diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py
index c219b97e..e482c396 100644
--- a/app/agent/tools/impl/ask_user_choice.py
+++ b/app/agent/tools/impl/ask_user_choice.py
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
-from app.chain.interaction import (
+from app.helper.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
diff --git a/app/chain/interaction.py b/app/chain/interaction.py
deleted file mode 100644
index c839e611..00000000
--- a/app/chain/interaction.py
+++ /dev/null
@@ -1,1363 +0,0 @@
-import math
-import re
-import uuid
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta
-from threading import Lock
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-from app.chain import ChainBase
-from app.chain.download import DownloadChain
-from app.chain.media import MediaChain
-from app.chain.search import SearchChain
-from app.chain.subscribe import SubscribeChain
-from app.core.config import settings
-from app.core.context import Context, MediaInfo
-from app.core.meta import MetaBase
-from app.db.user_oper import UserOper
-from app.helper.torrent import TorrentHelper
-from app.log import logger
-from app.schemas import Notification, NotExistMediaInfo
-from app.schemas.message import ChannelCapabilityManager
-from app.schemas.types import MediaType, MessageChannel
-from app.utils.string import StringUtils
-
-
-@dataclass(frozen=True)
-class AgentInteractionOption:
- """
- Agent 交互选项。
- """
-
- label: str
- value: str
-
-
-@dataclass
-class PendingAgentInteraction:
- """
- 待处理的 Agent 客户端交互请求。
- """
-
- request_id: str
- session_id: str
- user_id: str
- channel: Optional[str]
- source: Optional[str]
- username: Optional[str]
- title: Optional[str]
- prompt: str
- options: List[AgentInteractionOption]
- created_at: datetime = field(default_factory=datetime.now)
-
-
-class AgentInteractionManager:
- """
- 管理 Agent 发起的客户端交互请求。
- """
-
- _ttl = timedelta(hours=24)
-
- def __init__(self):
- self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
- self._lock = Lock()
-
- def _cleanup_locked(self) -> None:
- expire_before = datetime.now() - self._ttl
- expired_ids = [
- request_id
- for request_id, request in self._pending_interactions.items()
- if request.created_at < expire_before
- ]
- for request_id in expired_ids:
- self._pending_interactions.pop(request_id, None)
-
- def create_request(
- self,
- session_id: str,
- user_id: str,
- channel: Optional[str],
- source: Optional[str],
- username: Optional[str],
- title: Optional[str],
- prompt: str,
- options: List[AgentInteractionOption],
- ) -> PendingAgentInteraction:
- """
- 创建一条待用户确认的 Agent 交互请求。
- """
- with self._lock:
- self._cleanup_locked()
- request_id = uuid.uuid4().hex[:12]
- while request_id in self._pending_interactions:
- request_id = uuid.uuid4().hex[:12]
- request = PendingAgentInteraction(
- request_id=request_id,
- session_id=session_id,
- user_id=str(user_id),
- channel=channel,
- source=source,
- username=username,
- title=title,
- prompt=prompt,
- options=options,
- )
- self._pending_interactions[request_id] = request
- return request
-
- def resolve(
- self,
- request_id: str,
- option_index: int,
- user_id: Optional[str] = None,
- ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
- """
- 消费一条 Agent 交互请求,并返回选中的选项。
- """
- with self._lock:
- self._cleanup_locked()
- request = self._pending_interactions.get(request_id)
- if not request:
- return None
- if user_id is not None and str(request.user_id) != str(user_id):
- return None
- if option_index < 1 or option_index > len(request.options):
- return None
- option = request.options[option_index - 1]
- self._pending_interactions.pop(request_id, None)
- return request, option
-
- def clear(self) -> None:
- """
- 清空所有 Agent 交互请求。
- """
- with self._lock:
- self._pending_interactions.clear()
-
-
-agent_interaction_manager = AgentInteractionManager()
-
-
-@dataclass
-class PendingMediaInteraction:
- """
- 记录一次搜索/下载/订阅交互的当前上下文。
- """
-
- request_id: str
- user_id: str
- channel: Optional[MessageChannel]
- source: Optional[str]
- username: Optional[str]
- action: str
- keyword: str
- phase: str = "media"
- page: int = 0
- title: str = ""
- meta: Optional[MetaBase] = None
- current_media: Optional[MediaInfo] = None
- items: List[Any] = field(default_factory=list)
- created_at: datetime = field(default_factory=datetime.now)
-
-
-class MediaInteractionManager:
- """
- 管理用户当前激活的媒体交互状态。
-
- 每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。
- """
-
- _ttl = timedelta(hours=24)
-
- def __init__(self):
- self._by_id: Dict[str, PendingMediaInteraction] = {}
- self._by_user: Dict[str, str] = {}
- self._lock = Lock()
-
- def _cleanup_locked(self) -> None:
- """
- 清理超时会话,避免内存中残留旧交互状态。
- """
- expire_before = datetime.now() - self._ttl
- expired = [
- request_id
- for request_id, request in self._by_id.items()
- if request.created_at < expire_before
- ]
- for request_id in expired:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def create_or_replace(
- self,
- user_id: Union[str, int],
- channel: Optional[MessageChannel],
- source: Optional[str],
- username: Optional[str],
- action: str,
- keyword: str,
- title: str = "",
- meta: Optional[MetaBase] = None,
- items: Optional[List[Any]] = None,
- ) -> PendingMediaInteraction:
- """
- 为用户创建新的交互状态,并替换旧会话。
- """
- with self._lock:
- self._cleanup_locked()
- user_key = str(user_id)
- old_request_id = self._by_user.get(user_key)
- if old_request_id:
- self._by_id.pop(old_request_id, None)
-
- request = PendingMediaInteraction(
- request_id=uuid.uuid4().hex[:12],
- user_id=user_key,
- channel=channel,
- source=source,
- username=username,
- action=action,
- keyword=keyword,
- title=title,
- meta=meta,
- items=list(items or []),
- )
- self._by_id[request.request_id] = request
- self._by_user[user_key] = request.request_id
- return request
-
- def get_by_user(
- self, user_id: Union[str, int]
- ) -> Optional[PendingMediaInteraction]:
- """
- 按用户读取当前会话,供文本回复和旧按钮兼容使用。
- """
- with self._lock:
- self._cleanup_locked()
- request_id = self._by_user.get(str(user_id))
- if not request_id:
- return None
- return self._by_id.get(request_id)
-
- def get_by_id(
- self, request_id: str, user_id: Union[str, int]
- ) -> Optional[PendingMediaInteraction]:
- """
- 按请求 ID 读取会话,并校验用户归属。
- """
- with self._lock:
- self._cleanup_locked()
- request = self._by_id.get(request_id)
- if not request or str(request.user_id) != str(user_id):
- return None
- return request
-
- def remove(self, request_id: str) -> None:
- """
- 主动结束一条会话。
- """
- with self._lock:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def clear(self) -> None:
- """
- 清空所有交互状态,主要用于测试。
- """
- with self._lock:
- self._by_id.clear()
- self._by_user.clear()
-
-
-media_interaction_manager = MediaInteractionManager()
-
-
-class MediaInteractionChain(ChainBase):
- """
- 处理媒体搜索、订阅、资源选择和翻页等交互流程。
- """
-
- _button_page_size = 8
- _text_page_size = 8
-
- @staticmethod
- def has_pending_interaction(user_id: Union[str, int]) -> bool:
- """
- 判断用户当前是否存在未结束的媒体交互。
- """
- return media_interaction_manager.get_by_user(user_id) is not None
-
- @staticmethod
- def _get_noexits_info(
- meta: MetaBase, mediainfo: MediaInfo
- ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]:
- """
- 构造媒体缺失集信息,用于全量重搜或自动下载补全集数。
- """
- if mediainfo.type == MediaType.TV:
- if not mediainfo.seasons:
- mediainfo = MediaChain().recognize_media(
- mtype=mediainfo.type,
- tmdbid=mediainfo.tmdb_id,
- doubanid=mediainfo.douban_id,
- cache=False,
- )
- if not mediainfo:
- logger.warn("媒体信息识别失败,无法补充季集信息")
- return {}
- if not mediainfo.seasons:
- logger.warn(
- "媒体信息中没有季集信息,标题:%s,tmdbid:%s,doubanid:%s",
- mediainfo.title,
- mediainfo.tmdb_id,
- mediainfo.douban_id,
- )
- return {}
-
- mediakey = mediainfo.tmdb_id or mediainfo.douban_id
- no_exists = {mediakey: {}}
- if meta.begin_season:
- episodes = mediainfo.seasons.get(meta.begin_season)
- if not episodes:
- return {}
- no_exists[mediakey][meta.begin_season] = NotExistMediaInfo(
- season=meta.begin_season,
- episodes=[],
- total_episode=len(episodes),
- start_episode=episodes[0],
- )
- else:
- for sea, eps in mediainfo.seasons.items():
- if not eps:
- continue
- no_exists[mediakey][sea] = NotExistMediaInfo(
- season=sea,
- episodes=[],
- total_episode=len(eps),
- start_episode=eps[0],
- )
- return no_exists
- return {}
-
- @staticmethod
- def parse_callback(
- callback_data: str,
- ) -> Optional[Tuple[Optional[str], str, Optional[int]]]:
- """
- 解析新旧两种媒体交互按钮格式。
- """
- if callback_data.startswith("media:"):
- parts = callback_data.split(":")
- if len(parts) < 3:
- return None
- request_id = parts[1]
- action = parts[2]
- index = None
- if len(parts) >= 4 and parts[3].isdigit():
- index = int(parts[3])
- return request_id, action, index
-
- match = re.match(r"^(select|download)_(\d+)$", callback_data)
- if match:
- return None, match.group(1), int(match.group(2))
- if callback_data == "page_p":
- return None, "page-prev", None
- if callback_data == "page_n":
- return None, "page-next", None
- return None
-
- def handle_callback_interaction(
- self,
- callback_data: str,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> bool:
- """
- 处理按钮回调,并将当前视图刷新到原消息上。
- """
- parsed = self.parse_callback(callback_data)
- if not parsed:
- return False
-
- request_id, action, index = parsed
- if request_id:
- request = media_interaction_manager.get_by_id(request_id, userid)
- else:
- request = media_interaction_manager.get_by_user(userid)
-
- if not request:
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="交互已失效,请重新搜索或订阅",
- )
- )
- return True
-
- request.channel = channel
- request.source = source
- request.username = username
-
- if action == "page-prev":
- if request.page <= 0:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="已经是第一页了!",
- )
- return True
- request.page -= 1
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
- return True
-
- if action == "page-next":
- if not self._has_next_page(request):
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="已经是最后一页了!",
- )
- return True
- request.page += 1
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
- return True
-
- if action == "select":
- self._handle_media_selection(
- request=request,
- page_index=index,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
- return True
-
- if action == "download":
- self._handle_torrent_selection(
- request=request,
- page_index=index,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
-
- return False
-
- def handle_text_interaction(
- self,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- text: str,
- ) -> bool:
- """
- 处理文本式交互。
-
- 有会话时优先处理数字选择和翻页;无会话时负责识别搜索/订阅类入口。
- """
- request = media_interaction_manager.get_by_user(userid)
- normalized = (text or "").strip()
- lowered = normalized.lower()
-
- if request and lowered in {"退出", "关闭", "q", "quit", "exit"}:
- media_interaction_manager.remove(request.request_id)
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="媒体交互已结束",
- )
- )
- return True
-
- if normalized.isdigit():
- if not request:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
- request.channel = channel
- request.source = source
- request.username = username
- index = int(normalized)
- if request.phase == "torrent":
- self._handle_torrent_selection(
- request=request,
- page_index=index,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- else:
- self._handle_media_selection(
- request=request,
- page_index=index,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
-
- if lowered in {"p", "prev", "上一页"}:
- if not request:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
- if request.page <= 0:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="已经是第一页了!",
- )
- return True
- request.page -= 1
- request.channel = channel
- request.source = source
- request.username = username
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- )
- return True
-
- if lowered in {"n", "next", "下一页"}:
- if not request:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
- if not self._has_next_page(request):
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="已经是最后一页了!",
- )
- return True
- request.page += 1
- request.channel = channel
- request.source = source
- request.username = username
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- )
- return True
-
- action, content = self._resolve_action(normalized)
- if not action:
- return False
-
- self._start_media_interaction(
- action=action,
- content=content,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return True
-
- @staticmethod
- def _resolve_action(text: str) -> Tuple[Optional[str], str]:
- """
- 将用户输入归类为搜索、订阅或普通聊天。
- """
- if text.startswith("订阅"):
- return "Subscribe", re.sub(r"订阅[::\s]*", "", text)
- if text.startswith("洗版"):
- return "ReSubscribe", re.sub(r"洗版[::\s]*", "", text)
- if text.startswith("搜索") or text.startswith("下载"):
- return "ReSearch", re.sub(r"(搜索|下载)[::\s]*", "", text)
- if StringUtils.is_link(text):
- return None, text
- if not StringUtils.is_media_title_like(text):
- return None, text
- return "Search", text
-
- def _start_media_interaction(
- self,
- action: str,
- content: str,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- ) -> None:
- """
- 根据用户输入搜索媒体,并进入媒体选择阶段。
- """
- meta, medias = MediaChain().search(content)
- if not meta.name:
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title="无法识别输入内容!",
- )
- return
- if not medias:
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"{meta.name} 没有找到对应的媒体信息!",
- )
- )
- return
-
- logger.info("搜索到 %s 条相关媒体信息", len(medias))
- request = media_interaction_manager.create_or_replace(
- user_id=userid,
- channel=channel,
- source=source,
- username=username,
- action=action,
- keyword=content,
- title=meta.name,
- meta=meta,
- items=medias,
- )
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- )
-
- def _handle_media_selection(
- self,
- request: PendingMediaInteraction,
- page_index: Optional[int],
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> None:
- """
- 处理媒体选择阶段的序号输入。
- """
- page_items, page, _ = self._page_items(
- items=request.items,
- page=request.page,
- page_size=self._page_size(request.channel),
- )
- request.page = page
- if not page_index or page_index < 1 or page_index > len(page_items):
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return
-
- mediainfo: MediaInfo = page_items[page_index - 1]
- request.current_media = mediainfo
-
- if request.action in {"Search", "ReSearch"}:
- self._search_media_resources(
- request=request,
- mediainfo=mediainfo,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
- return
-
- if request.action in {"Subscribe", "ReSubscribe"}:
- self._subscribe_media(
- request=request,
- mediainfo=mediainfo,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
-
- def _search_media_resources(
- self,
- request: PendingMediaInteraction,
- mediainfo: MediaInfo,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> None:
- """
- 根据已选媒体搜索资源,并切换到资源选择阶段。
- """
- exist_flag, no_exists = DownloadChain().get_no_exists_info(
- meta=request.meta,
- mediainfo=mediainfo,
- )
- if exist_flag and request.action == "Search":
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
- )
- )
- return
- if exist_flag:
- no_exists = self._get_noexits_info(request.meta, mediainfo)
-
- messages = self._build_no_exists_messages(
- mediainfo=mediainfo,
- no_exists=no_exists,
- show_missing_only=request.action == "Search",
- )
- if messages:
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
- )
- )
-
- logger.info("开始搜索 %s ...", mediainfo.title_year)
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
- )
- )
-
- contexts = SearchChain().process(mediainfo=mediainfo, no_exists=no_exists)
- if not contexts:
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"{mediainfo.title}{request.meta.sea} 未搜索到需要的资源!",
- )
- )
- return
-
- contexts = TorrentHelper().sort_torrents(contexts)
- if self._should_auto_download(userid):
- logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid)
- self._auto_download(
- request=request,
- cache_list=contexts,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- no_exists=no_exists,
- )
- return
-
- request.phase = "torrent"
- request.page = 0
- request.title = mediainfo.title
- request.items = list(contexts)
- self._render_interaction(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
-
- def _subscribe_media(
- self,
- request: PendingMediaInteraction,
- mediainfo: MediaInfo,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- ) -> None:
- """
- 根据已选媒体创建订阅或洗版订阅。
- """
- best_version = request.action == "ReSubscribe"
- if not best_version:
- exist_flag, _ = DownloadChain().get_no_exists_info(
- meta=request.meta,
- mediainfo=mediainfo,
- )
- if exist_flag:
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
- )
- )
- return
-
- mp_name = (
- UserOper().get_name(**{f"{channel.name.lower()}_userid": userid})
- if channel
- else None
- )
- SubscribeChain().add(
- title=mediainfo.title,
- year=mediainfo.year,
- mtype=mediainfo.type,
- tmdbid=mediainfo.tmdb_id,
- season=request.meta.begin_season,
- channel=channel,
- source=source,
- userid=userid,
- username=mp_name or username,
- best_version=best_version,
- )
-
- def _handle_torrent_selection(
- self,
- request: PendingMediaInteraction,
- page_index: Optional[int],
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- ) -> None:
- """
- 处理资源选择阶段的下载操作。
- """
- if request.phase != "torrent":
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return
-
- if page_index == 0:
- self._auto_download(
- request=request,
- cache_list=request.items,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return
-
- page_items, page, _ = self._page_items(
- items=request.items,
- page=request.page,
- page_size=self._page_size(request.channel),
- )
- request.page = page
- if not page_index or page_index < 1 or page_index > len(page_items):
- self._post_invalid_input(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- return
-
- context: Context = page_items[page_index - 1]
- DownloadChain().download_single(
- context,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
-
- def _auto_download(
- self,
- request: PendingMediaInteraction,
- cache_list: List[Context],
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None,
- ) -> None:
- """
- 自动择优下载当前资源列表,并在未完成时补建订阅。
- """
- downloadchain = DownloadChain()
- if no_exists is None:
- exist_flag, no_exists = downloadchain.get_no_exists_info(
- meta=request.meta,
- mediainfo=request.current_media,
- )
- if exist_flag:
- no_exists = self._get_noexits_info(request.meta, request.current_media)
-
- downloads, lefts = downloadchain.batch_download(
- contexts=cache_list,
- no_exists=no_exists,
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- )
- if downloads and not lefts:
- logger.info("%s 下载完成", request.current_media.title_year)
- return
-
- logger.info("%s 未下载未完整,添加订阅 ...", request.current_media.title_year)
- if downloads and request.current_media.type == MediaType.TV:
- note = [
- download.meta_info.begin_episode
- for download in downloads
- if download.meta_info.begin_episode
- ]
- else:
- note = None
-
- mp_name = (
- UserOper().get_name(**{f"{channel.name.lower()}_userid": userid})
- if channel
- else None
- )
- SubscribeChain().add(
- title=request.current_media.title,
- year=request.current_media.year,
- mtype=request.current_media.type,
- tmdbid=request.current_media.tmdb_id,
- season=request.meta.begin_season,
- channel=channel,
- source=source,
- userid=userid,
- username=mp_name or username,
- state="R",
- note=note,
- )
-
- def _render_interaction(
- self,
- request: PendingMediaInteraction,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> None:
- """
- 按当前阶段渲染媒体列表或资源列表。
- """
- if request.phase == "torrent":
- self._post_torrents_message(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
- else:
- self._post_medias_message(
- request=request,
- channel=channel,
- source=source,
- userid=userid,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- )
-
- def _post_medias_message(
- self,
- request: PendingMediaInteraction,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> None:
- """
- 发送或更新媒体选择列表。
- """
- page_items, page, total_pages = self._page_items(
- items=request.items,
- page=request.page,
- page_size=self._page_size(channel),
- )
- request.page = page
- total = len(request.items)
- if self._supports_interactive_buttons(channel):
- title = f"【{request.title}】共找到{total}条相关信息,请选择操作"
- buttons = self._create_media_buttons(
- channel=channel,
- request=request,
- items=page_items,
- total=total,
- total_pages=total_pages,
- )
- else:
- if total > self._page_size(channel):
- title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)"
- else:
- title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择"
- buttons = None
-
- self.post_medias_message(
- Notification(
- channel=channel,
- source=source,
- title=title,
- userid=userid,
- buttons=buttons,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- ),
- medias=page_items,
- )
-
- def _post_torrents_message(
- self,
- request: PendingMediaInteraction,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- ) -> None:
- """
- 发送或更新资源选择列表。
- """
- page_items, page, total_pages = self._page_items(
- items=request.items,
- page=request.page,
- page_size=self._page_size(channel),
- )
- request.page = page
- total = len(request.items)
- if self._supports_interactive_buttons(channel):
- title = f"【{request.title}】共找到{total}条相关资源,请选择下载"
- buttons = self._create_torrent_buttons(
- channel=channel,
- request=request,
- items=page_items,
- total=total,
- total_pages=total_pages,
- )
- else:
- if total > self._page_size(channel):
- title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)"
- else:
- title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)"
- buttons = None
-
- self.post_torrents_message(
- Notification(
- channel=channel,
- source=source,
- title=title,
- userid=userid,
- link=settings.MP_DOMAIN("#/resource"),
- buttons=buttons,
- original_message_id=original_message_id,
- original_chat_id=original_chat_id,
- ),
- torrents=page_items,
- )
-
- def _create_media_buttons(
- self,
- channel: MessageChannel,
- request: PendingMediaInteraction,
- items: List[MediaInfo],
- total: int,
- total_pages: int,
- ) -> List[List[Dict[str, str]]]:
- """
- 为媒体列表生成选择和翻页按钮。
- """
- buttons: List[List[Dict[str, str]]] = []
- max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
- max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
-
- current_row: List[Dict[str, str]] = []
- for index, media in enumerate(items, start=1):
- if max_per_row == 1:
- button_text = f"{index}. {media.title_year}"
- if len(button_text) > max_text_length:
- button_text = button_text[: max_text_length - 3] + "..."
- buttons.append(
- [
- {
- "text": button_text,
- "callback_data": f"media:{request.request_id}:select:{index}",
- }
- ]
- )
- continue
-
- current_row.append(
- {
- "text": f"{index}",
- "callback_data": f"media:{request.request_id}:select:{index}",
- }
- )
- if len(current_row) == max_per_row or index == len(items):
- buttons.append(current_row)
- current_row = []
-
- if total > self._page_size(channel):
- buttons.extend(self._navigation_buttons(request, total_pages))
- return buttons
-
- def _create_torrent_buttons(
- self,
- channel: MessageChannel,
- request: PendingMediaInteraction,
- items: List[Context],
- total: int,
- total_pages: int,
- ) -> List[List[Dict[str, str]]]:
- """
- 为资源列表生成下载和翻页按钮。
- """
- buttons: List[List[Dict[str, str]]] = [
- [
- {
- "text": "🤖 自动选择下载",
- "callback_data": f"media:{request.request_id}:download:0",
- }
- ]
- ]
- max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
- max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
-
- current_row: List[Dict[str, str]] = []
- for index, context in enumerate(items, start=1):
- torrent = context.torrent_info
- if max_per_row == 1:
- button_text = f"{index}. {torrent.site_name} - {torrent.seeders}↑"
- if len(button_text) > max_text_length:
- button_text = button_text[: max_text_length - 3] + "..."
- buttons.append(
- [
- {
- "text": button_text,
- "callback_data": f"media:{request.request_id}:download:{index}",
- }
- ]
- )
- continue
-
- current_row.append(
- {
- "text": f"{index}",
- "callback_data": f"media:{request.request_id}:download:{index}",
- }
- )
- if len(current_row) == max_per_row or index == len(items):
- buttons.append(current_row)
- current_row = []
-
- if total > self._page_size(channel):
- buttons.extend(self._navigation_buttons(request, total_pages))
- return buttons
-
- def _has_next_page(self, request: PendingMediaInteraction) -> bool:
- """
- 判断当前视图是否还有下一页。
- """
- _, page, total_pages = self._page_items(
- items=request.items,
- page=request.page,
- page_size=self._page_size(request.channel),
- )
- return page < total_pages - 1
-
- @staticmethod
- def _navigation_buttons(
- request: PendingMediaInteraction,
- total_pages: int,
- ) -> List[List[Dict[str, str]]]:
- """
- 按当前页状态生成上一页和下一页按钮。
- """
- buttons: List[List[Dict[str, str]]] = []
- nav_row: List[Dict[str, str]] = []
- if request.page > 0:
- nav_row.append(
- {
- "text": "⬅️ 上一页",
- "callback_data": f"media:{request.request_id}:page-prev",
- }
- )
- if request.page < total_pages - 1:
- nav_row.append(
- {
- "text": "下一页 ➡️",
- "callback_data": f"media:{request.request_id}:page-next",
- }
- )
- if nav_row:
- buttons.append(nav_row)
- return buttons
-
- @staticmethod
- def _page_items(
- items: List[Any],
- page: int,
- page_size: int,
- ) -> Tuple[List[Any], int, int]:
- """
- 返回当前页数据,并把页码限制在有效范围内。
- """
- total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1
- page = min(max(0, page), total_pages - 1)
- start = page * page_size
- end = start + page_size
- return items[start:end], page, total_pages
-
- def _page_size(self, channel: Optional[MessageChannel]) -> int:
- """
- 按渠道交互能力选择分页大小。
- """
- return (
- self._button_page_size
- if self._supports_interactive_buttons(channel)
- else self._text_page_size
- )
-
- @staticmethod
- def _supports_interactive_buttons(channel: Optional[MessageChannel]) -> bool:
- """
- 判断渠道是否同时支持按钮展示与按钮回调。
- """
- return bool(
- channel
- and ChannelCapabilityManager.supports_buttons(channel)
- and ChannelCapabilityManager.supports_callbacks(channel)
- )
-
- @staticmethod
- def _build_no_exists_messages(
- mediainfo: MediaInfo,
- no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]],
- show_missing_only: bool,
- ) -> List[str]:
- """
- 将缺失集信息转换为可发送的文案。
- """
- if not no_exists:
- return []
- mediakey = mediainfo.tmdb_id or mediainfo.douban_id
- season_map = no_exists.get(mediakey) or {}
- if show_missing_only:
- return [
- f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
- for sea, no_exist in season_map.items()
- ]
- return [
- f"第 {sea} 季总 {no_exist.total_episode} 集"
- for sea, no_exist in season_map.items()
- ]
-
- @staticmethod
- def _should_auto_download(userid: Union[str, int]) -> bool:
- """
- 判断当前用户是否命中自动下载名单。
- """
- auto_download_user = settings.AUTO_DOWNLOAD_USER
- return bool(
- auto_download_user
- and (
- auto_download_user == "all"
- or any(userid == user for user in auto_download_user.split(","))
- )
- )
-
- def _post_invalid_input(
- self,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: Optional[str],
- title: str = "输入有误!",
- ) -> None:
- """
- 发送统一的非法输入提示。
- """
- self.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=title,
- )
- )
diff --git a/app/chain/media.py b/app/chain/media.py
index 08363f26..5833b56e 100644
--- a/app/chain/media.py
+++ b/app/chain/media.py
@@ -24,9 +24,9 @@ from app.schemas.types import (
ScrapingPolicy,
SystemConfigKey,
)
+from app.utils.http import RequestUtils
from app.utils.mixins import ConfigReloadMixin
from app.utils.singleton import Singleton
-from app.utils.http import RequestUtils
from app.utils.string import StringUtils
recognize_lock = Lock()
@@ -44,10 +44,10 @@ class ScrapingOption:
policy: ScrapingPolicy = ScrapingPolicy.MISSINGONLY
def __init__(
- self,
- type: Union[str, ScrapingTarget],
- metadata: Union[str, ScrapingMetadata],
- value: Union[ScrapingPolicy, bool, str],
+ self,
+ type: Union[str, ScrapingTarget],
+ metadata: Union[str, ScrapingMetadata],
+ value: Union[ScrapingPolicy, bool, str],
):
if isinstance(type, ScrapingTarget):
self.type = type
@@ -105,7 +105,7 @@ class ScrapingConfig:
self._policies[tuple(items)] = ScrapingOption(*items, value)
def option(
- self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
+ self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
) -> ScrapingOption:
if isinstance(item, ScrapingTarget):
@@ -173,11 +173,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
def on_config_changed(self):
self.scraping_policies = ScrapingConfig.from_system_config()
+ @staticmethod
def _should_scrape(
- self,
- scraping_option: ScrapingOption,
- file_exists: bool,
- global_overwrite: bool = False,
+ scraping_option: ScrapingOption,
+ file_exists: bool,
+ global_overwrite: bool = False,
) -> bool:
"""
判断是否应该执行刮削操作
@@ -211,7 +211,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return False
def _save_file(
- self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
+ self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
):
"""
保存或上传文件
@@ -224,7 +224,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return
# 使用tempfile创建临时文件
with NamedTemporaryFile(
- delete=True, delete_on_close=False, suffix=path.suffix
+ delete=True, delete_on_close=False, suffix=path.suffix
) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 写入内容
@@ -248,7 +248,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.warn(f"文件保存失败:{path}")
def _download_and_save_image(
- self, fileitem: schemas.FileItem, path: Path, url: str
+ self, fileitem: schemas.FileItem, path: Path, url: str
):
"""
流式下载图片并保存到文件
@@ -268,7 +268,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
if r and r.status_code == 200:
# 使用tempfile创建临时文件,自动删除
with NamedTemporaryFile(
- delete=True, delete_on_close=False, suffix=path.suffix
+ delete=True, delete_on_close=False, suffix=path.suffix
) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 流式写入文件
@@ -295,12 +295,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.error(f"{url} 图片下载失败:{str(err)}!")
def _get_target_fileitem_and_path(
- self,
- current_fileitem: schemas.FileItem,
- item_type: ScrapingTarget,
- metadata_type: ScrapingMetadata,
- filename_hint: Optional[str] = None,
- parent_fileitem: Optional[schemas.FileItem] = None,
+ self,
+ current_fileitem: schemas.FileItem,
+ item_type: ScrapingTarget,
+ metadata_type: ScrapingMetadata,
+ filename_hint: Optional[str] = None,
+ parent_fileitem: Optional[schemas.FileItem] = None,
) -> Tuple[schemas.FileItem, Optional[Path]]:
"""
根据当前上下文、刮削项类型和元数据类型生成目标 FileItem 和 Path
@@ -318,8 +318,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 电影文件NFO: 放在电影文件同级目录,名称与电影文件主体一致,后缀.nfo
final_filename = f"{target_dir_path.stem}.nfo"
target_dir_item = (
- parent_fileitem
- or self.storagechain.get_parent_item(current_fileitem)
+ parent_fileitem
+ or self.storagechain.get_parent_item(current_fileitem)
)
if not target_dir_item:
logger.error(
@@ -354,8 +354,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 图片通常是放在当前目录 (current_fileitem) 下
# 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致
elif (
- metadata_type in [ScrapingMetadata.THUMB]
- and item_type == ScrapingTarget.EPISODE
+ metadata_type in [ScrapingMetadata.THUMB]
+ and item_type == ScrapingTarget.EPISODE
):
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
final_filename = f"{target_dir_path.stem}{hint_ext}"
@@ -380,11 +380,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return target_dir_item, target_full_path
def metadata_nfo(
- self,
- meta: MetaBase,
- mediainfo: MediaInfo,
- season: Optional[int] = None,
- episode: Optional[int] = None,
+ self,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ season: Optional[int] = None,
+ episode: Optional[int] = None,
) -> Optional[str]:
"""
获取NFO文件内容文本
@@ -402,8 +402,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
episode=episode,
)
+ @staticmethod
def select_recognize_source(
- self, log_name: str, log_context: str, native_fn, plugin_fn
+ log_name: str, log_context: str, native_fn, plugin_fn
) -> Optional[MediaInfo]:
"""
选择识别模式,插件优先或原生优先
@@ -436,7 +437,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
def recognize_by_meta(
- self, metainfo: MetaBase, episode_group: Optional[str] = None
+ self, metainfo: MetaBase, episode_group: Optional[str] = None
) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息
@@ -513,7 +514,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return self.recognize_media(meta=org_meta)
def recognize_by_path(
- self, path: str, episode_group: Optional[str] = None
+ self, path: str, episode_group: Optional[str] = None
) -> Optional[Context]:
"""
根据文件路径识别媒体信息
@@ -577,7 +578,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return meta, medias
def get_tmdbinfo_by_doubanid(
- self, doubanid: str, mtype: MediaType = None
+ self, doubanid: str, mtype: MediaType = None
) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息
@@ -648,7 +649,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
def get_doubaninfo_by_tmdbid(
- self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
+ self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息
@@ -752,8 +753,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 收集从根目录到文件的所有父目录
current_path = sub_path.parent
while (
- current_path != root_path
- and current_path.is_relative_to(root_path)
+ current_path != root_path
+ and current_path.is_relative_to(root_path)
):
all_dirs.add(current_path)
current_path = current_path.parent
@@ -805,15 +806,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _scrape_nfo_generic(
- self,
- current_fileitem: schemas.FileItem,
- meta: MetaBase,
- mediainfo: MediaInfo,
- item_type: ScrapingTarget,
- parent_fileitem: Optional[schemas.FileItem] = None,
- overwrite: bool = False,
- season_number: Optional[int] = None,
- episode_number: Optional[int] = None,
+ self,
+ current_fileitem: schemas.FileItem,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ item_type: ScrapingTarget,
+ parent_fileitem: Optional[schemas.FileItem] = None,
+ overwrite: bool = False,
+ season_number: Optional[int] = None,
+ episode_number: Optional[int] = None,
):
"""
NFO 刮削
@@ -859,14 +860,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.warn(f"{nfo_path.name} NFO 文件生成失败!")
def _scrape_images_generic(
- self,
- current_fileitem: schemas.FileItem,
- mediainfo: MediaInfo,
- item_type: ScrapingTarget,
- parent_fileitem: Optional[schemas.FileItem] = None,
- overwrite: bool = False,
- season_number: Optional[int] = None,
- episode_number: Optional[int] = None,
+ self,
+ current_fileitem: schemas.FileItem,
+ mediainfo: MediaInfo,
+ item_type: ScrapingTarget,
+ parent_fileitem: Optional[schemas.FileItem] = None,
+ overwrite: bool = False,
+ season_number: Optional[int] = None,
+ episode_number: Optional[int] = None,
):
"""
图片刮削
@@ -906,14 +907,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 判断是否匹配当前刮削的季号
if item_type == ScrapingTarget.TV and image_name.lower().startswith(
- "season"
+ "season"
):
logger.info(f"当前为电视剧根目录刮削,跳过季图片:{image_name}")
continue
if (
- item_type == ScrapingTarget.SEASON
- and season_number is not None
- and image_name.lower().startswith("season")
+ item_type == ScrapingTarget.SEASON
+ and season_number is not None
+ and image_name.lower().startswith("season")
):
# 检查是否只下载当前刮削季的图片
image_season_str = (
@@ -921,7 +922,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
if image_season_str is not None and image_season_str != str(
- season_number
+ season_number
).rjust(2, "0"):
logger.info(
f"当前刮削季为:{season_number},跳过非本季图片:{image_name}"
@@ -956,14 +957,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def scrape_metadata(
- self,
- fileitem: schemas.FileItem,
- meta: MetaBase = None,
- mediainfo: MediaInfo = None,
- init_folder: bool = True,
- parent: schemas.FileItem = None,
- overwrite: bool = False,
- recursive: bool = True,
+ self,
+ fileitem: schemas.FileItem,
+ meta: MetaBase = None,
+ mediainfo: MediaInfo = None,
+ init_folder: bool = True,
+ parent: schemas.FileItem = None,
+ overwrite: bool = False,
+ recursive: bool = True,
):
"""
手动刮削媒体信息
@@ -982,7 +983,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 当前文件路径
filepath = Path(fileitem.path)
if fileitem.type == "file" and (
- not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
+ not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
):
return
@@ -1022,14 +1023,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.info(f"{filepath.name} 刮削完成")
def _handle_movie_scraping(
- self,
- fileitem: schemas.FileItem,
- meta: MetaBase,
- mediainfo: MediaInfo,
- init_folder: bool,
- parent: schemas.FileItem,
- overwrite: bool,
- recursive: bool,
+ self,
+ fileitem: schemas.FileItem,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ init_folder: bool,
+ parent: schemas.FileItem,
+ overwrite: bool,
+ recursive: bool,
):
"""
处理电影刮削
@@ -1051,20 +1052,18 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
meta=meta,
mediainfo=mediainfo,
init_folder=init_folder,
- parent=parent,
overwrite=overwrite,
recursive=recursive,
)
def _handle_movie_directory(
- self,
- fileitem: schemas.FileItem,
- meta: MetaBase,
- mediainfo: MediaInfo,
- init_folder: bool,
- parent: schemas.FileItem,
- overwrite: bool,
- recursive: bool,
+ self,
+ fileitem: schemas.FileItem,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ init_folder: bool,
+ overwrite: bool,
+ recursive: bool,
):
"""
处理电影目录刮削
@@ -1105,14 +1104,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_scraping(
- self,
- fileitem: schemas.FileItem,
- meta: MetaBase,
- mediainfo: MediaInfo,
- init_folder: bool,
- parent: schemas.FileItem,
- overwrite: bool,
- recursive: bool,
+ self,
+ fileitem: schemas.FileItem,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ init_folder: bool,
+ parent: schemas.FileItem,
+ overwrite: bool,
+ recursive: bool,
):
"""
处理电视剧刮削
@@ -1142,12 +1141,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_episode_file(
- self,
- fileitem: schemas.FileItem,
- filepath: Path,
- mediainfo: MediaInfo,
- parent: schemas.FileItem,
- overwrite: bool,
+ self,
+ fileitem: schemas.FileItem,
+ filepath: Path,
+ mediainfo: MediaInfo,
+ parent: schemas.FileItem,
+ overwrite: bool,
):
"""
处理电视剧集文件刮削
@@ -1191,15 +1190,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_directory(
- self,
- fileitem: schemas.FileItem,
- filepath: Path,
- meta: MetaBase,
- mediainfo: MediaInfo,
- init_folder: bool,
- parent: schemas.FileItem,
- overwrite: bool,
- recursive: bool,
+ self,
+ fileitem: schemas.FileItem,
+ filepath: Path,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ init_folder: bool,
+ parent: schemas.FileItem,
+ overwrite: bool,
+ recursive: bool,
):
"""
处理电视剧目录刮削
@@ -1209,9 +1208,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
files = self.storagechain.list_files(fileitem=fileitem) or []
for file in files:
if (
- file.type == "dir"
- and file.name not in settings.RENAME_FORMAT_S0_NAMES
- and MetaInfo(file.name).begin_season is None
+ file.type == "dir"
+ and file.name not in settings.RENAME_FORMAT_S0_NAMES
+ and MetaInfo(file.name).begin_season is None
):
# 电视剧不处理非季子目录
continue
@@ -1235,13 +1234,13 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _initialize_tv_directory_metadata(
- self,
- fileitem: schemas.FileItem,
- filepath: Path,
- meta: MetaBase,
- mediainfo: MediaInfo,
- parent: schemas.FileItem,
- overwrite: bool,
+ self,
+ fileitem: schemas.FileItem,
+ filepath: Path,
+ meta: MetaBase,
+ mediainfo: MediaInfo,
+ parent: schemas.FileItem,
+ overwrite: bool,
):
"""
初始化电视剧目录元数据(识别季号并刮削)
@@ -1296,8 +1295,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
else:
logger.warn("无法识别元数据,跳过")
+ @staticmethod
async def async_select_recognize_source(
- self, log_name: str, log_context: str, native_fn, plugin_fn
+ log_name: str, log_context: str, native_fn, plugin_fn
) -> Optional[MediaInfo]:
"""
选择识别模式,插件优先或原生优先(异步版本)
@@ -1330,7 +1330,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
async def async_recognize_by_meta(
- self, metainfo: MetaBase, episode_group: Optional[str] = None
+ self, metainfo: MetaBase, episode_group: Optional[str] = None
) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息(异步版本)
@@ -1366,7 +1366,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
async def async_recognize_help(
- self, title: str, org_meta: MetaBase
+ self, title: str, org_meta: MetaBase
) -> Optional[MediaInfo]:
"""
请求辅助识别,返回媒体信息(异步版本)
@@ -1417,7 +1417,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return await self.async_recognize_media(meta=org_meta)
async def async_recognize_by_path(
- self, path: str, episode_group: Optional[str] = None
+ self, path: str, episode_group: Optional[str] = None
) -> Optional[Context]:
"""
根据文件路径识别媒体信息(异步版本)
@@ -1455,7 +1455,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return Context(meta_info=file_meta, media_info=mediainfo)
async def async_search(
- self, title: str
+ self, title: str
) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
"""
搜索媒体/人物信息(异步版本)
@@ -1502,7 +1502,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
@staticmethod
def _extract_year_from_tmdb(
- tmdbinfo: dict, season: Optional[int] = None
+ tmdbinfo: dict, season: Optional[int] = None
) -> Optional[str]:
"""
从TMDB信息中提取年份
@@ -1522,11 +1522,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return year
def _match_tmdb_with_names(
- self,
- meta_names: list,
- year: Optional[str],
- mtype: MediaType,
- season: Optional[int] = None,
+ self,
+ meta_names: list,
+ year: Optional[str],
+ mtype: MediaType,
+ season: Optional[int] = None,
) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息
@@ -1540,11 +1540,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def _async_match_tmdb_with_names(
- self,
- meta_names: list,
- year: Optional[str],
- mtype: MediaType,
- season: Optional[int] = None,
+ self,
+ meta_names: list,
+ year: Optional[str],
+ mtype: MediaType,
+ season: Optional[int] = None,
) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息(异步版本)
@@ -1558,7 +1558,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def async_get_tmdbinfo_by_doubanid(
- self, doubanid: str, mtype: MediaType = None
+ self, doubanid: str, mtype: MediaType = None
) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息(异步版本)
@@ -1629,7 +1629,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def async_get_doubaninfo_by_tmdbid(
- self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
+ self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息(异步版本)
diff --git a/app/chain/message.py b/app/chain/message.py
index 2f9c25cf..5e9d44eb 100644
--- a/app/chain/message.py
+++ b/app/chain/message.py
@@ -1,35 +1,40 @@
import asyncio
import base64
+import math
import mimetypes
import re
import time
import uuid
from datetime import datetime, timedelta
from pathlib import Path
-from typing import Any, Optional, Dict, Union, List
+from typing import Any, Optional, Dict, Union, List, Tuple
from urllib.parse import unquote, urlparse
from app.agent import ReplyMode, agent_manager, prompt_manager
+from app.agent.llm import LLMHelper
from app.chain import ChainBase
-from app.chain.interaction import (
- MediaInteractionChain,
- agent_interaction_manager,
- media_interaction_manager,
-)
+from app.chain.download import DownloadChain
+from app.chain.media import MediaChain
+from app.chain.search import SearchChain
from app.chain.site import SiteChain, site_interaction_manager
from app.chain.skills import SkillsChain, skills_interaction_manager
from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager
from app.chain.transfer import TransferChain
from app.core.config import settings, global_vars
+from app.core.context import MediaInfo, Context
+from app.core.meta import MetaBase
from app.db.models import TransferHistory
from app.db.transferhistory_oper import TransferHistoryOper
-from app.agent.llm import LLMHelper
+from app.db.user_oper import UserOper
+from app.helper.interaction import agent_interaction_manager, media_interaction_manager, PendingMediaInteraction
+from app.helper.torrent import TorrentHelper
from app.helper.voice import VoiceHelper
from app.log import logger
-from app.schemas import Notification, CommingMessage
+from app.schemas import Notification, CommingMessage, NotExistMediaInfo
from app.schemas.message import ChannelCapabilityManager
-from app.schemas.types import EventType, MessageChannel
+from app.schemas.types import EventType, MessageChannel, MediaType
from app.utils.http import RequestUtils
+from app.utils.string import StringUtils
class MessageChain(ChainBase):
@@ -175,31 +180,31 @@ class MessageChain(ChainBase):
latest_slash_interaction = self._get_latest_slash_interaction(userid)
if latest_slash_interaction == "sites":
if SiteChain().handle_text_interaction(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- text=text,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ text=text,
):
return
if latest_slash_interaction == "subscribes":
if SubscribeChain().handle_text_interaction(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- text=text,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ text=text,
):
return
if latest_slash_interaction == "skills":
if SkillsChain().handle_text_interaction(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- text=text,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ text=text,
):
return
@@ -378,9 +383,9 @@ class MessageChain(ChainBase):
"""
candidates = []
for name, manager in (
- ("sites", site_interaction_manager),
- ("subscribes", subscribe_interaction_manager),
- ("skills", skills_interaction_manager),
+ ("sites", site_interaction_manager),
+ ("subscribes", subscribe_interaction_manager),
+ ("skills", skills_interaction_manager),
):
request = manager.get_by_user(userid)
if request:
@@ -1567,3 +1572,1092 @@ class MessageChain(ChainBase):
except Exception as e:
logger.error(e)
return None
+
+
+class MediaInteractionChain(ChainBase):
+ """
+ 处理媒体搜索、订阅、资源选择和翻页等交互流程。
+ """
+
+ _button_page_size = 8
+ _text_page_size = 8
+
+ @staticmethod
+ def has_pending_interaction(user_id: Union[str, int]) -> bool:
+ """
+ 判断用户当前是否存在未结束的媒体交互。
+ """
+ return media_interaction_manager.get_by_user(user_id) is not None
+
+ @staticmethod
+ def _get_noexits_info(
+ meta: MetaBase, mediainfo: MediaInfo
+ ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]:
+ """
+ 构造媒体缺失集信息,用于全量重搜或自动下载补全集数。
+ """
+ if mediainfo.type == MediaType.TV:
+ if not mediainfo.seasons:
+ mediainfo = MediaChain().recognize_media(
+ mtype=mediainfo.type,
+ tmdbid=mediainfo.tmdb_id,
+ doubanid=mediainfo.douban_id,
+ cache=False,
+ )
+ if not mediainfo:
+ logger.warn("媒体信息识别失败,无法补充季集信息")
+ return {}
+ if not mediainfo.seasons:
+ logger.warn(
+ "媒体信息中没有季集信息,标题:%s,tmdbid:%s,doubanid:%s",
+ mediainfo.title,
+ mediainfo.tmdb_id,
+ mediainfo.douban_id,
+ )
+ return {}
+
+ mediakey = mediainfo.tmdb_id or mediainfo.douban_id
+ no_exists = {mediakey: {}}
+ if meta.begin_season:
+ episodes = mediainfo.seasons.get(meta.begin_season)
+ if not episodes:
+ return {}
+ no_exists[mediakey][meta.begin_season] = NotExistMediaInfo(
+ season=meta.begin_season,
+ episodes=[],
+ total_episode=len(episodes),
+ start_episode=episodes[0],
+ )
+ else:
+ for sea, eps in mediainfo.seasons.items():
+ if not eps:
+ continue
+ no_exists[mediakey][sea] = NotExistMediaInfo(
+ season=sea,
+ episodes=[],
+ total_episode=len(eps),
+ start_episode=eps[0],
+ )
+ return no_exists
+ return {}
+
+ @staticmethod
+ def parse_callback(
+ callback_data: str,
+ ) -> Optional[Tuple[Optional[str], str, Optional[int]]]:
+ """
+ 解析新旧两种媒体交互按钮格式。
+ """
+ if callback_data.startswith("media:"):
+ parts = callback_data.split(":")
+ if len(parts) < 3:
+ return None
+ request_id = parts[1]
+ action = parts[2]
+ index = None
+ if len(parts) >= 4 and parts[3].isdigit():
+ index = int(parts[3])
+ return request_id, action, index
+
+ match = re.match(r"^(select|download)_(\d+)$", callback_data)
+ if match:
+ return None, match.group(1), int(match.group(2))
+ if callback_data == "page_p":
+ return None, "page-prev", None
+ if callback_data == "page_n":
+ return None, "page-next", None
+ return None
+
+ def handle_callback_interaction(
+ self,
+ callback_data: str,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> bool:
+ """
+ 处理按钮回调,并将当前视图刷新到原消息上。
+ """
+ parsed = self.parse_callback(callback_data)
+ if not parsed:
+ return False
+
+ request_id, action, index = parsed
+ if request_id:
+ request = media_interaction_manager.get_by_id(request_id, userid)
+ else:
+ request = media_interaction_manager.get_by_user(userid)
+
+ if not request:
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="交互已失效,请重新搜索或订阅",
+ )
+ )
+ return True
+
+ request.channel = channel
+ request.source = source
+ request.username = username
+
+ if action == "page-prev":
+ if request.page <= 0:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="已经是第一页了!",
+ )
+ return True
+ request.page -= 1
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+ return True
+
+ if action == "page-next":
+ if not self._has_next_page(request):
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="已经是最后一页了!",
+ )
+ return True
+ request.page += 1
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+ return True
+
+ if action == "select":
+ self._handle_media_selection(
+ request=request,
+ page_index=index,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+ return True
+
+ if action == "download":
+ self._handle_torrent_selection(
+ request=request,
+ page_index=index,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+
+ return False
+
+ def handle_text_interaction(
+ self,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ text: str,
+ ) -> bool:
+ """
+ 处理文本式交互。
+
+ 有会话时优先处理数字选择和翻页;无会话时负责识别搜索/订阅类入口。
+ """
+ request = media_interaction_manager.get_by_user(userid)
+ normalized = (text or "").strip()
+ lowered = normalized.lower()
+
+ if request and lowered in {"退出", "关闭", "q", "quit", "exit"}:
+ media_interaction_manager.remove(request.request_id)
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="媒体交互已结束",
+ )
+ )
+ return True
+
+ if normalized.isdigit():
+ if not request:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+ request.channel = channel
+ request.source = source
+ request.username = username
+ index = int(normalized)
+ if request.phase == "torrent":
+ self._handle_torrent_selection(
+ request=request,
+ page_index=index,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ else:
+ self._handle_media_selection(
+ request=request,
+ page_index=index,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+
+ if lowered in {"p", "prev", "上一页"}:
+ if not request:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+ if request.page <= 0:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="已经是第一页了!",
+ )
+ return True
+ request.page -= 1
+ request.channel = channel
+ request.source = source
+ request.username = username
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ )
+ return True
+
+ if lowered in {"n", "next", "下一页"}:
+ if not request:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+ if not self._has_next_page(request):
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="已经是最后一页了!",
+ )
+ return True
+ request.page += 1
+ request.channel = channel
+ request.source = source
+ request.username = username
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ )
+ return True
+
+ action, content = self._resolve_action(normalized)
+ if not action:
+ return False
+
+ self._start_media_interaction(
+ action=action,
+ content=content,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return True
+
+ @staticmethod
+ def _resolve_action(text: str) -> Tuple[Optional[str], str]:
+ """
+ 将用户输入归类为搜索、订阅或普通聊天。
+ """
+ if text.startswith("订阅"):
+ return "Subscribe", re.sub(r"订阅[::\s]*", "", text)
+ if text.startswith("洗版"):
+ return "ReSubscribe", re.sub(r"洗版[::\s]*", "", text)
+ if text.startswith("搜索") or text.startswith("下载"):
+ return "ReSearch", re.sub(r"(搜索|下载)[::\s]*", "", text)
+ if StringUtils.is_link(text):
+ return None, text
+ if not StringUtils.is_media_title_like(text):
+ return None, text
+ return "Search", text
+
+ def _start_media_interaction(
+ self,
+ action: str,
+ content: str,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ ) -> None:
+ """
+ 根据用户输入搜索媒体,并进入媒体选择阶段。
+ """
+ meta, medias = MediaChain().search(content)
+ if not meta.name:
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title="无法识别输入内容!",
+ )
+ return
+ if not medias:
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"{meta.name} 没有找到对应的媒体信息!",
+ )
+ )
+ return
+
+ logger.info("搜索到 %s 条相关媒体信息", len(medias))
+ request = media_interaction_manager.create_or_replace(
+ user_id=userid,
+ channel=channel,
+ source=source,
+ username=username,
+ action=action,
+ keyword=content,
+ title=meta.name,
+ meta=meta,
+ items=medias,
+ )
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ )
+
+ def _handle_media_selection(
+ self,
+ request: PendingMediaInteraction,
+ page_index: Optional[int],
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> None:
+ """
+ 处理媒体选择阶段的序号输入。
+ """
+ page_items, page, _ = self._page_items(
+ items=request.items,
+ page=request.page,
+ page_size=self._page_size(request.channel),
+ )
+ request.page = page
+ if not page_index or page_index < 1 or page_index > len(page_items):
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return
+
+ mediainfo: MediaInfo = page_items[page_index - 1]
+ request.current_media = mediainfo
+
+ if request.action in {"Search", "ReSearch"}:
+ self._search_media_resources(
+ request=request,
+ mediainfo=mediainfo,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+ return
+
+ if request.action in {"Subscribe", "ReSubscribe"}:
+ self._subscribe_media(
+ request=request,
+ mediainfo=mediainfo,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+
+ def _search_media_resources(
+ self,
+ request: PendingMediaInteraction,
+ mediainfo: MediaInfo,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> None:
+ """
+ 根据已选媒体搜索资源,并切换到资源选择阶段。
+ """
+ exist_flag, no_exists = DownloadChain().get_no_exists_info(
+ meta=request.meta,
+ mediainfo=mediainfo,
+ )
+ if exist_flag and request.action == "Search":
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
+ )
+ )
+ return
+ if exist_flag:
+ no_exists = self._get_noexits_info(request.meta, mediainfo)
+
+ messages = self._build_no_exists_messages(
+ mediainfo=mediainfo,
+ no_exists=no_exists,
+ show_missing_only=request.action == "Search",
+ )
+ if messages:
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
+ )
+ )
+
+ logger.info("开始搜索 %s ...", mediainfo.title_year)
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
+ )
+ )
+
+ contexts = SearchChain().process(mediainfo=mediainfo, no_exists=no_exists)
+ if not contexts:
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"{mediainfo.title}{request.meta.sea} 未搜索到需要的资源!",
+ )
+ )
+ return
+
+ contexts = TorrentHelper().sort_torrents(contexts)
+ if self._should_auto_download(userid):
+ logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid)
+ self._auto_download(
+ request=request,
+ cache_list=contexts,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ no_exists=no_exists,
+ )
+ return
+
+ request.phase = "torrent"
+ request.page = 0
+ request.title = mediainfo.title
+ request.items = list(contexts)
+ self._render_interaction(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+
+ def _subscribe_media(
+ self,
+ request: PendingMediaInteraction,
+ mediainfo: MediaInfo,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ ) -> None:
+ """
+ 根据已选媒体创建订阅或洗版订阅。
+ """
+ best_version = request.action == "ReSubscribe"
+ if not best_version:
+ exist_flag, _ = DownloadChain().get_no_exists_info(
+ meta=request.meta,
+ mediainfo=mediainfo,
+ )
+ if exist_flag:
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
+ )
+ )
+ return
+
+ mp_name = (
+ UserOper().get_name(**{f"{channel.name.lower()}_userid": userid})
+ if channel
+ else None
+ )
+ SubscribeChain().add(
+ title=mediainfo.title,
+ year=mediainfo.year,
+ mtype=mediainfo.type,
+ tmdbid=mediainfo.tmdb_id,
+ season=request.meta.begin_season,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=mp_name or username,
+ best_version=best_version,
+ )
+
+ def _handle_torrent_selection(
+ self,
+ request: PendingMediaInteraction,
+ page_index: Optional[int],
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ ) -> None:
+ """
+ 处理资源选择阶段的下载操作。
+ """
+ if request.phase != "torrent":
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return
+
+ if page_index == 0:
+ self._auto_download(
+ request=request,
+ cache_list=request.items,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return
+
+ page_items, page, _ = self._page_items(
+ items=request.items,
+ page=request.page,
+ page_size=self._page_size(request.channel),
+ )
+ request.page = page
+ if not page_index or page_index < 1 or page_index > len(page_items):
+ self._post_invalid_input(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ return
+
+ context: Context = page_items[page_index - 1]
+ DownloadChain().download_single(
+ context,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+
+ def _auto_download(
+ self,
+ request: PendingMediaInteraction,
+ cache_list: List[Context],
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None,
+ ) -> None:
+ """
+ 自动择优下载当前资源列表,并在未完成时补建订阅。
+ """
+ downloadchain = DownloadChain()
+ if no_exists is None:
+ exist_flag, no_exists = downloadchain.get_no_exists_info(
+ meta=request.meta,
+ mediainfo=request.current_media,
+ )
+ if exist_flag:
+ no_exists = self._get_noexits_info(request.meta, request.current_media)
+
+ downloads, lefts = downloadchain.batch_download(
+ contexts=cache_list,
+ no_exists=no_exists,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ )
+ if downloads and not lefts:
+ logger.info("%s 下载完成", request.current_media.title_year)
+ return
+
+ logger.info("%s 未下载未完整,添加订阅 ...", request.current_media.title_year)
+ if downloads and request.current_media.type == MediaType.TV:
+ note = [
+ download.meta_info.begin_episode
+ for download in downloads
+ if download.meta_info.begin_episode
+ ]
+ else:
+ note = None
+
+ mp_name = (
+ UserOper().get_name(**{f"{channel.name.lower()}_userid": userid})
+ if channel
+ else None
+ )
+ SubscribeChain().add(
+ title=request.current_media.title,
+ year=request.current_media.year,
+ mtype=request.current_media.type,
+ tmdbid=request.current_media.tmdb_id,
+ season=request.meta.begin_season,
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=mp_name or username,
+ state="R",
+ note=note,
+ )
+
+ def _render_interaction(
+ self,
+ request: PendingMediaInteraction,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> None:
+ """
+ 按当前阶段渲染媒体列表或资源列表。
+ """
+ if request.phase == "torrent":
+ self._post_torrents_message(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+ else:
+ self._post_medias_message(
+ request=request,
+ channel=channel,
+ source=source,
+ userid=userid,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ )
+
+ def _post_medias_message(
+ self,
+ request: PendingMediaInteraction,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> None:
+ """
+ 发送或更新媒体选择列表。
+ """
+ page_items, page, total_pages = self._page_items(
+ items=request.items,
+ page=request.page,
+ page_size=self._page_size(channel),
+ )
+ request.page = page
+ total = len(request.items)
+ if self._supports_interactive_buttons(channel):
+ title = f"【{request.title}】共找到{total}条相关信息,请选择操作"
+ buttons = self._create_media_buttons(
+ channel=channel,
+ request=request,
+ items=page_items,
+ total=total,
+ total_pages=total_pages,
+ )
+ else:
+ if total > self._page_size(channel):
+ title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)"
+ else:
+ title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择"
+ buttons = None
+
+ self.post_medias_message(
+ Notification(
+ channel=channel,
+ source=source,
+ title=title,
+ userid=userid,
+ buttons=buttons,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ ),
+ medias=page_items,
+ )
+
+ def _post_torrents_message(
+ self,
+ request: PendingMediaInteraction,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ ) -> None:
+ """
+ 发送或更新资源选择列表。
+ """
+ page_items, page, total_pages = self._page_items(
+ items=request.items,
+ page=request.page,
+ page_size=self._page_size(channel),
+ )
+ request.page = page
+ total = len(request.items)
+ if self._supports_interactive_buttons(channel):
+ title = f"【{request.title}】共找到{total}条相关资源,请选择下载"
+ buttons = self._create_torrent_buttons(
+ channel=channel,
+ request=request,
+ items=page_items,
+ total=total,
+ total_pages=total_pages,
+ )
+ else:
+ if total > self._page_size(channel):
+ title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)"
+ else:
+ title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)"
+ buttons = None
+
+ self.post_torrents_message(
+ Notification(
+ channel=channel,
+ source=source,
+ title=title,
+ userid=userid,
+ link=settings.MP_DOMAIN("#/resource"),
+ buttons=buttons,
+ original_message_id=original_message_id,
+ original_chat_id=original_chat_id,
+ ),
+ torrents=page_items,
+ )
+
+ def _create_media_buttons(
+ self,
+ channel: MessageChannel,
+ request: PendingMediaInteraction,
+ items: List[MediaInfo],
+ total: int,
+ total_pages: int,
+ ) -> List[List[Dict[str, str]]]:
+ """
+ 为媒体列表生成选择和翻页按钮。
+ """
+ buttons: List[List[Dict[str, str]]] = []
+ max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
+ max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
+
+ current_row: List[Dict[str, str]] = []
+ for index, media in enumerate(items, start=1):
+ if max_per_row == 1:
+ button_text = f"{index}. {media.title_year}"
+ if len(button_text) > max_text_length:
+ button_text = button_text[: max_text_length - 3] + "..."
+ buttons.append(
+ [
+ {
+ "text": button_text,
+ "callback_data": f"media:{request.request_id}:select:{index}",
+ }
+ ]
+ )
+ continue
+
+ current_row.append(
+ {
+ "text": f"{index}",
+ "callback_data": f"media:{request.request_id}:select:{index}",
+ }
+ )
+ if len(current_row) == max_per_row or index == len(items):
+ buttons.append(current_row)
+ current_row = []
+
+ if total > self._page_size(channel):
+ buttons.extend(self._navigation_buttons(request, total_pages))
+ return buttons
+
+ def _create_torrent_buttons(
+ self,
+ channel: MessageChannel,
+ request: PendingMediaInteraction,
+ items: List[Context],
+ total: int,
+ total_pages: int,
+ ) -> List[List[Dict[str, str]]]:
+ """
+ 为资源列表生成下载和翻页按钮。
+ """
+ buttons: List[List[Dict[str, str]]] = [
+ [
+ {
+ "text": "🤖 自动选择下载",
+ "callback_data": f"media:{request.request_id}:download:0",
+ }
+ ]
+ ]
+ max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
+ max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
+
+ current_row: List[Dict[str, str]] = []
+ for index, context in enumerate(items, start=1):
+ torrent = context.torrent_info
+ if max_per_row == 1:
+ button_text = f"{index}. {torrent.site_name} - {torrent.seeders}↑"
+ if len(button_text) > max_text_length:
+ button_text = button_text[: max_text_length - 3] + "..."
+ buttons.append(
+ [
+ {
+ "text": button_text,
+ "callback_data": f"media:{request.request_id}:download:{index}",
+ }
+ ]
+ )
+ continue
+
+ current_row.append(
+ {
+ "text": f"{index}",
+ "callback_data": f"media:{request.request_id}:download:{index}",
+ }
+ )
+ if len(current_row) == max_per_row or index == len(items):
+ buttons.append(current_row)
+ current_row = []
+
+ if total > self._page_size(channel):
+ buttons.extend(self._navigation_buttons(request, total_pages))
+ return buttons
+
+ def _has_next_page(self, request: PendingMediaInteraction) -> bool:
+ """
+ 判断当前视图是否还有下一页。
+ """
+ _, page, total_pages = self._page_items(
+ items=request.items,
+ page=request.page,
+ page_size=self._page_size(request.channel),
+ )
+ return page < total_pages - 1
+
+ @staticmethod
+ def _navigation_buttons(
+ request: PendingMediaInteraction,
+ total_pages: int,
+ ) -> List[List[Dict[str, str]]]:
+ """
+ 按当前页状态生成上一页和下一页按钮。
+ """
+ buttons: List[List[Dict[str, str]]] = []
+ nav_row: List[Dict[str, str]] = []
+ if request.page > 0:
+ nav_row.append(
+ {
+ "text": "⬅️ 上一页",
+ "callback_data": f"media:{request.request_id}:page-prev",
+ }
+ )
+ if request.page < total_pages - 1:
+ nav_row.append(
+ {
+ "text": "下一页 ➡️",
+ "callback_data": f"media:{request.request_id}:page-next",
+ }
+ )
+ if nav_row:
+ buttons.append(nav_row)
+ return buttons
+
+ @staticmethod
+ def _page_items(
+ items: List[Any],
+ page: int,
+ page_size: int,
+ ) -> Tuple[List[Any], int, int]:
+ """
+ 返回当前页数据,并把页码限制在有效范围内。
+ """
+ total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1
+ page = min(max(0, page), total_pages - 1)
+ start = page * page_size
+ end = start + page_size
+ return items[start:end], page, total_pages
+
+ def _page_size(self, channel: Optional[MessageChannel]) -> int:
+ """
+ 按渠道交互能力选择分页大小。
+ """
+ return (
+ self._button_page_size
+ if self._supports_interactive_buttons(channel)
+ else self._text_page_size
+ )
+
+ @staticmethod
+ def _supports_interactive_buttons(channel: Optional[MessageChannel]) -> bool:
+ """
+ 判断渠道是否同时支持按钮展示与按钮回调。
+ """
+ return bool(
+ channel
+ and ChannelCapabilityManager.supports_buttons(channel)
+ and ChannelCapabilityManager.supports_callbacks(channel)
+ )
+
+ @staticmethod
+ def _build_no_exists_messages(
+ mediainfo: MediaInfo,
+ no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]],
+ show_missing_only: bool,
+ ) -> List[str]:
+ """
+ 将缺失集信息转换为可发送的文案。
+ """
+ if not no_exists:
+ return []
+ mediakey = mediainfo.tmdb_id or mediainfo.douban_id
+ season_map = no_exists.get(mediakey) or {}
+ if show_missing_only:
+ return [
+ f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
+ for sea, no_exist in season_map.items()
+ ]
+ return [
+ f"第 {sea} 季总 {no_exist.total_episode} 集"
+ for sea, no_exist in season_map.items()
+ ]
+
+ @staticmethod
+ def _should_auto_download(userid: Union[str, int]) -> bool:
+ """
+ 判断当前用户是否命中自动下载名单。
+ """
+ auto_download_user = settings.AUTO_DOWNLOAD_USER
+ return bool(
+ auto_download_user
+ and (
+ auto_download_user == "all"
+ or any(userid == user for user in auto_download_user.split(","))
+ )
+ )
+
+ def _post_invalid_input(
+ self,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: Optional[str],
+ title: str = "输入有误!",
+ ) -> None:
+ """
+ 发送统一的非法输入提示。
+ """
+ self.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=title,
+ )
+ )
diff --git a/app/chain/site.py b/app/chain/site.py
index 54c6961f..e11e2a39 100644
--- a/app/chain/site.py
+++ b/app/chain/site.py
@@ -7,7 +7,7 @@ from urllib.parse import urljoin
from lxml import etree
from app.chain import ChainBase
-from app.helper.slash import (
+from app.helper.interaction import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,
@@ -1060,8 +1060,9 @@ class SiteChain(ChainBase):
original_chat_id=original_chat_id,
)
+ @staticmethod
def _format_site_list(
- self, site_list: List[Site], channel: Optional[MessageChannel]
+ site_list: List[Site], channel: Optional[MessageChannel]
) -> str:
"""
根据渠道能力格式化站点列表。
diff --git a/app/chain/skills.py b/app/chain/skills.py
index 648e324b..1625cdcb 100644
--- a/app/chain/skills.py
+++ b/app/chain/skills.py
@@ -1,145 +1,18 @@
import re
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta
-from threading import Lock
-from typing import Dict, List, Optional, Tuple, Union
-import uuid
+from typing import List, Optional, Tuple, Union
from app.chain import ChainBase
-from app.helper.slash import (
+from app.helper.interaction import (
build_navigation_buttons,
page_items,
supports_interaction_buttons,
- update_or_post_message,
+ update_or_post_message, skills_interaction_manager, PendingSkillsInteraction,
)
from app.helper.skill import SkillHelper, SkillInfo
from app.schemas import Notification
from app.schemas.types import MessageChannel
-@dataclass
-class PendingSkillsInteraction:
- """
- 记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
- """
-
- request_id: str
- user_id: str
- channel: Optional[MessageChannel]
- source: Optional[str]
- username: Optional[str]
- view: str = "root"
- local_page: int = 0
- market_page: int = 0
- market_query: str = ""
- awaiting_input: Optional[str] = None
- created_at: datetime = field(default_factory=datetime.now)
-
-
-class SkillsInteractionManager:
- """
- 管理用户当前的技能交互状态。
-
- 每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
- """
-
- _ttl = timedelta(hours=24)
-
- def __init__(self):
- self._by_id: Dict[str, PendingSkillsInteraction] = {}
- self._by_user: Dict[str, str] = {}
- self._lock = Lock()
-
- def _cleanup_locked(self):
- """
- 清理超时会话,避免按钮回调无限积累。
- """
- expire_before = datetime.now() - self._ttl
- expired = [
- request_id
- for request_id, request in self._by_id.items()
- if request.created_at < expire_before
- ]
- for request_id in expired:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def create_or_replace(
- self,
- user_id: Union[str, int],
- channel: Optional[MessageChannel],
- source: Optional[str],
- username: Optional[str],
- ) -> PendingSkillsInteraction:
- """
- 为用户创建新会话,并替换掉旧的技能交互状态。
- """
- with self._lock:
- self._cleanup_locked()
- user_key = str(user_id)
- old_request_id = self._by_user.get(user_key)
- if old_request_id:
- self._by_id.pop(old_request_id, None)
- request_id = uuid.uuid4().hex[:12]
- request = PendingSkillsInteraction(
- request_id=request_id,
- user_id=user_key,
- channel=channel,
- source=source,
- username=username,
- )
- self._by_id[request_id] = request
- self._by_user[user_key] = request_id
- return request
-
- def get_by_user(
- self, user_id: Union[str, int]
- ) -> Optional[PendingSkillsInteraction]:
- """
- 按用户获取当前有效会话,供纯文本回复路由使用。
- """
- with self._lock:
- self._cleanup_locked()
- request_id = self._by_user.get(str(user_id))
- if not request_id:
- return None
- return self._by_id.get(request_id)
-
- def get_by_id(
- self, request_id: str, user_id: Union[str, int]
- ) -> Optional[PendingSkillsInteraction]:
- """
- 按请求 ID 获取会话,并校验会话归属用户。
- """
- with self._lock:
- self._cleanup_locked()
- request = self._by_id.get(request_id)
- if not request or str(request.user_id) != str(user_id):
- return None
- return request
-
- def remove(self, request_id: str) -> None:
- """
- 主动结束会话,释放用户和请求 ID 的双向索引。
- """
- with self._lock:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def clear(self):
- """
- 清空所有会话,主要用于测试场景。
- """
- with self._lock:
- self._by_id.clear()
- self._by_user.clear()
-
-
-skills_interaction_manager = SkillsInteractionManager()
-
-
class SkillsChain(ChainBase):
"""
处理 /skills 指令、按钮回调和文本式技能管理交互。
@@ -153,11 +26,11 @@ class SkillsChain(ChainBase):
self.skillhelper = SkillHelper()
def remote_manage(
- self,
- arg_str: str,
- channel: MessageChannel,
- userid: Union[str, int],
- source: Optional[str] = None,
+ self,
+ arg_str: str,
+ channel: MessageChannel,
+ userid: Union[str, int],
+ source: Optional[str] = None,
):
"""
/skills 入口。创建新会话并渲染首屏菜单。
@@ -205,14 +78,14 @@ class SkillsChain(ChainBase):
return request_id, action, index
def handle_callback_interaction(
- self,
- callback_data: str,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
+ self,
+ callback_data: str,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
) -> bool:
"""
处理按钮交互,并在同一条消息上刷新当前视图。
@@ -364,12 +237,12 @@ class SkillsChain(ChainBase):
return True
def handle_text_interaction(
- self,
- channel: MessageChannel,
- source: str,
- userid: Union[str, int],
- username: str,
- text: str,
+ self,
+ channel: MessageChannel,
+ source: str,
+ userid: Union[str, int],
+ username: str,
+ text: str,
) -> bool:
"""
处理不支持按钮渠道上的文本指令,也兼容用户直接回复文字操作。
@@ -660,42 +533,42 @@ class SkillsChain(ChainBase):
return True
def _install_market_skill(
- self,
- request: PendingSkillsInteraction,
- page_index: int,
+ self,
+ request: PendingSkillsInteraction,
+ page_index: int,
) -> Tuple[bool, str]:
"""
按当前市场页的可见序号安装技能,避免跨页序号歧义。
"""
market_skills = self._get_market_skills(request=request)
- page_items, page, _ = self._page_items(
+ items, page, _ = self._page_items(
items=market_skills,
page=request.market_page,
page_size=self._page_size(request.channel),
)
request.market_page = page
- if page_index < 1 or page_index > len(page_items):
+ if page_index < 1 or page_index > len(items):
return False, "安装序号无效"
- return self.skillhelper.install_market_skill(page_items[page_index - 1])
+ return self.skillhelper.install_market_skill(items[page_index - 1])
def _remove_local_skill(
- self,
- request: PendingSkillsInteraction,
- page_index: int,
+ self,
+ request: PendingSkillsInteraction,
+ page_index: int,
) -> Tuple[bool, str]:
"""
按当前已安装页的可见序号删除技能,并拦截内置技能。
"""
local_skills = self.skillhelper.list_local_skills()
- page_items, page, _ = self._page_items(
+ items, page, _ = self._page_items(
items=local_skills,
page=request.local_page,
page_size=self._page_size(request.channel),
)
request.local_page = page
- if page_index < 1 or page_index > len(page_items):
+ if page_index < 1 or page_index > len(items):
return False, "删除序号无效"
- target = page_items[page_index - 1]
+ target = items[page_index - 1]
if not target.removable:
return False, f"技能 {target.id} 是内置技能,不能删除"
return self.skillhelper.remove_local_skill(target.id)
@@ -713,15 +586,15 @@ class SkillsChain(ChainBase):
return self.skillhelper.remove_custom_market_source(target.source)
def _render_interaction(
- self,
- request: PendingSkillsInteraction,
- channel: MessageChannel,
- source: Optional[str],
- userid: Union[str, int],
- username: Optional[str],
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
- force_market_refresh: bool = False,
+ self,
+ request: PendingSkillsInteraction,
+ channel: MessageChannel,
+ source: Optional[str],
+ userid: Union[str, int],
+ username: Optional[str],
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+ force_market_refresh: bool = False,
) -> None:
"""
根据当前视图生成内容,并选择编辑原消息或发送新消息。
@@ -758,9 +631,9 @@ class SkillsChain(ChainBase):
)
def _build_root_view(
- self,
- request: PendingSkillsInteraction,
- force_market_refresh: bool = False,
+ self,
+ request: PendingSkillsInteraction,
+ force_market_refresh: bool = False,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建根菜单视图,汇总本地技能和市场概览。
@@ -809,14 +682,14 @@ class SkillsChain(ChainBase):
return "技能管理", "\n".join(text_lines), buttons
def _build_installed_view(
- self,
- request: PendingSkillsInteraction
+ self,
+ request: PendingSkillsInteraction
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建已安装技能视图,列出来源和可删除状态。
"""
local_skills = self.skillhelper.list_local_skills()
- page_items, page, total_pages = self._page_items(
+ items, page, total_pages = self._page_items(
items=local_skills,
page=request.local_page,
page_size=self._page_size(request.channel),
@@ -824,11 +697,11 @@ class SkillsChain(ChainBase):
request.local_page = page
text_lines = [f"第 {page + 1}/{total_pages} 页,共 {len(local_skills)} 个技能"]
- if not page_items:
+ if not items:
text_lines.append("")
text_lines.append("当前没有已安装技能")
else:
- for index, skill in enumerate(page_items, start=1):
+ for index, skill in enumerate(items, start=1):
action = "可删除" if skill.removable else "内置不可删"
text_lines.extend(
[
@@ -869,9 +742,9 @@ class SkillsChain(ChainBase):
return "已安装技能", "\n".join(text_lines), buttons
def _build_market_view(
- self,
- request: PendingSkillsInteraction,
- force_market_refresh: bool = False,
+ self,
+ request: PendingSkillsInteraction,
+ force_market_refresh: bool = False,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建技能市场视图,仅展示尚未安装的技能。
@@ -880,7 +753,7 @@ class SkillsChain(ChainBase):
request=request,
force_market_refresh=force_market_refresh,
)
- page_items, page, total_pages = self._page_items(
+ items, page, total_pages = self._page_items(
items=market_skills,
page=request.market_page,
page_size=self._page_size(request.channel),
@@ -897,14 +770,14 @@ class SkillsChain(ChainBase):
"搜索输入中:直接回复关键词即可筛选市场技能,回复 `取消` 结束输入。",
]
)
- if not page_items:
+ if not items:
text_lines.append("")
if request.market_query:
text_lines.append("当前搜索没有匹配的市场技能")
else:
text_lines.append("当前没有可安装的市场技能")
else:
- for index, skill in enumerate(page_items, start=1):
+ for index, skill in enumerate(items, start=1):
text_lines.extend(
[
"",
@@ -970,8 +843,8 @@ class SkillsChain(ChainBase):
return "技能市场", "\n".join(text_lines), buttons
def _build_sources_view(
- self,
- request: PendingSkillsInteraction,
+ self,
+ request: PendingSkillsInteraction,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建技能源管理视图,提供自定义 GitHub 源的增删入口。
@@ -1052,9 +925,9 @@ class SkillsChain(ChainBase):
@staticmethod
def _page_items(
- items: List[SkillInfo],
- page: int,
- page_size: int,
+ items: List[SkillInfo],
+ page: int,
+ page_size: int,
) -> Tuple[List[SkillInfo], int, int]:
"""
返回当前页的数据,并把页码钳制到有效范围内。
@@ -1080,9 +953,9 @@ class SkillsChain(ChainBase):
@staticmethod
def _navigation_buttons(
- request: PendingSkillsInteraction,
- page: int,
- total_pages: int,
+ request: PendingSkillsInteraction,
+ page: int,
+ total_pages: int,
) -> List[List[dict]]:
"""
为分页视图生成上一页和下一页按钮。
@@ -1095,16 +968,16 @@ class SkillsChain(ChainBase):
)
def _update_or_post_message(
- self,
- channel: MessageChannel,
- source: Optional[str],
- userid: Union[str, int],
- username: Optional[str],
- title: str,
- text: str,
- buttons: Optional[List[List[dict]]] = None,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
+ self,
+ channel: MessageChannel,
+ source: Optional[str],
+ userid: Union[str, int],
+ username: Optional[str],
+ title: str,
+ text: str,
+ buttons: Optional[List[List[dict]]] = None,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
) -> None:
"""
优先编辑原消息,编辑失败时再回退为发送新消息。
@@ -1136,9 +1009,9 @@ class SkillsChain(ChainBase):
return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出"
def _get_market_skills(
- self,
- request: PendingSkillsInteraction,
- force_market_refresh: bool = False,
+ self,
+ request: PendingSkillsInteraction,
+ force_market_refresh: bool = False,
) -> List[SkillInfo]:
"""
获取当前 /skills 会话可见的市场技能,并应用搜索词过滤。
@@ -1183,8 +1056,8 @@ class SkillsChain(ChainBase):
@staticmethod
def _apply_market_search(
- request: PendingSkillsInteraction,
- query: str,
+ request: PendingSkillsInteraction,
+ query: str,
) -> None:
"""
将会话切到市场搜索结果视图,并重置分页状态。
diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py
index 17afae40..1f9ea7ba 100644
--- a/app/chain/subscribe.py
+++ b/app/chain/subscribe.py
@@ -1,6 +1,7 @@
import copy
import json
import random
+import re
import threading
import time
from datetime import datetime
@@ -11,7 +12,7 @@ from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
-from app.helper.slash import (
+from app.helper.interaction import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,
diff --git a/app/helper/__init__.py b/app/helper/__init__.py
index ff2e48a2..e69de29b 100644
--- a/app/helper/__init__.py
+++ b/app/helper/__init__.py
@@ -1 +0,0 @@
-from .cloudflare import under_challenge
diff --git a/app/helper/interaction.py b/app/helper/interaction.py
new file mode 100644
index 00000000..8cc69d65
--- /dev/null
+++ b/app/helper/interaction.py
@@ -0,0 +1,626 @@
+import math
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime, timedelta
+from threading import Lock
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+from app.core.context import MediaInfo
+from app.core.meta import MetaBase
+from app.schemas import Notification
+from app.schemas.message import ChannelCapabilityManager
+from app.schemas.types import MessageChannel
+
+
+@dataclass
+class PendingSlashInteraction:
+ """
+ 通用 slash 命令交互上下文。
+ """
+
+ request_id: str
+ user_id: str
+ channel: Optional[MessageChannel]
+ source: Optional[str]
+ username: Optional[str]
+ command: str
+ page: int = 0
+ awaiting_input: Optional[str] = None
+ created_at: datetime = field(default_factory=datetime.now)
+
+
+class SlashInteractionManager:
+ """
+ 管理单个 slash 命令的交互会话。
+ """
+
+ _ttl = timedelta(hours=24)
+
+ def __init__(self):
+ self._by_id: Dict[str, PendingSlashInteraction] = {}
+ self._by_user: Dict[str, str] = {}
+ self._lock = Lock()
+
+ def _cleanup_locked(self) -> None:
+ expire_before = datetime.now() - self._ttl
+ expired = [
+ request_id
+ for request_id, request in self._by_id.items()
+ if request.created_at < expire_before
+ ]
+ for request_id in expired:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def create_or_replace(
+ self,
+ user_id: Union[str, int],
+ command: str,
+ channel: Optional[MessageChannel],
+ source: Optional[str],
+ username: Optional[str],
+ ) -> PendingSlashInteraction:
+ with self._lock:
+ self._cleanup_locked()
+ user_key = str(user_id)
+ old_request_id = self._by_user.get(user_key)
+ if old_request_id:
+ self._by_id.pop(old_request_id, None)
+ request = PendingSlashInteraction(
+ request_id=uuid.uuid4().hex[:12],
+ user_id=user_key,
+ command=command,
+ channel=channel,
+ source=source,
+ username=username,
+ )
+ self._by_id[request.request_id] = request
+ self._by_user[user_key] = request.request_id
+ return request
+
+ def get_by_user(
+ self, user_id: Union[str, int]
+ ) -> Optional[PendingSlashInteraction]:
+ with self._lock:
+ self._cleanup_locked()
+ request_id = self._by_user.get(str(user_id))
+ if not request_id:
+ return None
+ return self._by_id.get(request_id)
+
+ def get_by_id(
+ self, request_id: str, user_id: Union[str, int]
+ ) -> Optional[PendingSlashInteraction]:
+ with self._lock:
+ self._cleanup_locked()
+ request = self._by_id.get(request_id)
+ if not request or str(request.user_id) != str(user_id):
+ return None
+ return request
+
+ def remove(self, request_id: str) -> None:
+ with self._lock:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def clear(self) -> None:
+ with self._lock:
+ self._by_id.clear()
+ self._by_user.clear()
+
+
+def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
+ """
+ 渠道同时支持按钮和回调时,优先使用按钮交互。
+ """
+ return bool(
+ channel
+ and ChannelCapabilityManager.supports_buttons(channel)
+ and ChannelCapabilityManager.supports_callbacks(channel)
+ )
+
+
+def supports_markdown(channel: Optional[MessageChannel]) -> bool:
+ """
+ 仅在支持 Markdown 的渠道上输出 Markdown 内容。
+ """
+ return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
+
+
+def page_items(
+ items: Sequence[Any],
+ page: int,
+ page_size: int,
+) -> Tuple[List[Any], int, int]:
+ """
+ 对列表做分页并规范化页码。
+ """
+ total = len(items)
+ if total == 0:
+ return [], 0, 1
+ total_pages = max(1, math.ceil(total / max(1, page_size)))
+ page = min(max(0, page), total_pages - 1)
+ start = page * page_size
+ end = start + page_size
+ return list(items[start:end]), page, total_pages
+
+
+def build_navigation_buttons(
+ prefix: str,
+ request: Any,
+ page: int,
+ total_pages: int,
+) -> List[List[dict]]:
+ """
+ 构造标准上一页/下一页按钮。
+ """
+ buttons = []
+ nav_row = []
+ if page > 0:
+ nav_row.append(
+ {
+ "text": "⬅️ 上一页",
+ "callback_data": f"{prefix}:{request.request_id}:page-prev",
+ }
+ )
+ if page < total_pages - 1:
+ nav_row.append(
+ {
+ "text": "下一页 ➡️",
+ "callback_data": f"{prefix}:{request.request_id}:page-next",
+ }
+ )
+ if nav_row:
+ buttons.append(nav_row)
+ return buttons
+
+
+def update_or_post_message(
+ chain,
+ channel: MessageChannel,
+ source: Optional[str],
+ userid: Union[str, int],
+ username: Optional[str],
+ title: str,
+ text: str,
+ buttons: Optional[List[List[dict]]] = None,
+ original_message_id: Optional[Union[str, int]] = None,
+ original_chat_id: Optional[str] = None,
+) -> None:
+ """
+ 优先编辑原消息,失败时回退为发送新消息。
+ """
+ if (
+ original_message_id
+ and original_chat_id
+ and ChannelCapabilityManager.supports_editing(channel)
+ ):
+ edited = chain.edit_message(
+ channel=channel,
+ source=source,
+ message_id=original_message_id,
+ chat_id=original_chat_id,
+ title=title,
+ text=text,
+ buttons=buttons,
+ )
+ if edited:
+ return
+
+ chain.post_message(
+ Notification(
+ channel=channel,
+ source=source,
+ userid=userid,
+ username=username,
+ title=title,
+ text=text,
+ buttons=buttons,
+ )
+ )
+
+
+def escape_markdown_table_cell(value: object) -> str:
+ """
+ 最小化转义 Markdown 表格中的特殊字符。
+ """
+ text = str(value or "").replace("\n", "
")
+ return text.replace("|", "\\|")
+
+
+def format_markdown_table(
+ headers: Sequence[str],
+ rows: Sequence[Sequence[object]],
+) -> str:
+ """
+ 生成 Markdown 表格文本。
+ """
+ header_line = (
+ "| "
+ + " | ".join(escape_markdown_table_cell(item) for item in headers)
+ + " |"
+ )
+ separator_line = "| " + " | ".join("---" for _ in headers) + " |"
+ data_lines = [
+ "| "
+ + " | ".join(escape_markdown_table_cell(item) for item in row)
+ + " |"
+ for row in rows
+ ]
+ return "\n".join([header_line, separator_line, *data_lines])
+
+
+@dataclass
+class PendingMediaInteraction:
+ """
+ 记录一次搜索/下载/订阅交互的当前上下文。
+ """
+
+ request_id: str
+ user_id: str
+ channel: Optional[MessageChannel]
+ source: Optional[str]
+ username: Optional[str]
+ action: str
+ keyword: str
+ phase: str = "media"
+ page: int = 0
+ title: str = ""
+ meta: Optional[MetaBase] = None
+ current_media: Optional[MediaInfo] = None
+ items: List[Any] = field(default_factory=list)
+ created_at: datetime = field(default_factory=datetime.now)
+
+
+class MediaInteractionManager:
+ """
+ 管理用户当前激活的媒体交互状态。
+
+ 每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。
+ """
+
+ _ttl = timedelta(hours=24)
+
+ def __init__(self):
+ self._by_id: Dict[str, PendingMediaInteraction] = {}
+ self._by_user: Dict[str, str] = {}
+ self._lock = Lock()
+
+ def _cleanup_locked(self) -> None:
+ """
+ 清理超时会话,避免内存中残留旧交互状态。
+ """
+ expire_before = datetime.now() - self._ttl
+ expired = [
+ request_id
+ for request_id, request in self._by_id.items()
+ if request.created_at < expire_before
+ ]
+ for request_id in expired:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def create_or_replace(
+ self,
+ user_id: Union[str, int],
+ channel: Optional[MessageChannel],
+ source: Optional[str],
+ username: Optional[str],
+ action: str,
+ keyword: str,
+ title: str = "",
+ meta: Optional[MetaBase] = None,
+ items: Optional[List[Any]] = None,
+ ) -> PendingMediaInteraction:
+ """
+ 为用户创建新的交互状态,并替换旧会话。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ user_key = str(user_id)
+ old_request_id = self._by_user.get(user_key)
+ if old_request_id:
+ self._by_id.pop(old_request_id, None)
+
+ request = PendingMediaInteraction(
+ request_id=uuid.uuid4().hex[:12],
+ user_id=user_key,
+ channel=channel,
+ source=source,
+ username=username,
+ action=action,
+ keyword=keyword,
+ title=title,
+ meta=meta,
+ items=list(items or []),
+ )
+ self._by_id[request.request_id] = request
+ self._by_user[user_key] = request.request_id
+ return request
+
+ def get_by_user(
+ self, user_id: Union[str, int]
+ ) -> Optional[PendingMediaInteraction]:
+ """
+ 按用户读取当前会话,供文本回复和旧按钮兼容使用。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request_id = self._by_user.get(str(user_id))
+ if not request_id:
+ return None
+ return self._by_id.get(request_id)
+
+ def get_by_id(
+ self, request_id: str, user_id: Union[str, int]
+ ) -> Optional[PendingMediaInteraction]:
+ """
+ 按请求 ID 读取会话,并校验用户归属。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request = self._by_id.get(request_id)
+ if not request or str(request.user_id) != str(user_id):
+ return None
+ return request
+
+ def remove(self, request_id: str) -> None:
+ """
+ 主动结束一条会话。
+ """
+ with self._lock:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def clear(self) -> None:
+ """
+ 清空所有交互状态,主要用于测试。
+ """
+ with self._lock:
+ self._by_id.clear()
+ self._by_user.clear()
+
+
+media_interaction_manager = MediaInteractionManager()
+
+
+@dataclass(frozen=True)
+class AgentInteractionOption:
+ """
+ Agent 交互选项。
+ """
+
+ label: str
+ value: str
+
+
+@dataclass
+class PendingAgentInteraction:
+ """
+ 待处理的 Agent 客户端交互请求。
+ """
+
+ request_id: str
+ session_id: str
+ user_id: str
+ channel: Optional[str]
+ source: Optional[str]
+ username: Optional[str]
+ title: Optional[str]
+ prompt: str
+ options: List[AgentInteractionOption]
+ created_at: datetime = field(default_factory=datetime.now)
+
+
+class AgentInteractionManager:
+ """
+ 管理 Agent 发起的客户端交互请求。
+ """
+
+ _ttl = timedelta(hours=24)
+
+ def __init__(self):
+ self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
+ self._lock = Lock()
+
+ def _cleanup_locked(self) -> None:
+ expire_before = datetime.now() - self._ttl
+ expired_ids = [
+ request_id
+ for request_id, request in self._pending_interactions.items()
+ if request.created_at < expire_before
+ ]
+ for request_id in expired_ids:
+ self._pending_interactions.pop(request_id, None)
+
+ def create_request(
+ self,
+ session_id: str,
+ user_id: str,
+ channel: Optional[str],
+ source: Optional[str],
+ username: Optional[str],
+ title: Optional[str],
+ prompt: str,
+ options: List[AgentInteractionOption],
+ ) -> PendingAgentInteraction:
+ """
+ 创建一条待用户确认的 Agent 交互请求。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request_id = uuid.uuid4().hex[:12]
+ while request_id in self._pending_interactions:
+ request_id = uuid.uuid4().hex[:12]
+ request = PendingAgentInteraction(
+ request_id=request_id,
+ session_id=session_id,
+ user_id=str(user_id),
+ channel=channel,
+ source=source,
+ username=username,
+ title=title,
+ prompt=prompt,
+ options=options,
+ )
+ self._pending_interactions[request_id] = request
+ return request
+
+ def resolve(
+ self,
+ request_id: str,
+ option_index: int,
+ user_id: Optional[str] = None,
+ ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
+ """
+ 消费一条 Agent 交互请求,并返回选中的选项。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request = self._pending_interactions.get(request_id)
+ if not request:
+ return None
+ if user_id is not None and str(request.user_id) != str(user_id):
+ return None
+ if option_index < 1 or option_index > len(request.options):
+ return None
+ option = request.options[option_index - 1]
+ self._pending_interactions.pop(request_id, None)
+ return request, option
+
+ def clear(self) -> None:
+ """
+ 清空所有 Agent 交互请求。
+ """
+ with self._lock:
+ self._pending_interactions.clear()
+
+
+agent_interaction_manager = AgentInteractionManager()
+
+
+@dataclass
+class PendingSkillsInteraction:
+ """
+ 记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
+ """
+
+ request_id: str
+ user_id: str
+ channel: Optional[MessageChannel]
+ source: Optional[str]
+ username: Optional[str]
+ view: str = "root"
+ local_page: int = 0
+ market_page: int = 0
+ market_query: str = ""
+ awaiting_input: Optional[str] = None
+ created_at: datetime = field(default_factory=datetime.now)
+
+
+class SkillsInteractionManager:
+ """
+ 管理用户当前的技能交互状态。
+
+ 每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
+ """
+
+ _ttl = timedelta(hours=24)
+
+ def __init__(self):
+ self._by_id: Dict[str, PendingSkillsInteraction] = {}
+ self._by_user: Dict[str, str] = {}
+ self._lock = Lock()
+
+ def _cleanup_locked(self):
+ """
+ 清理超时会话,避免按钮回调无限积累。
+ """
+ expire_before = datetime.now() - self._ttl
+ expired = [
+ request_id
+ for request_id, request in self._by_id.items()
+ if request.created_at < expire_before
+ ]
+ for request_id in expired:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def create_or_replace(
+ self,
+ user_id: Union[str, int],
+ channel: Optional[MessageChannel],
+ source: Optional[str],
+ username: Optional[str],
+ ) -> PendingSkillsInteraction:
+ """
+ 为用户创建新会话,并替换掉旧的技能交互状态。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ user_key = str(user_id)
+ old_request_id = self._by_user.get(user_key)
+ if old_request_id:
+ self._by_id.pop(old_request_id, None)
+ request_id = uuid.uuid4().hex[:12]
+ request = PendingSkillsInteraction(
+ request_id=request_id,
+ user_id=user_key,
+ channel=channel,
+ source=source,
+ username=username,
+ )
+ self._by_id[request_id] = request
+ self._by_user[user_key] = request_id
+ return request
+
+ def get_by_user(
+ self, user_id: Union[str, int]
+ ) -> Optional[PendingSkillsInteraction]:
+ """
+ 按用户获取当前有效会话,供纯文本回复路由使用。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request_id = self._by_user.get(str(user_id))
+ if not request_id:
+ return None
+ return self._by_id.get(request_id)
+
+ def get_by_id(
+ self, request_id: str, user_id: Union[str, int]
+ ) -> Optional[PendingSkillsInteraction]:
+ """
+ 按请求 ID 获取会话,并校验会话归属用户。
+ """
+ with self._lock:
+ self._cleanup_locked()
+ request = self._by_id.get(request_id)
+ if not request or str(request.user_id) != str(user_id):
+ return None
+ return request
+
+ def remove(self, request_id: str) -> None:
+ """
+ 主动结束会话,释放用户和请求 ID 的双向索引。
+ """
+ with self._lock:
+ request = self._by_id.pop(request_id, None)
+ if request:
+ self._by_user.pop(str(request.user_id), None)
+
+ def clear(self):
+ """
+ 清空所有会话,主要用于测试场景。
+ """
+ with self._lock:
+ self._by_id.clear()
+ self._by_user.clear()
+
+
+skills_interaction_manager = SkillsInteractionManager()
diff --git a/app/helper/slash.py b/app/helper/slash.py
deleted file mode 100644
index 49f7105e..00000000
--- a/app/helper/slash.py
+++ /dev/null
@@ -1,244 +0,0 @@
-import math
-import uuid
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta
-from threading import Lock
-from typing import Dict, List, Optional, Sequence, Tuple, Union
-
-from app.schemas import Notification
-from app.schemas.message import ChannelCapabilityManager
-from app.schemas.types import MessageChannel
-
-
-@dataclass
-class PendingSlashInteraction:
- """
- 通用 slash 命令交互上下文。
- """
-
- request_id: str
- user_id: str
- channel: Optional[MessageChannel]
- source: Optional[str]
- username: Optional[str]
- command: str
- page: int = 0
- awaiting_input: Optional[str] = None
- created_at: datetime = field(default_factory=datetime.now)
-
-
-class SlashInteractionManager:
- """
- 管理单个 slash 命令的交互会话。
- """
-
- _ttl = timedelta(hours=24)
-
- def __init__(self):
- self._by_id: Dict[str, PendingSlashInteraction] = {}
- self._by_user: Dict[str, str] = {}
- self._lock = Lock()
-
- def _cleanup_locked(self) -> None:
- expire_before = datetime.now() - self._ttl
- expired = [
- request_id
- for request_id, request in self._by_id.items()
- if request.created_at < expire_before
- ]
- for request_id in expired:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def create_or_replace(
- self,
- user_id: Union[str, int],
- command: str,
- channel: Optional[MessageChannel],
- source: Optional[str],
- username: Optional[str],
- ) -> PendingSlashInteraction:
- with self._lock:
- self._cleanup_locked()
- user_key = str(user_id)
- old_request_id = self._by_user.get(user_key)
- if old_request_id:
- self._by_id.pop(old_request_id, None)
- request = PendingSlashInteraction(
- request_id=uuid.uuid4().hex[:12],
- user_id=user_key,
- command=command,
- channel=channel,
- source=source,
- username=username,
- )
- self._by_id[request.request_id] = request
- self._by_user[user_key] = request.request_id
- return request
-
- def get_by_user(
- self, user_id: Union[str, int]
- ) -> Optional[PendingSlashInteraction]:
- with self._lock:
- self._cleanup_locked()
- request_id = self._by_user.get(str(user_id))
- if not request_id:
- return None
- return self._by_id.get(request_id)
-
- def get_by_id(
- self, request_id: str, user_id: Union[str, int]
- ) -> Optional[PendingSlashInteraction]:
- with self._lock:
- self._cleanup_locked()
- request = self._by_id.get(request_id)
- if not request or str(request.user_id) != str(user_id):
- return None
- return request
-
- def remove(self, request_id: str) -> None:
- with self._lock:
- request = self._by_id.pop(request_id, None)
- if request:
- self._by_user.pop(str(request.user_id), None)
-
- def clear(self) -> None:
- with self._lock:
- self._by_id.clear()
- self._by_user.clear()
-
-
-def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
- """
- 渠道同时支持按钮和回调时,优先使用按钮交互。
- """
- return bool(
- channel
- and ChannelCapabilityManager.supports_buttons(channel)
- and ChannelCapabilityManager.supports_callbacks(channel)
- )
-
-
-def supports_markdown(channel: Optional[MessageChannel]) -> bool:
- """
- 仅在支持 Markdown 的渠道上输出 Markdown 内容。
- """
- return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
-
-
-def page_items(
- items: Sequence,
- page: int,
- page_size: int,
-) -> Tuple[List, int, int]:
- """
- 对列表做分页并规范化页码。
- """
- total = len(items)
- if total == 0:
- return [], 0, 1
- total_pages = max(1, math.ceil(total / max(1, page_size)))
- page = min(max(0, page), total_pages - 1)
- start = page * page_size
- end = start + page_size
- return list(items[start:end]), page, total_pages
-
-
-def build_navigation_buttons(
- prefix: str,
- request: PendingSlashInteraction,
- page: int,
- total_pages: int,
-) -> List[List[dict]]:
- """
- 构造标准上一页/下一页按钮。
- """
- buttons = []
- nav_row = []
- if page > 0:
- nav_row.append(
- {
- "text": "⬅️ 上一页",
- "callback_data": f"{prefix}:{request.request_id}:page-prev",
- }
- )
- if page < total_pages - 1:
- nav_row.append(
- {
- "text": "下一页 ➡️",
- "callback_data": f"{prefix}:{request.request_id}:page-next",
- }
- )
- if nav_row:
- buttons.append(nav_row)
- return buttons
-
-
-def update_or_post_message(
- chain,
- channel: MessageChannel,
- source: Optional[str],
- userid: Union[str, int],
- username: Optional[str],
- title: str,
- text: str,
- buttons: Optional[List[List[dict]]] = None,
- original_message_id: Optional[Union[str, int]] = None,
- original_chat_id: Optional[str] = None,
-) -> None:
- """
- 优先编辑原消息,失败时回退为发送新消息。
- """
- if (
- original_message_id
- and original_chat_id
- and ChannelCapabilityManager.supports_editing(channel)
- ):
- edited = chain.edit_message(
- channel=channel,
- source=source,
- message_id=original_message_id,
- chat_id=original_chat_id,
- title=title,
- text=text,
- buttons=buttons,
- )
- if edited:
- return
-
- chain.post_message(
- Notification(
- channel=channel,
- source=source,
- userid=userid,
- username=username,
- title=title,
- text=text,
- buttons=buttons,
- )
- )
-
-
-def escape_markdown_table_cell(value: object) -> str:
- """
- 最小化转义 Markdown 表格中的特殊字符。
- """
- text = str(value or "").replace("\n", "
")
- text = text.replace("|", "\\|")
- return text
-
-
-def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str:
- """
- 生成 Markdown 表格文本。
- """
- header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |"
- separator_line = "| " + " | ".join("---" for _ in headers) + " |"
- data_lines = [
- "| "
- + " | ".join(escape_markdown_table_cell(item) for item in row)
- + " |"
- for row in rows
- ]
- return "\n".join([header_line, separator_line, *data_lines])
diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py
index 957f6d9d..743b4307 100644
--- a/tests/test_agent_interaction.py
+++ b/tests/test_agent_interaction.py
@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
AskUserChoiceTool,
UserChoiceOptionInput,
)
-from app.chain.interaction import (
+from app.helper.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py
index 862c1d93..2d049d4c 100644
--- a/tests/test_media_interaction.py
+++ b/tests/test_media_interaction.py
@@ -9,7 +9,7 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
setattr(sys.modules["transmission_rpc"], "File", object)
sys.modules.setdefault("psutil", ModuleType("psutil"))
-from app.chain.interaction import MediaInteractionChain, media_interaction_manager
+from app.chain.media import MediaChain, media_interaction_manager
from app.chain.message import MessageChain
from app.core.context import MediaInfo
from app.core.meta import MetaBase
@@ -43,7 +43,7 @@ class TestMediaInteraction(unittest.TestCase):
self.assertIsNotNone(request)
with patch.object(chain, "_record_user_message"), patch(
- "app.chain.message.MediaInteractionChain.handle_text_interaction",
+ "app.chain.message.MediaChain.handle_text_interaction",
return_value=True,
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
chain.handle_message(
@@ -72,7 +72,7 @@ class TestMediaInteraction(unittest.TestCase):
)
with patch(
- "app.chain.message.MediaInteractionChain.handle_callback_interaction",
+ "app.chain.message.MediaChain.handle_callback_interaction",
return_value=True,
) as handle_callback:
chain._handle_callback(
@@ -86,7 +86,7 @@ class TestMediaInteraction(unittest.TestCase):
handle_callback.assert_called_once()
def test_media_interaction_starts_search_and_posts_media_list(self):
- chain = MediaInteractionChain()
+ chain = MediaChain()
meta = self._build_meta("星际穿越")
medias = [
MediaInfo(title="星际穿越", year="2014"),
@@ -94,7 +94,7 @@ class TestMediaInteraction(unittest.TestCase):
]
with patch(
- "app.chain.interaction.MediaChain.search",
+ "app.chain.media.MediaChain.search",
return_value=(meta, medias),
), patch.object(chain, "post_medias_message") as post_medias_message:
handled = chain.handle_text_interaction(
@@ -119,7 +119,7 @@ class TestMediaInteraction(unittest.TestCase):
self.assertEqual(len(request.items), 2)
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
- chain = MediaInteractionChain()
+ chain = MediaChain()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,