From 4d0a722b09aad780fccb5cc6f9f08309b5f39be3 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 1 May 2026 09:53:04 +0800 Subject: [PATCH] refactor: reorganize interaction chain --- app/agent/tools/impl/ask_user_choice.py | 2 +- app/chain/interaction.py | 1363 ----------------------- app/chain/media.py | 270 ++--- app/chain/message.py | 1148 ++++++++++++++++++- app/chain/site.py | 5 +- app/chain/skills.py | 287 ++--- app/chain/subscribe.py | 3 +- app/helper/__init__.py | 1 - app/helper/interaction.py | 626 +++++++++++ app/helper/slash.py | 244 ---- tests/test_agent_interaction.py | 2 +- tests/test_media_interaction.py | 12 +- 12 files changed, 1975 insertions(+), 1988 deletions(-) delete mode 100644 app/chain/interaction.py create mode 100644 app/helper/interaction.py delete mode 100644 app/helper/slash.py diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index c219b97e..e482c396 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -5,7 +5,7 @@ from typing import List, Optional, Type from pydantic import BaseModel, Field, model_validator from app.agent.tools.base import MoviePilotTool, ToolChain -from app.chain.interaction import ( +from app.helper.interaction import ( AgentInteractionOption, agent_interaction_manager, ) diff --git a/app/chain/interaction.py b/app/chain/interaction.py deleted file mode 100644 index c839e611..00000000 --- a/app/chain/interaction.py +++ /dev/null @@ -1,1363 +0,0 @@ -import math -import re -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from threading import Lock -from typing import Any, Dict, List, Optional, Tuple, Union - -from app.chain import ChainBase -from app.chain.download import DownloadChain -from app.chain.media import MediaChain -from app.chain.search import SearchChain -from app.chain.subscribe import SubscribeChain -from app.core.config import settings -from app.core.context import Context, MediaInfo -from app.core.meta import MetaBase -from app.db.user_oper import UserOper -from app.helper.torrent import TorrentHelper -from app.log import logger -from app.schemas import Notification, NotExistMediaInfo -from app.schemas.message import ChannelCapabilityManager -from app.schemas.types import MediaType, MessageChannel -from app.utils.string import StringUtils - - -@dataclass(frozen=True) -class AgentInteractionOption: - """ - Agent 交互选项。 - """ - - label: str - value: str - - -@dataclass -class PendingAgentInteraction: - """ - 待处理的 Agent 客户端交互请求。 - """ - - request_id: str - session_id: str - user_id: str - channel: Optional[str] - source: Optional[str] - username: Optional[str] - title: Optional[str] - prompt: str - options: List[AgentInteractionOption] - created_at: datetime = field(default_factory=datetime.now) - - -class AgentInteractionManager: - """ - 管理 Agent 发起的客户端交互请求。 - """ - - _ttl = timedelta(hours=24) - - def __init__(self): - self._pending_interactions: Dict[str, PendingAgentInteraction] = {} - self._lock = Lock() - - def _cleanup_locked(self) -> None: - expire_before = datetime.now() - self._ttl - expired_ids = [ - request_id - for request_id, request in self._pending_interactions.items() - if request.created_at < expire_before - ] - for request_id in expired_ids: - self._pending_interactions.pop(request_id, None) - - def create_request( - self, - session_id: str, - user_id: str, - channel: Optional[str], - source: Optional[str], - username: Optional[str], - title: Optional[str], - prompt: str, - options: List[AgentInteractionOption], - ) -> PendingAgentInteraction: - """ - 创建一条待用户确认的 Agent 交互请求。 - """ - with self._lock: - self._cleanup_locked() - request_id = uuid.uuid4().hex[:12] - while request_id in self._pending_interactions: - request_id = uuid.uuid4().hex[:12] - request = PendingAgentInteraction( - request_id=request_id, - session_id=session_id, - user_id=str(user_id), - channel=channel, - source=source, - username=username, - title=title, - prompt=prompt, - options=options, - ) - self._pending_interactions[request_id] = request - return request - - def resolve( - self, - request_id: str, - option_index: int, - user_id: Optional[str] = None, - ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]: - """ - 消费一条 Agent 交互请求,并返回选中的选项。 - """ - with self._lock: - self._cleanup_locked() - request = self._pending_interactions.get(request_id) - if not request: - return None - if user_id is not None and str(request.user_id) != str(user_id): - return None - if option_index < 1 or option_index > len(request.options): - return None - option = request.options[option_index - 1] - self._pending_interactions.pop(request_id, None) - return request, option - - def clear(self) -> None: - """ - 清空所有 Agent 交互请求。 - """ - with self._lock: - self._pending_interactions.clear() - - -agent_interaction_manager = AgentInteractionManager() - - -@dataclass -class PendingMediaInteraction: - """ - 记录一次搜索/下载/订阅交互的当前上下文。 - """ - - request_id: str - user_id: str - channel: Optional[MessageChannel] - source: Optional[str] - username: Optional[str] - action: str - keyword: str - phase: str = "media" - page: int = 0 - title: str = "" - meta: Optional[MetaBase] = None - current_media: Optional[MediaInfo] = None - items: List[Any] = field(default_factory=list) - created_at: datetime = field(default_factory=datetime.now) - - -class MediaInteractionManager: - """ - 管理用户当前激活的媒体交互状态。 - - 每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。 - """ - - _ttl = timedelta(hours=24) - - def __init__(self): - self._by_id: Dict[str, PendingMediaInteraction] = {} - self._by_user: Dict[str, str] = {} - self._lock = Lock() - - def _cleanup_locked(self) -> None: - """ - 清理超时会话,避免内存中残留旧交互状态。 - """ - expire_before = datetime.now() - self._ttl - expired = [ - request_id - for request_id, request in self._by_id.items() - if request.created_at < expire_before - ] - for request_id in expired: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def create_or_replace( - self, - user_id: Union[str, int], - channel: Optional[MessageChannel], - source: Optional[str], - username: Optional[str], - action: str, - keyword: str, - title: str = "", - meta: Optional[MetaBase] = None, - items: Optional[List[Any]] = None, - ) -> PendingMediaInteraction: - """ - 为用户创建新的交互状态,并替换旧会话。 - """ - with self._lock: - self._cleanup_locked() - user_key = str(user_id) - old_request_id = self._by_user.get(user_key) - if old_request_id: - self._by_id.pop(old_request_id, None) - - request = PendingMediaInteraction( - request_id=uuid.uuid4().hex[:12], - user_id=user_key, - channel=channel, - source=source, - username=username, - action=action, - keyword=keyword, - title=title, - meta=meta, - items=list(items or []), - ) - self._by_id[request.request_id] = request - self._by_user[user_key] = request.request_id - return request - - def get_by_user( - self, user_id: Union[str, int] - ) -> Optional[PendingMediaInteraction]: - """ - 按用户读取当前会话,供文本回复和旧按钮兼容使用。 - """ - with self._lock: - self._cleanup_locked() - request_id = self._by_user.get(str(user_id)) - if not request_id: - return None - return self._by_id.get(request_id) - - def get_by_id( - self, request_id: str, user_id: Union[str, int] - ) -> Optional[PendingMediaInteraction]: - """ - 按请求 ID 读取会话,并校验用户归属。 - """ - with self._lock: - self._cleanup_locked() - request = self._by_id.get(request_id) - if not request or str(request.user_id) != str(user_id): - return None - return request - - def remove(self, request_id: str) -> None: - """ - 主动结束一条会话。 - """ - with self._lock: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def clear(self) -> None: - """ - 清空所有交互状态,主要用于测试。 - """ - with self._lock: - self._by_id.clear() - self._by_user.clear() - - -media_interaction_manager = MediaInteractionManager() - - -class MediaInteractionChain(ChainBase): - """ - 处理媒体搜索、订阅、资源选择和翻页等交互流程。 - """ - - _button_page_size = 8 - _text_page_size = 8 - - @staticmethod - def has_pending_interaction(user_id: Union[str, int]) -> bool: - """ - 判断用户当前是否存在未结束的媒体交互。 - """ - return media_interaction_manager.get_by_user(user_id) is not None - - @staticmethod - def _get_noexits_info( - meta: MetaBase, mediainfo: MediaInfo - ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]: - """ - 构造媒体缺失集信息,用于全量重搜或自动下载补全集数。 - """ - if mediainfo.type == MediaType.TV: - if not mediainfo.seasons: - mediainfo = MediaChain().recognize_media( - mtype=mediainfo.type, - tmdbid=mediainfo.tmdb_id, - doubanid=mediainfo.douban_id, - cache=False, - ) - if not mediainfo: - logger.warn("媒体信息识别失败,无法补充季集信息") - return {} - if not mediainfo.seasons: - logger.warn( - "媒体信息中没有季集信息,标题:%s,tmdbid:%s,doubanid:%s", - mediainfo.title, - mediainfo.tmdb_id, - mediainfo.douban_id, - ) - return {} - - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - no_exists = {mediakey: {}} - if meta.begin_season: - episodes = mediainfo.seasons.get(meta.begin_season) - if not episodes: - return {} - no_exists[mediakey][meta.begin_season] = NotExistMediaInfo( - season=meta.begin_season, - episodes=[], - total_episode=len(episodes), - start_episode=episodes[0], - ) - else: - for sea, eps in mediainfo.seasons.items(): - if not eps: - continue - no_exists[mediakey][sea] = NotExistMediaInfo( - season=sea, - episodes=[], - total_episode=len(eps), - start_episode=eps[0], - ) - return no_exists - return {} - - @staticmethod - def parse_callback( - callback_data: str, - ) -> Optional[Tuple[Optional[str], str, Optional[int]]]: - """ - 解析新旧两种媒体交互按钮格式。 - """ - if callback_data.startswith("media:"): - parts = callback_data.split(":") - if len(parts) < 3: - return None - request_id = parts[1] - action = parts[2] - index = None - if len(parts) >= 4 and parts[3].isdigit(): - index = int(parts[3]) - return request_id, action, index - - match = re.match(r"^(select|download)_(\d+)$", callback_data) - if match: - return None, match.group(1), int(match.group(2)) - if callback_data == "page_p": - return None, "page-prev", None - if callback_data == "page_n": - return None, "page-next", None - return None - - def handle_callback_interaction( - self, - callback_data: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> bool: - """ - 处理按钮回调,并将当前视图刷新到原消息上。 - """ - parsed = self.parse_callback(callback_data) - if not parsed: - return False - - request_id, action, index = parsed - if request_id: - request = media_interaction_manager.get_by_id(request_id, userid) - else: - request = media_interaction_manager.get_by_user(userid) - - if not request: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title="交互已失效,请重新搜索或订阅", - ) - ) - return True - - request.channel = channel - request.source = source - request.username = username - - if action == "page-prev": - if request.page <= 0: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - title="已经是第一页了!", - ) - return True - request.page -= 1 - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - return True - - if action == "page-next": - if not self._has_next_page(request): - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - title="已经是最后一页了!", - ) - return True - request.page += 1 - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - return True - - if action == "select": - self._handle_media_selection( - request=request, - page_index=index, - channel=channel, - source=source, - userid=userid, - username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - return True - - if action == "download": - self._handle_torrent_selection( - request=request, - page_index=index, - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - - return False - - def handle_text_interaction( - self, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - text: str, - ) -> bool: - """ - 处理文本式交互。 - - 有会话时优先处理数字选择和翻页;无会话时负责识别搜索/订阅类入口。 - """ - request = media_interaction_manager.get_by_user(userid) - normalized = (text or "").strip() - lowered = normalized.lower() - - if request and lowered in {"退出", "关闭", "q", "quit", "exit"}: - media_interaction_manager.remove(request.request_id) - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title="媒体交互已结束", - ) - ) - return True - - if normalized.isdigit(): - if not request: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - request.channel = channel - request.source = source - request.username = username - index = int(normalized) - if request.phase == "torrent": - self._handle_torrent_selection( - request=request, - page_index=index, - channel=channel, - source=source, - userid=userid, - username=username, - ) - else: - self._handle_media_selection( - request=request, - page_index=index, - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - - if lowered in {"p", "prev", "上一页"}: - if not request: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - if request.page <= 0: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - title="已经是第一页了!", - ) - return True - request.page -= 1 - request.channel = channel - request.source = source - request.username = username - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - ) - return True - - if lowered in {"n", "next", "下一页"}: - if not request: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - if not self._has_next_page(request): - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - title="已经是最后一页了!", - ) - return True - request.page += 1 - request.channel = channel - request.source = source - request.username = username - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - ) - return True - - action, content = self._resolve_action(normalized) - if not action: - return False - - self._start_media_interaction( - action=action, - content=content, - channel=channel, - source=source, - userid=userid, - username=username, - ) - return True - - @staticmethod - def _resolve_action(text: str) -> Tuple[Optional[str], str]: - """ - 将用户输入归类为搜索、订阅或普通聊天。 - """ - if text.startswith("订阅"): - return "Subscribe", re.sub(r"订阅[::\s]*", "", text) - if text.startswith("洗版"): - return "ReSubscribe", re.sub(r"洗版[::\s]*", "", text) - if text.startswith("搜索") or text.startswith("下载"): - return "ReSearch", re.sub(r"(搜索|下载)[::\s]*", "", text) - if StringUtils.is_link(text): - return None, text - if not StringUtils.is_media_title_like(text): - return None, text - return "Search", text - - def _start_media_interaction( - self, - action: str, - content: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - ) -> None: - """ - 根据用户输入搜索媒体,并进入媒体选择阶段。 - """ - meta, medias = MediaChain().search(content) - if not meta.name: - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - title="无法识别输入内容!", - ) - return - if not medias: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"{meta.name} 没有找到对应的媒体信息!", - ) - ) - return - - logger.info("搜索到 %s 条相关媒体信息", len(medias)) - request = media_interaction_manager.create_or_replace( - user_id=userid, - channel=channel, - source=source, - username=username, - action=action, - keyword=content, - title=meta.name, - meta=meta, - items=medias, - ) - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - ) - - def _handle_media_selection( - self, - request: PendingMediaInteraction, - page_index: Optional[int], - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> None: - """ - 处理媒体选择阶段的序号输入。 - """ - page_items, page, _ = self._page_items( - items=request.items, - page=request.page, - page_size=self._page_size(request.channel), - ) - request.page = page - if not page_index or page_index < 1 or page_index > len(page_items): - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return - - mediainfo: MediaInfo = page_items[page_index - 1] - request.current_media = mediainfo - - if request.action in {"Search", "ReSearch"}: - self._search_media_resources( - request=request, - mediainfo=mediainfo, - channel=channel, - source=source, - userid=userid, - username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - return - - if request.action in {"Subscribe", "ReSubscribe"}: - self._subscribe_media( - request=request, - mediainfo=mediainfo, - channel=channel, - source=source, - userid=userid, - username=username, - ) - - def _search_media_resources( - self, - request: PendingMediaInteraction, - mediainfo: MediaInfo, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> None: - """ - 根据已选媒体搜索资源,并切换到资源选择阶段。 - """ - exist_flag, no_exists = DownloadChain().get_no_exists_info( - meta=request.meta, - mediainfo=mediainfo, - ) - if exist_flag and request.action == "Search": - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", - ) - ) - return - if exist_flag: - no_exists = self._get_noexits_info(request.meta, mediainfo) - - messages = self._build_no_exists_messages( - mediainfo=mediainfo, - no_exists=no_exists, - show_missing_only=request.action == "Search", - ) - if messages: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"{mediainfo.title_year}:\n" + "\n".join(messages), - ) - ) - - logger.info("开始搜索 %s ...", mediainfo.title_year) - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", - ) - ) - - contexts = SearchChain().process(mediainfo=mediainfo, no_exists=no_exists) - if not contexts: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"{mediainfo.title}{request.meta.sea} 未搜索到需要的资源!", - ) - ) - return - - contexts = TorrentHelper().sort_torrents(contexts) - if self._should_auto_download(userid): - logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid) - self._auto_download( - request=request, - cache_list=contexts, - channel=channel, - source=source, - userid=userid, - username=username, - no_exists=no_exists, - ) - return - - request.phase = "torrent" - request.page = 0 - request.title = mediainfo.title - request.items = list(contexts) - self._render_interaction( - request=request, - channel=channel, - source=source, - userid=userid, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - - def _subscribe_media( - self, - request: PendingMediaInteraction, - mediainfo: MediaInfo, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - ) -> None: - """ - 根据已选媒体创建订阅或洗版订阅。 - """ - best_version = request.action == "ReSubscribe" - if not best_version: - exist_flag, _ = DownloadChain().get_no_exists_info( - meta=request.meta, - mediainfo=mediainfo, - ) - if exist_flag: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", - ) - ) - return - - mp_name = ( - UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) - if channel - else None - ) - SubscribeChain().add( - title=mediainfo.title, - year=mediainfo.year, - mtype=mediainfo.type, - tmdbid=mediainfo.tmdb_id, - season=request.meta.begin_season, - channel=channel, - source=source, - userid=userid, - username=mp_name or username, - best_version=best_version, - ) - - def _handle_torrent_selection( - self, - request: PendingMediaInteraction, - page_index: Optional[int], - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - ) -> None: - """ - 处理资源选择阶段的下载操作。 - """ - if request.phase != "torrent": - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return - - if page_index == 0: - self._auto_download( - request=request, - cache_list=request.items, - channel=channel, - source=source, - userid=userid, - username=username, - ) - return - - page_items, page, _ = self._page_items( - items=request.items, - page=request.page, - page_size=self._page_size(request.channel), - ) - request.page = page - if not page_index or page_index < 1 or page_index > len(page_items): - self._post_invalid_input( - channel=channel, - source=source, - userid=userid, - username=username, - ) - return - - context: Context = page_items[page_index - 1] - DownloadChain().download_single( - context, - channel=channel, - source=source, - userid=userid, - username=username, - ) - - def _auto_download( - self, - request: PendingMediaInteraction, - cache_list: List[Context], - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, - ) -> None: - """ - 自动择优下载当前资源列表,并在未完成时补建订阅。 - """ - downloadchain = DownloadChain() - if no_exists is None: - exist_flag, no_exists = downloadchain.get_no_exists_info( - meta=request.meta, - mediainfo=request.current_media, - ) - if exist_flag: - no_exists = self._get_noexits_info(request.meta, request.current_media) - - downloads, lefts = downloadchain.batch_download( - contexts=cache_list, - no_exists=no_exists, - channel=channel, - source=source, - userid=userid, - username=username, - ) - if downloads and not lefts: - logger.info("%s 下载完成", request.current_media.title_year) - return - - logger.info("%s 未下载未完整,添加订阅 ...", request.current_media.title_year) - if downloads and request.current_media.type == MediaType.TV: - note = [ - download.meta_info.begin_episode - for download in downloads - if download.meta_info.begin_episode - ] - else: - note = None - - mp_name = ( - UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) - if channel - else None - ) - SubscribeChain().add( - title=request.current_media.title, - year=request.current_media.year, - mtype=request.current_media.type, - tmdbid=request.current_media.tmdb_id, - season=request.meta.begin_season, - channel=channel, - source=source, - userid=userid, - username=mp_name or username, - state="R", - note=note, - ) - - def _render_interaction( - self, - request: PendingMediaInteraction, - channel: MessageChannel, - source: str, - userid: Union[str, int], - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> None: - """ - 按当前阶段渲染媒体列表或资源列表。 - """ - if request.phase == "torrent": - self._post_torrents_message( - request=request, - channel=channel, - source=source, - userid=userid, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - else: - self._post_medias_message( - request=request, - channel=channel, - source=source, - userid=userid, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - - def _post_medias_message( - self, - request: PendingMediaInteraction, - channel: MessageChannel, - source: str, - userid: Union[str, int], - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> None: - """ - 发送或更新媒体选择列表。 - """ - page_items, page, total_pages = self._page_items( - items=request.items, - page=request.page, - page_size=self._page_size(channel), - ) - request.page = page - total = len(request.items) - if self._supports_interactive_buttons(channel): - title = f"【{request.title}】共找到{total}条相关信息,请选择操作" - buttons = self._create_media_buttons( - channel=channel, - request=request, - items=page_items, - total=total, - total_pages=total_pages, - ) - else: - if total > self._page_size(channel): - title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)" - else: - title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择" - buttons = None - - self.post_medias_message( - Notification( - channel=channel, - source=source, - title=title, - userid=userid, - buttons=buttons, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ), - medias=page_items, - ) - - def _post_torrents_message( - self, - request: PendingMediaInteraction, - channel: MessageChannel, - source: str, - userid: Union[str, int], - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ) -> None: - """ - 发送或更新资源选择列表。 - """ - page_items, page, total_pages = self._page_items( - items=request.items, - page=request.page, - page_size=self._page_size(channel), - ) - request.page = page - total = len(request.items) - if self._supports_interactive_buttons(channel): - title = f"【{request.title}】共找到{total}条相关资源,请选择下载" - buttons = self._create_torrent_buttons( - channel=channel, - request=request, - items=page_items, - total=total, - total_pages=total_pages, - ) - else: - if total > self._page_size(channel): - title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)" - else: - title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)" - buttons = None - - self.post_torrents_message( - Notification( - channel=channel, - source=source, - title=title, - userid=userid, - link=settings.MP_DOMAIN("#/resource"), - buttons=buttons, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ), - torrents=page_items, - ) - - def _create_media_buttons( - self, - channel: MessageChannel, - request: PendingMediaInteraction, - items: List[MediaInfo], - total: int, - total_pages: int, - ) -> List[List[Dict[str, str]]]: - """ - 为媒体列表生成选择和翻页按钮。 - """ - buttons: List[List[Dict[str, str]]] = [] - max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) - max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) - - current_row: List[Dict[str, str]] = [] - for index, media in enumerate(items, start=1): - if max_per_row == 1: - button_text = f"{index}. {media.title_year}" - if len(button_text) > max_text_length: - button_text = button_text[: max_text_length - 3] + "..." - buttons.append( - [ - { - "text": button_text, - "callback_data": f"media:{request.request_id}:select:{index}", - } - ] - ) - continue - - current_row.append( - { - "text": f"{index}", - "callback_data": f"media:{request.request_id}:select:{index}", - } - ) - if len(current_row) == max_per_row or index == len(items): - buttons.append(current_row) - current_row = [] - - if total > self._page_size(channel): - buttons.extend(self._navigation_buttons(request, total_pages)) - return buttons - - def _create_torrent_buttons( - self, - channel: MessageChannel, - request: PendingMediaInteraction, - items: List[Context], - total: int, - total_pages: int, - ) -> List[List[Dict[str, str]]]: - """ - 为资源列表生成下载和翻页按钮。 - """ - buttons: List[List[Dict[str, str]]] = [ - [ - { - "text": "🤖 自动选择下载", - "callback_data": f"media:{request.request_id}:download:0", - } - ] - ] - max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) - max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) - - current_row: List[Dict[str, str]] = [] - for index, context in enumerate(items, start=1): - torrent = context.torrent_info - if max_per_row == 1: - button_text = f"{index}. {torrent.site_name} - {torrent.seeders}↑" - if len(button_text) > max_text_length: - button_text = button_text[: max_text_length - 3] + "..." - buttons.append( - [ - { - "text": button_text, - "callback_data": f"media:{request.request_id}:download:{index}", - } - ] - ) - continue - - current_row.append( - { - "text": f"{index}", - "callback_data": f"media:{request.request_id}:download:{index}", - } - ) - if len(current_row) == max_per_row or index == len(items): - buttons.append(current_row) - current_row = [] - - if total > self._page_size(channel): - buttons.extend(self._navigation_buttons(request, total_pages)) - return buttons - - def _has_next_page(self, request: PendingMediaInteraction) -> bool: - """ - 判断当前视图是否还有下一页。 - """ - _, page, total_pages = self._page_items( - items=request.items, - page=request.page, - page_size=self._page_size(request.channel), - ) - return page < total_pages - 1 - - @staticmethod - def _navigation_buttons( - request: PendingMediaInteraction, - total_pages: int, - ) -> List[List[Dict[str, str]]]: - """ - 按当前页状态生成上一页和下一页按钮。 - """ - buttons: List[List[Dict[str, str]]] = [] - nav_row: List[Dict[str, str]] = [] - if request.page > 0: - nav_row.append( - { - "text": "⬅️ 上一页", - "callback_data": f"media:{request.request_id}:page-prev", - } - ) - if request.page < total_pages - 1: - nav_row.append( - { - "text": "下一页 ➡️", - "callback_data": f"media:{request.request_id}:page-next", - } - ) - if nav_row: - buttons.append(nav_row) - return buttons - - @staticmethod - def _page_items( - items: List[Any], - page: int, - page_size: int, - ) -> Tuple[List[Any], int, int]: - """ - 返回当前页数据,并把页码限制在有效范围内。 - """ - total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1 - page = min(max(0, page), total_pages - 1) - start = page * page_size - end = start + page_size - return items[start:end], page, total_pages - - def _page_size(self, channel: Optional[MessageChannel]) -> int: - """ - 按渠道交互能力选择分页大小。 - """ - return ( - self._button_page_size - if self._supports_interactive_buttons(channel) - else self._text_page_size - ) - - @staticmethod - def _supports_interactive_buttons(channel: Optional[MessageChannel]) -> bool: - """ - 判断渠道是否同时支持按钮展示与按钮回调。 - """ - return bool( - channel - and ChannelCapabilityManager.supports_buttons(channel) - and ChannelCapabilityManager.supports_callbacks(channel) - ) - - @staticmethod - def _build_no_exists_messages( - mediainfo: MediaInfo, - no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]], - show_missing_only: bool, - ) -> List[str]: - """ - 将缺失集信息转换为可发送的文案。 - """ - if not no_exists: - return [] - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - season_map = no_exists.get(mediakey) or {} - if show_missing_only: - return [ - f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" - for sea, no_exist in season_map.items() - ] - return [ - f"第 {sea} 季总 {no_exist.total_episode} 集" - for sea, no_exist in season_map.items() - ] - - @staticmethod - def _should_auto_download(userid: Union[str, int]) -> bool: - """ - 判断当前用户是否命中自动下载名单。 - """ - auto_download_user = settings.AUTO_DOWNLOAD_USER - return bool( - auto_download_user - and ( - auto_download_user == "all" - or any(userid == user for user in auto_download_user.split(",")) - ) - ) - - def _post_invalid_input( - self, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: Optional[str], - title: str = "输入有误!", - ) -> None: - """ - 发送统一的非法输入提示。 - """ - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=title, - ) - ) diff --git a/app/chain/media.py b/app/chain/media.py index 08363f26..5833b56e 100644 --- a/app/chain/media.py +++ b/app/chain/media.py @@ -24,9 +24,9 @@ from app.schemas.types import ( ScrapingPolicy, SystemConfigKey, ) +from app.utils.http import RequestUtils from app.utils.mixins import ConfigReloadMixin from app.utils.singleton import Singleton -from app.utils.http import RequestUtils from app.utils.string import StringUtils recognize_lock = Lock() @@ -44,10 +44,10 @@ class ScrapingOption: policy: ScrapingPolicy = ScrapingPolicy.MISSINGONLY def __init__( - self, - type: Union[str, ScrapingTarget], - metadata: Union[str, ScrapingMetadata], - value: Union[ScrapingPolicy, bool, str], + self, + type: Union[str, ScrapingTarget], + metadata: Union[str, ScrapingMetadata], + value: Union[ScrapingPolicy, bool, str], ): if isinstance(type, ScrapingTarget): self.type = type @@ -105,7 +105,7 @@ class ScrapingConfig: self._policies[tuple(items)] = ScrapingOption(*items, value) def option( - self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata] + self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata] ) -> ScrapingOption: if isinstance(item, ScrapingTarget): @@ -173,11 +173,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): def on_config_changed(self): self.scraping_policies = ScrapingConfig.from_system_config() + @staticmethod def _should_scrape( - self, - scraping_option: ScrapingOption, - file_exists: bool, - global_overwrite: bool = False, + scraping_option: ScrapingOption, + file_exists: bool, + global_overwrite: bool = False, ) -> bool: """ 判断是否应该执行刮削操作 @@ -211,7 +211,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return False def _save_file( - self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str] + self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str] ): """ 保存或上传文件 @@ -224,7 +224,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return # 使用tempfile创建临时文件 with NamedTemporaryFile( - delete=True, delete_on_close=False, suffix=path.suffix + delete=True, delete_on_close=False, suffix=path.suffix ) as tmp_file: tmp_file_path = Path(tmp_file.name) # 写入内容 @@ -248,7 +248,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): logger.warn(f"文件保存失败:{path}") def _download_and_save_image( - self, fileitem: schemas.FileItem, path: Path, url: str + self, fileitem: schemas.FileItem, path: Path, url: str ): """ 流式下载图片并保存到文件 @@ -268,7 +268,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): if r and r.status_code == 200: # 使用tempfile创建临时文件,自动删除 with NamedTemporaryFile( - delete=True, delete_on_close=False, suffix=path.suffix + delete=True, delete_on_close=False, suffix=path.suffix ) as tmp_file: tmp_file_path = Path(tmp_file.name) # 流式写入文件 @@ -295,12 +295,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): logger.error(f"{url} 图片下载失败:{str(err)}!") def _get_target_fileitem_and_path( - self, - current_fileitem: schemas.FileItem, - item_type: ScrapingTarget, - metadata_type: ScrapingMetadata, - filename_hint: Optional[str] = None, - parent_fileitem: Optional[schemas.FileItem] = None, + self, + current_fileitem: schemas.FileItem, + item_type: ScrapingTarget, + metadata_type: ScrapingMetadata, + filename_hint: Optional[str] = None, + parent_fileitem: Optional[schemas.FileItem] = None, ) -> Tuple[schemas.FileItem, Optional[Path]]: """ 根据当前上下文、刮削项类型和元数据类型生成目标 FileItem 和 Path @@ -318,8 +318,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 电影文件NFO: 放在电影文件同级目录,名称与电影文件主体一致,后缀.nfo final_filename = f"{target_dir_path.stem}.nfo" target_dir_item = ( - parent_fileitem - or self.storagechain.get_parent_item(current_fileitem) + parent_fileitem + or self.storagechain.get_parent_item(current_fileitem) ) if not target_dir_item: logger.error( @@ -354,8 +354,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 图片通常是放在当前目录 (current_fileitem) 下 # 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致 elif ( - metadata_type in [ScrapingMetadata.THUMB] - and item_type == ScrapingTarget.EPISODE + metadata_type in [ScrapingMetadata.THUMB] + and item_type == ScrapingTarget.EPISODE ): hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg" final_filename = f"{target_dir_path.stem}{hint_ext}" @@ -380,11 +380,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return target_dir_item, target_full_path def metadata_nfo( - self, - meta: MetaBase, - mediainfo: MediaInfo, - season: Optional[int] = None, - episode: Optional[int] = None, + self, + meta: MetaBase, + mediainfo: MediaInfo, + season: Optional[int] = None, + episode: Optional[int] = None, ) -> Optional[str]: """ 获取NFO文件内容文本 @@ -402,8 +402,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): episode=episode, ) + @staticmethod def select_recognize_source( - self, log_name: str, log_context: str, native_fn, plugin_fn + log_name: str, log_context: str, native_fn, plugin_fn ) -> Optional[MediaInfo]: """ 选择识别模式,插件优先或原生优先 @@ -436,7 +437,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return mediainfo def recognize_by_meta( - self, metainfo: MetaBase, episode_group: Optional[str] = None + self, metainfo: MetaBase, episode_group: Optional[str] = None ) -> Optional[MediaInfo]: """ 根据主副标题识别媒体信息 @@ -513,7 +514,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return self.recognize_media(meta=org_meta) def recognize_by_path( - self, path: str, episode_group: Optional[str] = None + self, path: str, episode_group: Optional[str] = None ) -> Optional[Context]: """ 根据文件路径识别媒体信息 @@ -577,7 +578,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return meta, medias def get_tmdbinfo_by_doubanid( - self, doubanid: str, mtype: MediaType = None + self, doubanid: str, mtype: MediaType = None ) -> Optional[dict]: """ 根据豆瓣ID获取TMDB信息 @@ -648,7 +649,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None def get_doubaninfo_by_tmdbid( - self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None + self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None ) -> Optional[dict]: """ 根据TMDBID获取豆瓣信息 @@ -752,8 +753,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 收集从根目录到文件的所有父目录 current_path = sub_path.parent while ( - current_path != root_path - and current_path.is_relative_to(root_path) + current_path != root_path + and current_path.is_relative_to(root_path) ): all_dirs.add(current_path) current_path = current_path.parent @@ -805,15 +806,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def _scrape_nfo_generic( - self, - current_fileitem: schemas.FileItem, - meta: MetaBase, - mediainfo: MediaInfo, - item_type: ScrapingTarget, - parent_fileitem: Optional[schemas.FileItem] = None, - overwrite: bool = False, - season_number: Optional[int] = None, - episode_number: Optional[int] = None, + self, + current_fileitem: schemas.FileItem, + meta: MetaBase, + mediainfo: MediaInfo, + item_type: ScrapingTarget, + parent_fileitem: Optional[schemas.FileItem] = None, + overwrite: bool = False, + season_number: Optional[int] = None, + episode_number: Optional[int] = None, ): """ NFO 刮削 @@ -859,14 +860,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): logger.warn(f"{nfo_path.name} NFO 文件生成失败!") def _scrape_images_generic( - self, - current_fileitem: schemas.FileItem, - mediainfo: MediaInfo, - item_type: ScrapingTarget, - parent_fileitem: Optional[schemas.FileItem] = None, - overwrite: bool = False, - season_number: Optional[int] = None, - episode_number: Optional[int] = None, + self, + current_fileitem: schemas.FileItem, + mediainfo: MediaInfo, + item_type: ScrapingTarget, + parent_fileitem: Optional[schemas.FileItem] = None, + overwrite: bool = False, + season_number: Optional[int] = None, + episode_number: Optional[int] = None, ): """ 图片刮削 @@ -906,14 +907,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 判断是否匹配当前刮削的季号 if item_type == ScrapingTarget.TV and image_name.lower().startswith( - "season" + "season" ): logger.info(f"当前为电视剧根目录刮削,跳过季图片:{image_name}") continue if ( - item_type == ScrapingTarget.SEASON - and season_number is not None - and image_name.lower().startswith("season") + item_type == ScrapingTarget.SEASON + and season_number is not None + and image_name.lower().startswith("season") ): # 检查是否只下载当前刮削季的图片 image_season_str = ( @@ -921,7 +922,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) if image_season_str is not None and image_season_str != str( - season_number + season_number ).rjust(2, "0"): logger.info( f"当前刮削季为:{season_number},跳过非本季图片:{image_name}" @@ -956,14 +957,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def scrape_metadata( - self, - fileitem: schemas.FileItem, - meta: MetaBase = None, - mediainfo: MediaInfo = None, - init_folder: bool = True, - parent: schemas.FileItem = None, - overwrite: bool = False, - recursive: bool = True, + self, + fileitem: schemas.FileItem, + meta: MetaBase = None, + mediainfo: MediaInfo = None, + init_folder: bool = True, + parent: schemas.FileItem = None, + overwrite: bool = False, + recursive: bool = True, ): """ 手动刮削媒体信息 @@ -982,7 +983,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 当前文件路径 filepath = Path(fileitem.path) if fileitem.type == "file" and ( - not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT + not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT ): return @@ -1022,14 +1023,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): logger.info(f"{filepath.name} 刮削完成") def _handle_movie_scraping( - self, - fileitem: schemas.FileItem, - meta: MetaBase, - mediainfo: MediaInfo, - init_folder: bool, - parent: schemas.FileItem, - overwrite: bool, - recursive: bool, + self, + fileitem: schemas.FileItem, + meta: MetaBase, + mediainfo: MediaInfo, + init_folder: bool, + parent: schemas.FileItem, + overwrite: bool, + recursive: bool, ): """ 处理电影刮削 @@ -1051,20 +1052,18 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): meta=meta, mediainfo=mediainfo, init_folder=init_folder, - parent=parent, overwrite=overwrite, recursive=recursive, ) def _handle_movie_directory( - self, - fileitem: schemas.FileItem, - meta: MetaBase, - mediainfo: MediaInfo, - init_folder: bool, - parent: schemas.FileItem, - overwrite: bool, - recursive: bool, + self, + fileitem: schemas.FileItem, + meta: MetaBase, + mediainfo: MediaInfo, + init_folder: bool, + overwrite: bool, + recursive: bool, ): """ 处理电影目录刮削 @@ -1105,14 +1104,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def _handle_tv_scraping( - self, - fileitem: schemas.FileItem, - meta: MetaBase, - mediainfo: MediaInfo, - init_folder: bool, - parent: schemas.FileItem, - overwrite: bool, - recursive: bool, + self, + fileitem: schemas.FileItem, + meta: MetaBase, + mediainfo: MediaInfo, + init_folder: bool, + parent: schemas.FileItem, + overwrite: bool, + recursive: bool, ): """ 处理电视剧刮削 @@ -1142,12 +1141,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def _handle_tv_episode_file( - self, - fileitem: schemas.FileItem, - filepath: Path, - mediainfo: MediaInfo, - parent: schemas.FileItem, - overwrite: bool, + self, + fileitem: schemas.FileItem, + filepath: Path, + mediainfo: MediaInfo, + parent: schemas.FileItem, + overwrite: bool, ): """ 处理电视剧集文件刮削 @@ -1191,15 +1190,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def _handle_tv_directory( - self, - fileitem: schemas.FileItem, - filepath: Path, - meta: MetaBase, - mediainfo: MediaInfo, - init_folder: bool, - parent: schemas.FileItem, - overwrite: bool, - recursive: bool, + self, + fileitem: schemas.FileItem, + filepath: Path, + meta: MetaBase, + mediainfo: MediaInfo, + init_folder: bool, + parent: schemas.FileItem, + overwrite: bool, + recursive: bool, ): """ 处理电视剧目录刮削 @@ -1209,9 +1208,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): files = self.storagechain.list_files(fileitem=fileitem) or [] for file in files: if ( - file.type == "dir" - and file.name not in settings.RENAME_FORMAT_S0_NAMES - and MetaInfo(file.name).begin_season is None + file.type == "dir" + and file.name not in settings.RENAME_FORMAT_S0_NAMES + and MetaInfo(file.name).begin_season is None ): # 电视剧不处理非季子目录 continue @@ -1235,13 +1234,13 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def _initialize_tv_directory_metadata( - self, - fileitem: schemas.FileItem, - filepath: Path, - meta: MetaBase, - mediainfo: MediaInfo, - parent: schemas.FileItem, - overwrite: bool, + self, + fileitem: schemas.FileItem, + filepath: Path, + meta: MetaBase, + mediainfo: MediaInfo, + parent: schemas.FileItem, + overwrite: bool, ): """ 初始化电视剧目录元数据(识别季号并刮削) @@ -1296,8 +1295,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): else: logger.warn("无法识别元数据,跳过") + @staticmethod async def async_select_recognize_source( - self, log_name: str, log_context: str, native_fn, plugin_fn + log_name: str, log_context: str, native_fn, plugin_fn ) -> Optional[MediaInfo]: """ 选择识别模式,插件优先或原生优先(异步版本) @@ -1330,7 +1330,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return mediainfo async def async_recognize_by_meta( - self, metainfo: MetaBase, episode_group: Optional[str] = None + self, metainfo: MetaBase, episode_group: Optional[str] = None ) -> Optional[MediaInfo]: """ 根据主副标题识别媒体信息(异步版本) @@ -1366,7 +1366,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return mediainfo async def async_recognize_help( - self, title: str, org_meta: MetaBase + self, title: str, org_meta: MetaBase ) -> Optional[MediaInfo]: """ 请求辅助识别,返回媒体信息(异步版本) @@ -1417,7 +1417,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return await self.async_recognize_media(meta=org_meta) async def async_recognize_by_path( - self, path: str, episode_group: Optional[str] = None + self, path: str, episode_group: Optional[str] = None ) -> Optional[Context]: """ 根据文件路径识别媒体信息(异步版本) @@ -1455,7 +1455,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return Context(meta_info=file_meta, media_info=mediainfo) async def async_search( - self, title: str + self, title: str ) -> Tuple[Optional[MetaBase], List[MediaInfo]]: """ 搜索媒体/人物信息(异步版本) @@ -1502,7 +1502,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): @staticmethod def _extract_year_from_tmdb( - tmdbinfo: dict, season: Optional[int] = None + tmdbinfo: dict, season: Optional[int] = None ) -> Optional[str]: """ 从TMDB信息中提取年份 @@ -1522,11 +1522,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return year def _match_tmdb_with_names( - self, - meta_names: list, - year: Optional[str], - mtype: MediaType, - season: Optional[int] = None, + self, + meta_names: list, + year: Optional[str], + mtype: MediaType, + season: Optional[int] = None, ) -> Optional[dict]: """ 使用名称列表匹配TMDB信息 @@ -1540,11 +1540,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None async def _async_match_tmdb_with_names( - self, - meta_names: list, - year: Optional[str], - mtype: MediaType, - season: Optional[int] = None, + self, + meta_names: list, + year: Optional[str], + mtype: MediaType, + season: Optional[int] = None, ) -> Optional[dict]: """ 使用名称列表匹配TMDB信息(异步版本) @@ -1558,7 +1558,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None async def async_get_tmdbinfo_by_doubanid( - self, doubanid: str, mtype: MediaType = None + self, doubanid: str, mtype: MediaType = None ) -> Optional[dict]: """ 根据豆瓣ID获取TMDB信息(异步版本) @@ -1629,7 +1629,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None async def async_get_doubaninfo_by_tmdbid( - self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None + self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None ) -> Optional[dict]: """ 根据TMDBID获取豆瓣信息(异步版本) diff --git a/app/chain/message.py b/app/chain/message.py index 2f9c25cf..5e9d44eb 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1,35 +1,40 @@ import asyncio import base64 +import math import mimetypes import re import time import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Optional, Dict, Union, List +from typing import Any, Optional, Dict, Union, List, Tuple from urllib.parse import unquote, urlparse from app.agent import ReplyMode, agent_manager, prompt_manager +from app.agent.llm import LLMHelper from app.chain import ChainBase -from app.chain.interaction import ( - MediaInteractionChain, - agent_interaction_manager, - media_interaction_manager, -) +from app.chain.download import DownloadChain +from app.chain.media import MediaChain +from app.chain.search import SearchChain from app.chain.site import SiteChain, site_interaction_manager from app.chain.skills import SkillsChain, skills_interaction_manager from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager from app.chain.transfer import TransferChain from app.core.config import settings, global_vars +from app.core.context import MediaInfo, Context +from app.core.meta import MetaBase from app.db.models import TransferHistory from app.db.transferhistory_oper import TransferHistoryOper -from app.agent.llm import LLMHelper +from app.db.user_oper import UserOper +from app.helper.interaction import agent_interaction_manager, media_interaction_manager, PendingMediaInteraction +from app.helper.torrent import TorrentHelper from app.helper.voice import VoiceHelper from app.log import logger -from app.schemas import Notification, CommingMessage +from app.schemas import Notification, CommingMessage, NotExistMediaInfo from app.schemas.message import ChannelCapabilityManager -from app.schemas.types import EventType, MessageChannel +from app.schemas.types import EventType, MessageChannel, MediaType from app.utils.http import RequestUtils +from app.utils.string import StringUtils class MessageChain(ChainBase): @@ -175,31 +180,31 @@ class MessageChain(ChainBase): latest_slash_interaction = self._get_latest_slash_interaction(userid) if latest_slash_interaction == "sites": if SiteChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return if latest_slash_interaction == "subscribes": if SubscribeChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return if latest_slash_interaction == "skills": if SkillsChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return @@ -378,9 +383,9 @@ class MessageChain(ChainBase): """ candidates = [] for name, manager in ( - ("sites", site_interaction_manager), - ("subscribes", subscribe_interaction_manager), - ("skills", skills_interaction_manager), + ("sites", site_interaction_manager), + ("subscribes", subscribe_interaction_manager), + ("skills", skills_interaction_manager), ): request = manager.get_by_user(userid) if request: @@ -1567,3 +1572,1092 @@ class MessageChain(ChainBase): except Exception as e: logger.error(e) return None + + +class MediaInteractionChain(ChainBase): + """ + 处理媒体搜索、订阅、资源选择和翻页等交互流程。 + """ + + _button_page_size = 8 + _text_page_size = 8 + + @staticmethod + def has_pending_interaction(user_id: Union[str, int]) -> bool: + """ + 判断用户当前是否存在未结束的媒体交互。 + """ + return media_interaction_manager.get_by_user(user_id) is not None + + @staticmethod + def _get_noexits_info( + meta: MetaBase, mediainfo: MediaInfo + ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]: + """ + 构造媒体缺失集信息,用于全量重搜或自动下载补全集数。 + """ + if mediainfo.type == MediaType.TV: + if not mediainfo.seasons: + mediainfo = MediaChain().recognize_media( + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + doubanid=mediainfo.douban_id, + cache=False, + ) + if not mediainfo: + logger.warn("媒体信息识别失败,无法补充季集信息") + return {} + if not mediainfo.seasons: + logger.warn( + "媒体信息中没有季集信息,标题:%s,tmdbid:%s,doubanid:%s", + mediainfo.title, + mediainfo.tmdb_id, + mediainfo.douban_id, + ) + return {} + + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + no_exists = {mediakey: {}} + if meta.begin_season: + episodes = mediainfo.seasons.get(meta.begin_season) + if not episodes: + return {} + no_exists[mediakey][meta.begin_season] = NotExistMediaInfo( + season=meta.begin_season, + episodes=[], + total_episode=len(episodes), + start_episode=episodes[0], + ) + else: + for sea, eps in mediainfo.seasons.items(): + if not eps: + continue + no_exists[mediakey][sea] = NotExistMediaInfo( + season=sea, + episodes=[], + total_episode=len(eps), + start_episode=eps[0], + ) + return no_exists + return {} + + @staticmethod + def parse_callback( + callback_data: str, + ) -> Optional[Tuple[Optional[str], str, Optional[int]]]: + """ + 解析新旧两种媒体交互按钮格式。 + """ + if callback_data.startswith("media:"): + parts = callback_data.split(":") + if len(parts) < 3: + return None + request_id = parts[1] + action = parts[2] + index = None + if len(parts) >= 4 and parts[3].isdigit(): + index = int(parts[3]) + return request_id, action, index + + match = re.match(r"^(select|download)_(\d+)$", callback_data) + if match: + return None, match.group(1), int(match.group(2)) + if callback_data == "page_p": + return None, "page-prev", None + if callback_data == "page_n": + return None, "page-next", None + return None + + def handle_callback_interaction( + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> bool: + """ + 处理按钮回调,并将当前视图刷新到原消息上。 + """ + parsed = self.parse_callback(callback_data) + if not parsed: + return False + + request_id, action, index = parsed + if request_id: + request = media_interaction_manager.get_by_id(request_id, userid) + else: + request = media_interaction_manager.get_by_user(userid) + + if not request: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="交互已失效,请重新搜索或订阅", + ) + ) + return True + + request.channel = channel + request.source = source + request.username = username + + if action == "page-prev": + if request.page <= 0: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是第一页了!", + ) + return True + request.page -= 1 + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "page-next": + if not self._has_next_page(request): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是最后一页了!", + ) + return True + request.page += 1 + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "select": + self._handle_media_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "download": + self._handle_torrent_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + return False + + def handle_text_interaction( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + ) -> bool: + """ + 处理文本式交互。 + + 有会话时优先处理数字选择和翻页;无会话时负责识别搜索/订阅类入口。 + """ + request = media_interaction_manager.get_by_user(userid) + normalized = (text or "").strip() + lowered = normalized.lower() + + if request and lowered in {"退出", "关闭", "q", "quit", "exit"}: + media_interaction_manager.remove(request.request_id) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="媒体交互已结束", + ) + ) + return True + + if normalized.isdigit(): + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + request.channel = channel + request.source = source + request.username = username + index = int(normalized) + if request.phase == "torrent": + self._handle_torrent_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + else: + self._handle_media_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"p", "prev", "上一页"}: + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if request.page <= 0: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是第一页了!", + ) + return True + request.page -= 1 + request.channel = channel + request.source = source + request.username = username + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + return True + + if lowered in {"n", "next", "下一页"}: + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if not self._has_next_page(request): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是最后一页了!", + ) + return True + request.page += 1 + request.channel = channel + request.source = source + request.username = username + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + return True + + action, content = self._resolve_action(normalized) + if not action: + return False + + self._start_media_interaction( + action=action, + content=content, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + @staticmethod + def _resolve_action(text: str) -> Tuple[Optional[str], str]: + """ + 将用户输入归类为搜索、订阅或普通聊天。 + """ + if text.startswith("订阅"): + return "Subscribe", re.sub(r"订阅[::\s]*", "", text) + if text.startswith("洗版"): + return "ReSubscribe", re.sub(r"洗版[::\s]*", "", text) + if text.startswith("搜索") or text.startswith("下载"): + return "ReSearch", re.sub(r"(搜索|下载)[::\s]*", "", text) + if StringUtils.is_link(text): + return None, text + if not StringUtils.is_media_title_like(text): + return None, text + return "Search", text + + def _start_media_interaction( + self, + action: str, + content: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 根据用户输入搜索媒体,并进入媒体选择阶段。 + """ + meta, medias = MediaChain().search(content) + if not meta.name: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="无法识别输入内容!", + ) + return + if not medias: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{meta.name} 没有找到对应的媒体信息!", + ) + ) + return + + logger.info("搜索到 %s 条相关媒体信息", len(medias)) + request = media_interaction_manager.create_or_replace( + user_id=userid, + channel=channel, + source=source, + username=username, + action=action, + keyword=content, + title=meta.name, + meta=meta, + items=medias, + ) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + + def _handle_media_selection( + self, + request: PendingMediaInteraction, + page_index: Optional[int], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 处理媒体选择阶段的序号输入。 + """ + page_items, page, _ = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + request.page = page + if not page_index or page_index < 1 or page_index > len(page_items): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + mediainfo: MediaInfo = page_items[page_index - 1] + request.current_media = mediainfo + + if request.action in {"Search", "ReSearch"}: + self._search_media_resources( + request=request, + mediainfo=mediainfo, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return + + if request.action in {"Subscribe", "ReSubscribe"}: + self._subscribe_media( + request=request, + mediainfo=mediainfo, + channel=channel, + source=source, + userid=userid, + username=username, + ) + + def _search_media_resources( + self, + request: PendingMediaInteraction, + mediainfo: MediaInfo, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 根据已选媒体搜索资源,并切换到资源选择阶段。 + """ + exist_flag, no_exists = DownloadChain().get_no_exists_info( + meta=request.meta, + mediainfo=mediainfo, + ) + if exist_flag and request.action == "Search": + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", + ) + ) + return + if exist_flag: + no_exists = self._get_noexits_info(request.meta, mediainfo) + + messages = self._build_no_exists_messages( + mediainfo=mediainfo, + no_exists=no_exists, + show_missing_only=request.action == "Search", + ) + if messages: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{mediainfo.title_year}:\n" + "\n".join(messages), + ) + ) + + logger.info("开始搜索 %s ...", mediainfo.title_year) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", + ) + ) + + contexts = SearchChain().process(mediainfo=mediainfo, no_exists=no_exists) + if not contexts: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{mediainfo.title}{request.meta.sea} 未搜索到需要的资源!", + ) + ) + return + + contexts = TorrentHelper().sort_torrents(contexts) + if self._should_auto_download(userid): + logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid) + self._auto_download( + request=request, + cache_list=contexts, + channel=channel, + source=source, + userid=userid, + username=username, + no_exists=no_exists, + ) + return + + request.phase = "torrent" + request.page = 0 + request.title = mediainfo.title + request.items = list(contexts) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _subscribe_media( + self, + request: PendingMediaInteraction, + mediainfo: MediaInfo, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 根据已选媒体创建订阅或洗版订阅。 + """ + best_version = request.action == "ReSubscribe" + if not best_version: + exist_flag, _ = DownloadChain().get_no_exists_info( + meta=request.meta, + mediainfo=mediainfo, + ) + if exist_flag: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", + ) + ) + return + + mp_name = ( + UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) + if channel + else None + ) + SubscribeChain().add( + title=mediainfo.title, + year=mediainfo.year, + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + season=request.meta.begin_season, + channel=channel, + source=source, + userid=userid, + username=mp_name or username, + best_version=best_version, + ) + + def _handle_torrent_selection( + self, + request: PendingMediaInteraction, + page_index: Optional[int], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 处理资源选择阶段的下载操作。 + """ + if request.phase != "torrent": + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + if page_index == 0: + self._auto_download( + request=request, + cache_list=request.items, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + page_items, page, _ = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + request.page = page + if not page_index or page_index < 1 or page_index > len(page_items): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + context: Context = page_items[page_index - 1] + DownloadChain().download_single( + context, + channel=channel, + source=source, + userid=userid, + username=username, + ) + + def _auto_download( + self, + request: PendingMediaInteraction, + cache_list: List[Context], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, + ) -> None: + """ + 自动择优下载当前资源列表,并在未完成时补建订阅。 + """ + downloadchain = DownloadChain() + if no_exists is None: + exist_flag, no_exists = downloadchain.get_no_exists_info( + meta=request.meta, + mediainfo=request.current_media, + ) + if exist_flag: + no_exists = self._get_noexits_info(request.meta, request.current_media) + + downloads, lefts = downloadchain.batch_download( + contexts=cache_list, + no_exists=no_exists, + channel=channel, + source=source, + userid=userid, + username=username, + ) + if downloads and not lefts: + logger.info("%s 下载完成", request.current_media.title_year) + return + + logger.info("%s 未下载未完整,添加订阅 ...", request.current_media.title_year) + if downloads and request.current_media.type == MediaType.TV: + note = [ + download.meta_info.begin_episode + for download in downloads + if download.meta_info.begin_episode + ] + else: + note = None + + mp_name = ( + UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) + if channel + else None + ) + SubscribeChain().add( + title=request.current_media.title, + year=request.current_media.year, + mtype=request.current_media.type, + tmdbid=request.current_media.tmdb_id, + season=request.meta.begin_season, + channel=channel, + source=source, + userid=userid, + username=mp_name or username, + state="R", + note=note, + ) + + def _render_interaction( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 按当前阶段渲染媒体列表或资源列表。 + """ + if request.phase == "torrent": + self._post_torrents_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + else: + self._post_medias_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _post_medias_message( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 发送或更新媒体选择列表。 + """ + page_items, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(channel), + ) + request.page = page + total = len(request.items) + if self._supports_interactive_buttons(channel): + title = f"【{request.title}】共找到{total}条相关信息,请选择操作" + buttons = self._create_media_buttons( + channel=channel, + request=request, + items=page_items, + total=total, + total_pages=total_pages, + ) + else: + if total > self._page_size(channel): + title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)" + else: + title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择" + buttons = None + + self.post_medias_message( + Notification( + channel=channel, + source=source, + title=title, + userid=userid, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ), + medias=page_items, + ) + + def _post_torrents_message( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 发送或更新资源选择列表。 + """ + page_items, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(channel), + ) + request.page = page + total = len(request.items) + if self._supports_interactive_buttons(channel): + title = f"【{request.title}】共找到{total}条相关资源,请选择下载" + buttons = self._create_torrent_buttons( + channel=channel, + request=request, + items=page_items, + total=total, + total_pages=total_pages, + ) + else: + if total > self._page_size(channel): + title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)" + else: + title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)" + buttons = None + + self.post_torrents_message( + Notification( + channel=channel, + source=source, + title=title, + userid=userid, + link=settings.MP_DOMAIN("#/resource"), + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ), + torrents=page_items, + ) + + def _create_media_buttons( + self, + channel: MessageChannel, + request: PendingMediaInteraction, + items: List[MediaInfo], + total: int, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 为媒体列表生成选择和翻页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [] + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + + current_row: List[Dict[str, str]] = [] + for index, media in enumerate(items, start=1): + if max_per_row == 1: + button_text = f"{index}. {media.title_year}" + if len(button_text) > max_text_length: + button_text = button_text[: max_text_length - 3] + "..." + buttons.append( + [ + { + "text": button_text, + "callback_data": f"media:{request.request_id}:select:{index}", + } + ] + ) + continue + + current_row.append( + { + "text": f"{index}", + "callback_data": f"media:{request.request_id}:select:{index}", + } + ) + if len(current_row) == max_per_row or index == len(items): + buttons.append(current_row) + current_row = [] + + if total > self._page_size(channel): + buttons.extend(self._navigation_buttons(request, total_pages)) + return buttons + + def _create_torrent_buttons( + self, + channel: MessageChannel, + request: PendingMediaInteraction, + items: List[Context], + total: int, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 为资源列表生成下载和翻页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [ + [ + { + "text": "🤖 自动选择下载", + "callback_data": f"media:{request.request_id}:download:0", + } + ] + ] + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + + current_row: List[Dict[str, str]] = [] + for index, context in enumerate(items, start=1): + torrent = context.torrent_info + if max_per_row == 1: + button_text = f"{index}. {torrent.site_name} - {torrent.seeders}↑" + if len(button_text) > max_text_length: + button_text = button_text[: max_text_length - 3] + "..." + buttons.append( + [ + { + "text": button_text, + "callback_data": f"media:{request.request_id}:download:{index}", + } + ] + ) + continue + + current_row.append( + { + "text": f"{index}", + "callback_data": f"media:{request.request_id}:download:{index}", + } + ) + if len(current_row) == max_per_row or index == len(items): + buttons.append(current_row) + current_row = [] + + if total > self._page_size(channel): + buttons.extend(self._navigation_buttons(request, total_pages)) + return buttons + + def _has_next_page(self, request: PendingMediaInteraction) -> bool: + """ + 判断当前视图是否还有下一页。 + """ + _, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + return page < total_pages - 1 + + @staticmethod + def _navigation_buttons( + request: PendingMediaInteraction, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 按当前页状态生成上一页和下一页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [] + nav_row: List[Dict[str, str]] = [] + if request.page > 0: + nav_row.append( + { + "text": "⬅️ 上一页", + "callback_data": f"media:{request.request_id}:page-prev", + } + ) + if request.page < total_pages - 1: + nav_row.append( + { + "text": "下一页 ➡️", + "callback_data": f"media:{request.request_id}:page-next", + } + ) + if nav_row: + buttons.append(nav_row) + return buttons + + @staticmethod + def _page_items( + items: List[Any], + page: int, + page_size: int, + ) -> Tuple[List[Any], int, int]: + """ + 返回当前页数据,并把页码限制在有效范围内。 + """ + total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1 + page = min(max(0, page), total_pages - 1) + start = page * page_size + end = start + page_size + return items[start:end], page, total_pages + + def _page_size(self, channel: Optional[MessageChannel]) -> int: + """ + 按渠道交互能力选择分页大小。 + """ + return ( + self._button_page_size + if self._supports_interactive_buttons(channel) + else self._text_page_size + ) + + @staticmethod + def _supports_interactive_buttons(channel: Optional[MessageChannel]) -> bool: + """ + 判断渠道是否同时支持按钮展示与按钮回调。 + """ + return bool( + channel + and ChannelCapabilityManager.supports_buttons(channel) + and ChannelCapabilityManager.supports_callbacks(channel) + ) + + @staticmethod + def _build_no_exists_messages( + mediainfo: MediaInfo, + no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]], + show_missing_only: bool, + ) -> List[str]: + """ + 将缺失集信息转换为可发送的文案。 + """ + if not no_exists: + return [] + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + season_map = no_exists.get(mediakey) or {} + if show_missing_only: + return [ + f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" + for sea, no_exist in season_map.items() + ] + return [ + f"第 {sea} 季总 {no_exist.total_episode} 集" + for sea, no_exist in season_map.items() + ] + + @staticmethod + def _should_auto_download(userid: Union[str, int]) -> bool: + """ + 判断当前用户是否命中自动下载名单。 + """ + auto_download_user = settings.AUTO_DOWNLOAD_USER + return bool( + auto_download_user + and ( + auto_download_user == "all" + or any(userid == user for user in auto_download_user.split(",")) + ) + ) + + def _post_invalid_input( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: Optional[str], + title: str = "输入有误!", + ) -> None: + """ + 发送统一的非法输入提示。 + """ + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=title, + ) + ) diff --git a/app/chain/site.py b/app/chain/site.py index 54c6961f..e11e2a39 100644 --- a/app/chain/site.py +++ b/app/chain/site.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin from lxml import etree from app.chain import ChainBase -from app.helper.slash import ( +from app.helper.interaction import ( SlashInteractionManager, build_navigation_buttons, format_markdown_table, @@ -1060,8 +1060,9 @@ class SiteChain(ChainBase): original_chat_id=original_chat_id, ) + @staticmethod def _format_site_list( - self, site_list: List[Site], channel: Optional[MessageChannel] + site_list: List[Site], channel: Optional[MessageChannel] ) -> str: """ 根据渠道能力格式化站点列表。 diff --git a/app/chain/skills.py b/app/chain/skills.py index 648e324b..1625cdcb 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -1,145 +1,18 @@ import re -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from threading import Lock -from typing import Dict, List, Optional, Tuple, Union -import uuid +from typing import List, Optional, Tuple, Union from app.chain import ChainBase -from app.helper.slash import ( +from app.helper.interaction import ( build_navigation_buttons, page_items, supports_interaction_buttons, - update_or_post_message, + update_or_post_message, skills_interaction_manager, PendingSkillsInteraction, ) from app.helper.skill import SkillHelper, SkillInfo from app.schemas import Notification from app.schemas.types import MessageChannel -@dataclass -class PendingSkillsInteraction: - """ - 记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。 - """ - - request_id: str - user_id: str - channel: Optional[MessageChannel] - source: Optional[str] - username: Optional[str] - view: str = "root" - local_page: int = 0 - market_page: int = 0 - market_query: str = "" - awaiting_input: Optional[str] = None - created_at: datetime = field(default_factory=datetime.now) - - -class SkillsInteractionManager: - """ - 管理用户当前的技能交互状态。 - - 每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。 - """ - - _ttl = timedelta(hours=24) - - def __init__(self): - self._by_id: Dict[str, PendingSkillsInteraction] = {} - self._by_user: Dict[str, str] = {} - self._lock = Lock() - - def _cleanup_locked(self): - """ - 清理超时会话,避免按钮回调无限积累。 - """ - expire_before = datetime.now() - self._ttl - expired = [ - request_id - for request_id, request in self._by_id.items() - if request.created_at < expire_before - ] - for request_id in expired: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def create_or_replace( - self, - user_id: Union[str, int], - channel: Optional[MessageChannel], - source: Optional[str], - username: Optional[str], - ) -> PendingSkillsInteraction: - """ - 为用户创建新会话,并替换掉旧的技能交互状态。 - """ - with self._lock: - self._cleanup_locked() - user_key = str(user_id) - old_request_id = self._by_user.get(user_key) - if old_request_id: - self._by_id.pop(old_request_id, None) - request_id = uuid.uuid4().hex[:12] - request = PendingSkillsInteraction( - request_id=request_id, - user_id=user_key, - channel=channel, - source=source, - username=username, - ) - self._by_id[request_id] = request - self._by_user[user_key] = request_id - return request - - def get_by_user( - self, user_id: Union[str, int] - ) -> Optional[PendingSkillsInteraction]: - """ - 按用户获取当前有效会话,供纯文本回复路由使用。 - """ - with self._lock: - self._cleanup_locked() - request_id = self._by_user.get(str(user_id)) - if not request_id: - return None - return self._by_id.get(request_id) - - def get_by_id( - self, request_id: str, user_id: Union[str, int] - ) -> Optional[PendingSkillsInteraction]: - """ - 按请求 ID 获取会话,并校验会话归属用户。 - """ - with self._lock: - self._cleanup_locked() - request = self._by_id.get(request_id) - if not request or str(request.user_id) != str(user_id): - return None - return request - - def remove(self, request_id: str) -> None: - """ - 主动结束会话,释放用户和请求 ID 的双向索引。 - """ - with self._lock: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def clear(self): - """ - 清空所有会话,主要用于测试场景。 - """ - with self._lock: - self._by_id.clear() - self._by_user.clear() - - -skills_interaction_manager = SkillsInteractionManager() - - class SkillsChain(ChainBase): """ 处理 /skills 指令、按钮回调和文本式技能管理交互。 @@ -153,11 +26,11 @@ class SkillsChain(ChainBase): self.skillhelper = SkillHelper() def remote_manage( - self, - arg_str: str, - channel: MessageChannel, - userid: Union[str, int], - source: Optional[str] = None, + self, + arg_str: str, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, ): """ /skills 入口。创建新会话并渲染首屏菜单。 @@ -205,14 +78,14 @@ class SkillsChain(ChainBase): return request_id, action, index def handle_callback_interaction( - self, - callback_data: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, ) -> bool: """ 处理按钮交互,并在同一条消息上刷新当前视图。 @@ -364,12 +237,12 @@ class SkillsChain(ChainBase): return True def handle_text_interaction( - self, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - text: str, + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, ) -> bool: """ 处理不支持按钮渠道上的文本指令,也兼容用户直接回复文字操作。 @@ -660,42 +533,42 @@ class SkillsChain(ChainBase): return True def _install_market_skill( - self, - request: PendingSkillsInteraction, - page_index: int, + self, + request: PendingSkillsInteraction, + page_index: int, ) -> Tuple[bool, str]: """ 按当前市场页的可见序号安装技能,避免跨页序号歧义。 """ market_skills = self._get_market_skills(request=request) - page_items, page, _ = self._page_items( + items, page, _ = self._page_items( items=market_skills, page=request.market_page, page_size=self._page_size(request.channel), ) request.market_page = page - if page_index < 1 or page_index > len(page_items): + if page_index < 1 or page_index > len(items): return False, "安装序号无效" - return self.skillhelper.install_market_skill(page_items[page_index - 1]) + return self.skillhelper.install_market_skill(items[page_index - 1]) def _remove_local_skill( - self, - request: PendingSkillsInteraction, - page_index: int, + self, + request: PendingSkillsInteraction, + page_index: int, ) -> Tuple[bool, str]: """ 按当前已安装页的可见序号删除技能,并拦截内置技能。 """ local_skills = self.skillhelper.list_local_skills() - page_items, page, _ = self._page_items( + items, page, _ = self._page_items( items=local_skills, page=request.local_page, page_size=self._page_size(request.channel), ) request.local_page = page - if page_index < 1 or page_index > len(page_items): + if page_index < 1 or page_index > len(items): return False, "删除序号无效" - target = page_items[page_index - 1] + target = items[page_index - 1] if not target.removable: return False, f"技能 {target.id} 是内置技能,不能删除" return self.skillhelper.remove_local_skill(target.id) @@ -713,15 +586,15 @@ class SkillsChain(ChainBase): return self.skillhelper.remove_custom_market_source(target.source) def _render_interaction( - self, - request: PendingSkillsInteraction, - channel: MessageChannel, - source: Optional[str], - userid: Union[str, int], - username: Optional[str], - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - force_market_refresh: bool = False, + self, + request: PendingSkillsInteraction, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + force_market_refresh: bool = False, ) -> None: """ 根据当前视图生成内容,并选择编辑原消息或发送新消息。 @@ -758,9 +631,9 @@ class SkillsChain(ChainBase): ) def _build_root_view( - self, - request: PendingSkillsInteraction, - force_market_refresh: bool = False, + self, + request: PendingSkillsInteraction, + force_market_refresh: bool = False, ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建根菜单视图,汇总本地技能和市场概览。 @@ -809,14 +682,14 @@ class SkillsChain(ChainBase): return "技能管理", "\n".join(text_lines), buttons def _build_installed_view( - self, - request: PendingSkillsInteraction + self, + request: PendingSkillsInteraction ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建已安装技能视图,列出来源和可删除状态。 """ local_skills = self.skillhelper.list_local_skills() - page_items, page, total_pages = self._page_items( + items, page, total_pages = self._page_items( items=local_skills, page=request.local_page, page_size=self._page_size(request.channel), @@ -824,11 +697,11 @@ class SkillsChain(ChainBase): request.local_page = page text_lines = [f"第 {page + 1}/{total_pages} 页,共 {len(local_skills)} 个技能"] - if not page_items: + if not items: text_lines.append("") text_lines.append("当前没有已安装技能") else: - for index, skill in enumerate(page_items, start=1): + for index, skill in enumerate(items, start=1): action = "可删除" if skill.removable else "内置不可删" text_lines.extend( [ @@ -869,9 +742,9 @@ class SkillsChain(ChainBase): return "已安装技能", "\n".join(text_lines), buttons def _build_market_view( - self, - request: PendingSkillsInteraction, - force_market_refresh: bool = False, + self, + request: PendingSkillsInteraction, + force_market_refresh: bool = False, ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建技能市场视图,仅展示尚未安装的技能。 @@ -880,7 +753,7 @@ class SkillsChain(ChainBase): request=request, force_market_refresh=force_market_refresh, ) - page_items, page, total_pages = self._page_items( + items, page, total_pages = self._page_items( items=market_skills, page=request.market_page, page_size=self._page_size(request.channel), @@ -897,14 +770,14 @@ class SkillsChain(ChainBase): "搜索输入中:直接回复关键词即可筛选市场技能,回复 `取消` 结束输入。", ] ) - if not page_items: + if not items: text_lines.append("") if request.market_query: text_lines.append("当前搜索没有匹配的市场技能") else: text_lines.append("当前没有可安装的市场技能") else: - for index, skill in enumerate(page_items, start=1): + for index, skill in enumerate(items, start=1): text_lines.extend( [ "", @@ -970,8 +843,8 @@ class SkillsChain(ChainBase): return "技能市场", "\n".join(text_lines), buttons def _build_sources_view( - self, - request: PendingSkillsInteraction, + self, + request: PendingSkillsInteraction, ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建技能源管理视图,提供自定义 GitHub 源的增删入口。 @@ -1052,9 +925,9 @@ class SkillsChain(ChainBase): @staticmethod def _page_items( - items: List[SkillInfo], - page: int, - page_size: int, + items: List[SkillInfo], + page: int, + page_size: int, ) -> Tuple[List[SkillInfo], int, int]: """ 返回当前页的数据,并把页码钳制到有效范围内。 @@ -1080,9 +953,9 @@ class SkillsChain(ChainBase): @staticmethod def _navigation_buttons( - request: PendingSkillsInteraction, - page: int, - total_pages: int, + request: PendingSkillsInteraction, + page: int, + total_pages: int, ) -> List[List[dict]]: """ 为分页视图生成上一页和下一页按钮。 @@ -1095,16 +968,16 @@ class SkillsChain(ChainBase): ) def _update_or_post_message( - self, - channel: MessageChannel, - source: Optional[str], - userid: Union[str, int], - username: Optional[str], - title: str, - text: str, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, + self, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + title: str, + text: str, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, ) -> None: """ 优先编辑原消息,编辑失败时再回退为发送新消息。 @@ -1136,9 +1009,9 @@ class SkillsChain(ChainBase): return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出" def _get_market_skills( - self, - request: PendingSkillsInteraction, - force_market_refresh: bool = False, + self, + request: PendingSkillsInteraction, + force_market_refresh: bool = False, ) -> List[SkillInfo]: """ 获取当前 /skills 会话可见的市场技能,并应用搜索词过滤。 @@ -1183,8 +1056,8 @@ class SkillsChain(ChainBase): @staticmethod def _apply_market_search( - request: PendingSkillsInteraction, - query: str, + request: PendingSkillsInteraction, + query: str, ) -> None: """ 将会话切到市场搜索结果视图,并重置分页状态。 diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 17afae40..1f9ea7ba 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -1,6 +1,7 @@ import copy import json import random +import re import threading import time from datetime import datetime @@ -11,7 +12,7 @@ from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.media import MediaChain from app.chain.search import SearchChain -from app.helper.slash import ( +from app.helper.interaction import ( SlashInteractionManager, build_navigation_buttons, format_markdown_table, diff --git a/app/helper/__init__.py b/app/helper/__init__.py index ff2e48a2..e69de29b 100644 --- a/app/helper/__init__.py +++ b/app/helper/__init__.py @@ -1 +0,0 @@ -from .cloudflare import under_challenge diff --git a/app/helper/interaction.py b/app/helper/interaction.py new file mode 100644 index 00000000..8cc69d65 --- /dev/null +++ b/app/helper/interaction.py @@ -0,0 +1,626 @@ +import math +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from threading import Lock +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from app.core.context import MediaInfo +from app.core.meta import MetaBase +from app.schemas import Notification +from app.schemas.message import ChannelCapabilityManager +from app.schemas.types import MessageChannel + + +@dataclass +class PendingSlashInteraction: + """ + 通用 slash 命令交互上下文。 + """ + + request_id: str + user_id: str + channel: Optional[MessageChannel] + source: Optional[str] + username: Optional[str] + command: str + page: int = 0 + awaiting_input: Optional[str] = None + created_at: datetime = field(default_factory=datetime.now) + + +class SlashInteractionManager: + """ + 管理单个 slash 命令的交互会话。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._by_id: Dict[str, PendingSlashInteraction] = {} + self._by_user: Dict[str, str] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + expire_before = datetime.now() - self._ttl + expired = [ + request_id + for request_id, request in self._by_id.items() + if request.created_at < expire_before + ] + for request_id in expired: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def create_or_replace( + self, + user_id: Union[str, int], + command: str, + channel: Optional[MessageChannel], + source: Optional[str], + username: Optional[str], + ) -> PendingSlashInteraction: + with self._lock: + self._cleanup_locked() + user_key = str(user_id) + old_request_id = self._by_user.get(user_key) + if old_request_id: + self._by_id.pop(old_request_id, None) + request = PendingSlashInteraction( + request_id=uuid.uuid4().hex[:12], + user_id=user_key, + command=command, + channel=channel, + source=source, + username=username, + ) + self._by_id[request.request_id] = request + self._by_user[user_key] = request.request_id + return request + + def get_by_user( + self, user_id: Union[str, int] + ) -> Optional[PendingSlashInteraction]: + with self._lock: + self._cleanup_locked() + request_id = self._by_user.get(str(user_id)) + if not request_id: + return None + return self._by_id.get(request_id) + + def get_by_id( + self, request_id: str, user_id: Union[str, int] + ) -> Optional[PendingSlashInteraction]: + with self._lock: + self._cleanup_locked() + request = self._by_id.get(request_id) + if not request or str(request.user_id) != str(user_id): + return None + return request + + def remove(self, request_id: str) -> None: + with self._lock: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def clear(self) -> None: + with self._lock: + self._by_id.clear() + self._by_user.clear() + + +def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool: + """ + 渠道同时支持按钮和回调时,优先使用按钮交互。 + """ + return bool( + channel + and ChannelCapabilityManager.supports_buttons(channel) + and ChannelCapabilityManager.supports_callbacks(channel) + ) + + +def supports_markdown(channel: Optional[MessageChannel]) -> bool: + """ + 仅在支持 Markdown 的渠道上输出 Markdown 内容。 + """ + return bool(channel and ChannelCapabilityManager.supports_markdown(channel)) + + +def page_items( + items: Sequence[Any], + page: int, + page_size: int, +) -> Tuple[List[Any], int, int]: + """ + 对列表做分页并规范化页码。 + """ + total = len(items) + if total == 0: + return [], 0, 1 + total_pages = max(1, math.ceil(total / max(1, page_size))) + page = min(max(0, page), total_pages - 1) + start = page * page_size + end = start + page_size + return list(items[start:end]), page, total_pages + + +def build_navigation_buttons( + prefix: str, + request: Any, + page: int, + total_pages: int, +) -> List[List[dict]]: + """ + 构造标准上一页/下一页按钮。 + """ + buttons = [] + nav_row = [] + if page > 0: + nav_row.append( + { + "text": "⬅️ 上一页", + "callback_data": f"{prefix}:{request.request_id}:page-prev", + } + ) + if page < total_pages - 1: + nav_row.append( + { + "text": "下一页 ➡️", + "callback_data": f"{prefix}:{request.request_id}:page-next", + } + ) + if nav_row: + buttons.append(nav_row) + return buttons + + +def update_or_post_message( + chain, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + title: str, + text: str, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, +) -> None: + """ + 优先编辑原消息,失败时回退为发送新消息。 + """ + if ( + original_message_id + and original_chat_id + and ChannelCapabilityManager.supports_editing(channel) + ): + edited = chain.edit_message( + channel=channel, + source=source, + message_id=original_message_id, + chat_id=original_chat_id, + title=title, + text=text, + buttons=buttons, + ) + if edited: + return + + chain.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=title, + text=text, + buttons=buttons, + ) + ) + + +def escape_markdown_table_cell(value: object) -> str: + """ + 最小化转义 Markdown 表格中的特殊字符。 + """ + text = str(value or "").replace("\n", "
") + return text.replace("|", "\\|") + + +def format_markdown_table( + headers: Sequence[str], + rows: Sequence[Sequence[object]], +) -> str: + """ + 生成 Markdown 表格文本。 + """ + header_line = ( + "| " + + " | ".join(escape_markdown_table_cell(item) for item in headers) + + " |" + ) + separator_line = "| " + " | ".join("---" for _ in headers) + " |" + data_lines = [ + "| " + + " | ".join(escape_markdown_table_cell(item) for item in row) + + " |" + for row in rows + ] + return "\n".join([header_line, separator_line, *data_lines]) + + +@dataclass +class PendingMediaInteraction: + """ + 记录一次搜索/下载/订阅交互的当前上下文。 + """ + + request_id: str + user_id: str + channel: Optional[MessageChannel] + source: Optional[str] + username: Optional[str] + action: str + keyword: str + phase: str = "media" + page: int = 0 + title: str = "" + meta: Optional[MetaBase] = None + current_media: Optional[MediaInfo] = None + items: List[Any] = field(default_factory=list) + created_at: datetime = field(default_factory=datetime.now) + + +class MediaInteractionManager: + """ + 管理用户当前激活的媒体交互状态。 + + 每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._by_id: Dict[str, PendingMediaInteraction] = {} + self._by_user: Dict[str, str] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + """ + 清理超时会话,避免内存中残留旧交互状态。 + """ + expire_before = datetime.now() - self._ttl + expired = [ + request_id + for request_id, request in self._by_id.items() + if request.created_at < expire_before + ] + for request_id in expired: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def create_or_replace( + self, + user_id: Union[str, int], + channel: Optional[MessageChannel], + source: Optional[str], + username: Optional[str], + action: str, + keyword: str, + title: str = "", + meta: Optional[MetaBase] = None, + items: Optional[List[Any]] = None, + ) -> PendingMediaInteraction: + """ + 为用户创建新的交互状态,并替换旧会话。 + """ + with self._lock: + self._cleanup_locked() + user_key = str(user_id) + old_request_id = self._by_user.get(user_key) + if old_request_id: + self._by_id.pop(old_request_id, None) + + request = PendingMediaInteraction( + request_id=uuid.uuid4().hex[:12], + user_id=user_key, + channel=channel, + source=source, + username=username, + action=action, + keyword=keyword, + title=title, + meta=meta, + items=list(items or []), + ) + self._by_id[request.request_id] = request + self._by_user[user_key] = request.request_id + return request + + def get_by_user( + self, user_id: Union[str, int] + ) -> Optional[PendingMediaInteraction]: + """ + 按用户读取当前会话,供文本回复和旧按钮兼容使用。 + """ + with self._lock: + self._cleanup_locked() + request_id = self._by_user.get(str(user_id)) + if not request_id: + return None + return self._by_id.get(request_id) + + def get_by_id( + self, request_id: str, user_id: Union[str, int] + ) -> Optional[PendingMediaInteraction]: + """ + 按请求 ID 读取会话,并校验用户归属。 + """ + with self._lock: + self._cleanup_locked() + request = self._by_id.get(request_id) + if not request or str(request.user_id) != str(user_id): + return None + return request + + def remove(self, request_id: str) -> None: + """ + 主动结束一条会话。 + """ + with self._lock: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def clear(self) -> None: + """ + 清空所有交互状态,主要用于测试。 + """ + with self._lock: + self._by_id.clear() + self._by_user.clear() + + +media_interaction_manager = MediaInteractionManager() + + +@dataclass(frozen=True) +class AgentInteractionOption: + """ + Agent 交互选项。 + """ + + label: str + value: str + + +@dataclass +class PendingAgentInteraction: + """ + 待处理的 Agent 客户端交互请求。 + """ + + request_id: str + session_id: str + user_id: str + channel: Optional[str] + source: Optional[str] + username: Optional[str] + title: Optional[str] + prompt: str + options: List[AgentInteractionOption] + created_at: datetime = field(default_factory=datetime.now) + + +class AgentInteractionManager: + """ + 管理 Agent 发起的客户端交互请求。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._pending_interactions: Dict[str, PendingAgentInteraction] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + expire_before = datetime.now() - self._ttl + expired_ids = [ + request_id + for request_id, request in self._pending_interactions.items() + if request.created_at < expire_before + ] + for request_id in expired_ids: + self._pending_interactions.pop(request_id, None) + + def create_request( + self, + session_id: str, + user_id: str, + channel: Optional[str], + source: Optional[str], + username: Optional[str], + title: Optional[str], + prompt: str, + options: List[AgentInteractionOption], + ) -> PendingAgentInteraction: + """ + 创建一条待用户确认的 Agent 交互请求。 + """ + with self._lock: + self._cleanup_locked() + request_id = uuid.uuid4().hex[:12] + while request_id in self._pending_interactions: + request_id = uuid.uuid4().hex[:12] + request = PendingAgentInteraction( + request_id=request_id, + session_id=session_id, + user_id=str(user_id), + channel=channel, + source=source, + username=username, + title=title, + prompt=prompt, + options=options, + ) + self._pending_interactions[request_id] = request + return request + + def resolve( + self, + request_id: str, + option_index: int, + user_id: Optional[str] = None, + ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]: + """ + 消费一条 Agent 交互请求,并返回选中的选项。 + """ + with self._lock: + self._cleanup_locked() + request = self._pending_interactions.get(request_id) + if not request: + return None + if user_id is not None and str(request.user_id) != str(user_id): + return None + if option_index < 1 or option_index > len(request.options): + return None + option = request.options[option_index - 1] + self._pending_interactions.pop(request_id, None) + return request, option + + def clear(self) -> None: + """ + 清空所有 Agent 交互请求。 + """ + with self._lock: + self._pending_interactions.clear() + + +agent_interaction_manager = AgentInteractionManager() + + +@dataclass +class PendingSkillsInteraction: + """ + 记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。 + """ + + request_id: str + user_id: str + channel: Optional[MessageChannel] + source: Optional[str] + username: Optional[str] + view: str = "root" + local_page: int = 0 + market_page: int = 0 + market_query: str = "" + awaiting_input: Optional[str] = None + created_at: datetime = field(default_factory=datetime.now) + + +class SkillsInteractionManager: + """ + 管理用户当前的技能交互状态。 + + 每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._by_id: Dict[str, PendingSkillsInteraction] = {} + self._by_user: Dict[str, str] = {} + self._lock = Lock() + + def _cleanup_locked(self): + """ + 清理超时会话,避免按钮回调无限积累。 + """ + expire_before = datetime.now() - self._ttl + expired = [ + request_id + for request_id, request in self._by_id.items() + if request.created_at < expire_before + ] + for request_id in expired: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def create_or_replace( + self, + user_id: Union[str, int], + channel: Optional[MessageChannel], + source: Optional[str], + username: Optional[str], + ) -> PendingSkillsInteraction: + """ + 为用户创建新会话,并替换掉旧的技能交互状态。 + """ + with self._lock: + self._cleanup_locked() + user_key = str(user_id) + old_request_id = self._by_user.get(user_key) + if old_request_id: + self._by_id.pop(old_request_id, None) + request_id = uuid.uuid4().hex[:12] + request = PendingSkillsInteraction( + request_id=request_id, + user_id=user_key, + channel=channel, + source=source, + username=username, + ) + self._by_id[request_id] = request + self._by_user[user_key] = request_id + return request + + def get_by_user( + self, user_id: Union[str, int] + ) -> Optional[PendingSkillsInteraction]: + """ + 按用户获取当前有效会话,供纯文本回复路由使用。 + """ + with self._lock: + self._cleanup_locked() + request_id = self._by_user.get(str(user_id)) + if not request_id: + return None + return self._by_id.get(request_id) + + def get_by_id( + self, request_id: str, user_id: Union[str, int] + ) -> Optional[PendingSkillsInteraction]: + """ + 按请求 ID 获取会话,并校验会话归属用户。 + """ + with self._lock: + self._cleanup_locked() + request = self._by_id.get(request_id) + if not request or str(request.user_id) != str(user_id): + return None + return request + + def remove(self, request_id: str) -> None: + """ + 主动结束会话,释放用户和请求 ID 的双向索引。 + """ + with self._lock: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def clear(self): + """ + 清空所有会话,主要用于测试场景。 + """ + with self._lock: + self._by_id.clear() + self._by_user.clear() + + +skills_interaction_manager = SkillsInteractionManager() diff --git a/app/helper/slash.py b/app/helper/slash.py deleted file mode 100644 index 49f7105e..00000000 --- a/app/helper/slash.py +++ /dev/null @@ -1,244 +0,0 @@ -import math -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from threading import Lock -from typing import Dict, List, Optional, Sequence, Tuple, Union - -from app.schemas import Notification -from app.schemas.message import ChannelCapabilityManager -from app.schemas.types import MessageChannel - - -@dataclass -class PendingSlashInteraction: - """ - 通用 slash 命令交互上下文。 - """ - - request_id: str - user_id: str - channel: Optional[MessageChannel] - source: Optional[str] - username: Optional[str] - command: str - page: int = 0 - awaiting_input: Optional[str] = None - created_at: datetime = field(default_factory=datetime.now) - - -class SlashInteractionManager: - """ - 管理单个 slash 命令的交互会话。 - """ - - _ttl = timedelta(hours=24) - - def __init__(self): - self._by_id: Dict[str, PendingSlashInteraction] = {} - self._by_user: Dict[str, str] = {} - self._lock = Lock() - - def _cleanup_locked(self) -> None: - expire_before = datetime.now() - self._ttl - expired = [ - request_id - for request_id, request in self._by_id.items() - if request.created_at < expire_before - ] - for request_id in expired: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def create_or_replace( - self, - user_id: Union[str, int], - command: str, - channel: Optional[MessageChannel], - source: Optional[str], - username: Optional[str], - ) -> PendingSlashInteraction: - with self._lock: - self._cleanup_locked() - user_key = str(user_id) - old_request_id = self._by_user.get(user_key) - if old_request_id: - self._by_id.pop(old_request_id, None) - request = PendingSlashInteraction( - request_id=uuid.uuid4().hex[:12], - user_id=user_key, - command=command, - channel=channel, - source=source, - username=username, - ) - self._by_id[request.request_id] = request - self._by_user[user_key] = request.request_id - return request - - def get_by_user( - self, user_id: Union[str, int] - ) -> Optional[PendingSlashInteraction]: - with self._lock: - self._cleanup_locked() - request_id = self._by_user.get(str(user_id)) - if not request_id: - return None - return self._by_id.get(request_id) - - def get_by_id( - self, request_id: str, user_id: Union[str, int] - ) -> Optional[PendingSlashInteraction]: - with self._lock: - self._cleanup_locked() - request = self._by_id.get(request_id) - if not request or str(request.user_id) != str(user_id): - return None - return request - - def remove(self, request_id: str) -> None: - with self._lock: - request = self._by_id.pop(request_id, None) - if request: - self._by_user.pop(str(request.user_id), None) - - def clear(self) -> None: - with self._lock: - self._by_id.clear() - self._by_user.clear() - - -def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool: - """ - 渠道同时支持按钮和回调时,优先使用按钮交互。 - """ - return bool( - channel - and ChannelCapabilityManager.supports_buttons(channel) - and ChannelCapabilityManager.supports_callbacks(channel) - ) - - -def supports_markdown(channel: Optional[MessageChannel]) -> bool: - """ - 仅在支持 Markdown 的渠道上输出 Markdown 内容。 - """ - return bool(channel and ChannelCapabilityManager.supports_markdown(channel)) - - -def page_items( - items: Sequence, - page: int, - page_size: int, -) -> Tuple[List, int, int]: - """ - 对列表做分页并规范化页码。 - """ - total = len(items) - if total == 0: - return [], 0, 1 - total_pages = max(1, math.ceil(total / max(1, page_size))) - page = min(max(0, page), total_pages - 1) - start = page * page_size - end = start + page_size - return list(items[start:end]), page, total_pages - - -def build_navigation_buttons( - prefix: str, - request: PendingSlashInteraction, - page: int, - total_pages: int, -) -> List[List[dict]]: - """ - 构造标准上一页/下一页按钮。 - """ - buttons = [] - nav_row = [] - if page > 0: - nav_row.append( - { - "text": "⬅️ 上一页", - "callback_data": f"{prefix}:{request.request_id}:page-prev", - } - ) - if page < total_pages - 1: - nav_row.append( - { - "text": "下一页 ➡️", - "callback_data": f"{prefix}:{request.request_id}:page-next", - } - ) - if nav_row: - buttons.append(nav_row) - return buttons - - -def update_or_post_message( - chain, - channel: MessageChannel, - source: Optional[str], - userid: Union[str, int], - username: Optional[str], - title: str, - text: str, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, -) -> None: - """ - 优先编辑原消息,失败时回退为发送新消息。 - """ - if ( - original_message_id - and original_chat_id - and ChannelCapabilityManager.supports_editing(channel) - ): - edited = chain.edit_message( - channel=channel, - source=source, - message_id=original_message_id, - chat_id=original_chat_id, - title=title, - text=text, - buttons=buttons, - ) - if edited: - return - - chain.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=title, - text=text, - buttons=buttons, - ) - ) - - -def escape_markdown_table_cell(value: object) -> str: - """ - 最小化转义 Markdown 表格中的特殊字符。 - """ - text = str(value or "").replace("\n", "
") - text = text.replace("|", "\\|") - return text - - -def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str: - """ - 生成 Markdown 表格文本。 - """ - header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |" - separator_line = "| " + " | ".join("---" for _ in headers) + " |" - data_lines = [ - "| " - + " | ".join(escape_markdown_table_cell(item) for item in row) - + " |" - for row in rows - ] - return "\n".join([header_line, separator_line, *data_lines]) diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py index 957f6d9d..743b4307 100644 --- a/tests/test_agent_interaction.py +++ b/tests/test_agent_interaction.py @@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import ( AskUserChoiceTool, UserChoiceOptionInput, ) -from app.chain.interaction import ( +from app.helper.interaction import ( AgentInteractionOption, agent_interaction_manager, ) diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py index 862c1d93..2d049d4c 100644 --- a/tests/test_media_interaction.py +++ b/tests/test_media_interaction.py @@ -9,7 +9,7 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) setattr(sys.modules["transmission_rpc"], "File", object) sys.modules.setdefault("psutil", ModuleType("psutil")) -from app.chain.interaction import MediaInteractionChain, media_interaction_manager +from app.chain.media import MediaChain, media_interaction_manager from app.chain.message import MessageChain from app.core.context import MediaInfo from app.core.meta import MetaBase @@ -43,7 +43,7 @@ class TestMediaInteraction(unittest.TestCase): self.assertIsNotNone(request) with patch.object(chain, "_record_user_message"), patch( - "app.chain.message.MediaInteractionChain.handle_text_interaction", + "app.chain.message.MediaChain.handle_text_interaction", return_value=True, ) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai: chain.handle_message( @@ -72,7 +72,7 @@ class TestMediaInteraction(unittest.TestCase): ) with patch( - "app.chain.message.MediaInteractionChain.handle_callback_interaction", + "app.chain.message.MediaChain.handle_callback_interaction", return_value=True, ) as handle_callback: chain._handle_callback( @@ -86,7 +86,7 @@ class TestMediaInteraction(unittest.TestCase): handle_callback.assert_called_once() def test_media_interaction_starts_search_and_posts_media_list(self): - chain = MediaInteractionChain() + chain = MediaChain() meta = self._build_meta("星际穿越") medias = [ MediaInfo(title="星际穿越", year="2014"), @@ -94,7 +94,7 @@ class TestMediaInteraction(unittest.TestCase): ] with patch( - "app.chain.interaction.MediaChain.search", + "app.chain.media.MediaChain.search", return_value=(meta, medias), ), patch.object(chain, "post_medias_message") as post_medias_message: handled = chain.handle_text_interaction( @@ -119,7 +119,7 @@ class TestMediaInteraction(unittest.TestCase): self.assertEqual(len(request.items), 2) def test_media_interaction_legacy_page_callback_updates_existing_request(self): - chain = MediaInteractionChain() + chain = MediaChain() request = media_interaction_manager.create_or_replace( user_id="10001", channel=MessageChannel.Telegram,