diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 092ce8f1..e8f22ee1 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -8,6 +8,7 @@ from datetime import datetime from enum import Enum from typing import Any, Callable, Dict, List, Optional +from fastapi.concurrency import run_in_threadpool from langchain.agents import create_agent from langchain.agents.middleware import ( SummarizationMiddleware, @@ -53,6 +54,40 @@ class AgentChain(ChainBase): pass +def _finish_processing_status(status: Optional[dict], user_id: Optional[str] = None) -> None: + """结束入站消息的渠道处理状态。""" + if not status: + return + try: + channel = MessageChannel(status.get("channel")) + except Exception: + return + try: + AgentChain().run_module( + "mark_message_processing_finished", + channel=channel, + source=status.get("source"), + userid=status.get("userid") or user_id, + message_id=status.get("message_id"), + chat_id=status.get("chat_id"), + status=status, + ) + except Exception as err: + logger.debug(f"结束Agent消息处理状态失败: {err}") + + +async def _async_finish_processing_status( + status: Optional[dict], user_id: Optional[str] = None +) -> None: + """ + 在 Agent worker 中结束渠道处理状态。 + 渠道收口可能触发外部 API,同步实现需切到线程池避免阻塞事件循环。 + """ + if not status: + return + await run_in_threadpool(_finish_processing_status, status, user_id) + + @dataclass class _SessionUsageSnapshot: model: Optional[str] = None @@ -901,6 +936,7 @@ class _MessageTask: username: Optional[str] = None original_message_id: Optional[str] = None original_chat_id: Optional[str] = None + processing_status: Optional[dict] = None reply_mode: ReplyMode = ReplyMode.DISPATCH @@ -987,6 +1023,7 @@ class AgentManager: username: str = None, original_message_id: Optional[str] = None, original_chat_id: Optional[str] = None, + processing_status: Optional[dict] = None, reply_mode: ReplyMode = ReplyMode.DISPATCH, ) -> str: """ @@ -1004,6 +1041,7 @@ class AgentManager: username=username, original_message_id=original_message_id, original_chat_id=original_chat_id, + processing_status=processing_status, reply_mode=reply_mode, ) @@ -1062,6 +1100,9 @@ class AgentManager: except Exception as e: logger.error(f"处理会话 {session_id} 的消息失败: {e}") finally: + await _async_finish_processing_status( + task.processing_status, task.user_id + ) queue.task_done() except asyncio.CancelledError: diff --git a/app/chain/message.py b/app/chain/message.py index 098f3de2..bd0c55ae 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -31,7 +31,7 @@ from app.helper.interaction import agent_interaction_manager, media_interaction_ from app.helper.torrent import TorrentHelper from app.log import logger from app.schemas import Notification, CommingMessage, NotExistMediaInfo -from app.schemas.message import ChannelCapabilityManager +from app.schemas.message import ChannelCapabilityManager, ChannelCapability from app.schemas.types import EventType, MessageChannel, MediaType from app.utils.http import RequestUtils from app.utils.string import StringUtils @@ -48,11 +48,24 @@ class MessageChain(ChainBase): _session_timeout_minutes: int = 24 * 60 @dataclass - class _ProcessingMarker: + class _ProcessingStatus: channel: MessageChannel source: str - message_id: str - reaction_id: str + userid: Optional[Union[str, int]] = None + message_id: Optional[Union[str, int]] = None + chat_id: Optional[Union[str, int]] = None + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为模块接口可安全传递的普通字典。""" + return { + "channel": self.channel.value, + "source": self.source, + "userid": self.userid, + "message_id": self.message_id, + "chat_id": self.chat_id, + "metadata": self.metadata or {}, + } def process(self, body: Any, form: Any, args: Any) -> None: """ @@ -158,25 +171,40 @@ class MessageChain(ChainBase): text=text, ) - self._mark_message_processing_started( - channel=channel, - source=source, - original_message_id=original_message_id, - text=text, - ) - self._handle_message_core( + processing_status = self._mark_message_processing_started( channel=channel, source=source, userid=userid, - username=username, - text=text, original_message_id=original_message_id, original_chat_id=original_chat_id, - images=images, - audio_refs=audio_refs, - files=files, - has_audio_input=has_audio_input, + text=text, ) + continues_async = False + try: + continues_async = self._handle_message_core( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + images=images, + audio_refs=audio_refs, + files=files, + has_audio_input=has_audio_input, + processing_status=processing_status, + ) + finally: + if continues_async is not True: + self._mark_message_processing_finished( + channel=channel, + source=source, + userid=userid, + status=processing_status, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) def _handle_message_core( self, @@ -191,7 +219,8 @@ class MessageChain(ChainBase): audio_refs: Optional[List[str]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, has_audio_input: bool = False, - ) -> None: + processing_status: Optional[_ProcessingStatus] = None, + ) -> bool: """执行实际消息路由,便于统一包裹处理中状态。""" if text.startswith("CALLBACK:"): @@ -211,14 +240,14 @@ class MessageChain(ChainBase): channel.value, text, ) - return + return False if text.startswith("/") and not text.lower().startswith("/ai"): self.eventmanager.send_event( EventType.CommandExcute, {"cmd": text, "user": userid, "channel": channel, "source": source}, ) - return + return False latest_slash_interaction = self._get_latest_slash_interaction(userid) if latest_slash_interaction == "sites": @@ -229,7 +258,7 @@ class MessageChain(ChainBase): username=username, text=text, ): - return + return False if latest_slash_interaction == "subscribes": if SubscribeChain().handle_text_interaction( @@ -239,7 +268,7 @@ class MessageChain(ChainBase): username=username, text=text, ): - return + return False if latest_slash_interaction == "skills": if SkillsChain().handle_text_interaction( @@ -249,7 +278,7 @@ class MessageChain(ChainBase): username=username, text=text, ): - return + return False if media_interaction_manager.get_by_user(userid): if MediaInteractionChain().handle_text_interaction( @@ -259,10 +288,10 @@ class MessageChain(ChainBase): username=username, text=text, ): - return + return False if text.lower().startswith("/ai"): - self._handle_ai_message( + return self._handle_ai_message( text=text, channel=channel, source=source, @@ -272,14 +301,14 @@ class MessageChain(ChainBase): original_chat_id=original_chat_id, images=images, files=files, + processing_status=processing_status, ) - return if ( settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input) ): - self._handle_ai_message( + return self._handle_ai_message( text=text, channel=channel, source=source, @@ -289,8 +318,8 @@ class MessageChain(ChainBase): original_chat_id=original_chat_id, images=images, files=files, + processing_status=processing_status, ) - return if MediaInteractionChain().handle_text_interaction( channel=channel, @@ -299,7 +328,7 @@ class MessageChain(ChainBase): username=username, text=text, ): - return + return False self.eventmanager.send_event( EventType.UserMessage, @@ -310,36 +339,81 @@ class MessageChain(ChainBase): "source": source, }, ) + return False def _mark_message_processing_started( self, channel: MessageChannel, source: str, + userid: Union[str, int], original_message_id: Optional[Union[str, int]], + original_chat_id: Optional[Union[str, int]], text: str, - ) -> Optional[_ProcessingMarker]: + ) -> Optional[_ProcessingStatus]: """为支持的渠道标记“消息正在处理”。""" - if channel != MessageChannel.Feishu: + if not ChannelCapabilityManager.supports_capability( + channel, ChannelCapability.PROCESSING_STATUS + ): return None - if not original_message_id or not text or text.startswith("CALLBACK:"): + if not text: return None - reaction_id = self.run_module( - "add_feishu_message_reaction", - message_id=str(original_message_id), - emoji_type="GLANCE", - source=source, - ) - if not reaction_id: + try: + status = self.run_module( + "mark_message_processing_started", + channel=channel, + source=source, + userid=userid, + message_id=original_message_id, + chat_id=original_chat_id, + text=text, + ) + except Exception as err: + logger.debug(f"标记消息处理状态失败: {err}") return None - return self._ProcessingMarker( + if not isinstance(status, dict): + return None + metadata = status.get("metadata") + return self._ProcessingStatus( channel=channel, source=source, - message_id=str(original_message_id), - reaction_id=str(reaction_id), + userid=status.get("userid", userid), + message_id=status.get("message_id", original_message_id), + chat_id=status.get("chat_id", original_chat_id), + metadata=metadata if isinstance(metadata, dict) else {}, ) + def _mark_message_processing_finished( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + status: Optional[_ProcessingStatus] = None, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[Union[str, int]] = None, + ) -> None: + """ + 结束渠道侧“消息正在处理”状态。 + 不同渠道的表现可能是 reaction、typing 等,消息链只负责调用通用模块接口。 + """ + if not status and not ChannelCapabilityManager.supports_capability( + channel, ChannelCapability.PROCESSING_STATUS + ): + return + try: + self.run_module( + "mark_message_processing_finished", + channel=channel, + source=source, + userid=userid, + message_id=status.message_id if status else original_message_id, + chat_id=status.chat_id if status else original_chat_id, + status=status.to_dict() if status else None, + ) + except Exception as err: + logger.debug(f"结束消息处理状态失败: {err}") + def _handle_callback( self, text: str, @@ -1063,7 +1137,8 @@ class MessageChain(ChainBase): images: Optional[List[CommingMessage.MessageImage]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, session_id: Optional[str] = None, - ) -> None: + processing_status: Optional[_ProcessingStatus] = None, + ) -> bool: """ 处理AI智能体消息 """ @@ -1079,7 +1154,7 @@ class MessageChain(ChainBase): title="MoviePilot智能助手未启用,请在系统设置中启用", ) ) - return + return False images = CommingMessage.MessageImage.normalize_list(images) @@ -1099,7 +1174,7 @@ class MessageChain(ChainBase): title="请输入您的问题或需求", ) ) - return + return False # 生成或复用会话ID session_id = session_id or self._get_or_create_session_id(userid) @@ -1122,7 +1197,7 @@ class MessageChain(ChainBase): title="附件读取失败,请稍后重试", ) ) - return + return False elif images: image_attachments = self._build_image_attachments(images) if ( @@ -1140,7 +1215,7 @@ class MessageChain(ChainBase): title="附件读取失败,请稍后重试", ) ) - return + return False all_files.extend(image_attachments) images = None @@ -1160,7 +1235,7 @@ class MessageChain(ChainBase): title="文件读取失败,请稍后重试", ) ) - return + return False # 在事件循环中处理 asyncio.run_coroutine_threadsafe( @@ -1175,17 +1250,20 @@ class MessageChain(ChainBase): username=username, original_message_id=str(original_message_id) if original_message_id else None, original_chat_id=original_chat_id, + processing_status=processing_status.to_dict() + if processing_status + else None, ), global_vars.loop, ) - return + return True except Exception as e: logger.error(f"处理AI智能体消息失败: {e}") self.messagehelper.put( f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手" ) - return + return False def _transcribe_audio_refs( self, audio_refs: List[str], channel: MessageChannel, source: str diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index a112b776..3ceaafbf 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -473,6 +473,69 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): return True return False + def mark_message_processing_started( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + text: Optional[str] = None, + ) -> Optional[dict]: + """ + 使用 Discord typing 指示标记“正在处理”。 + """ + if channel != self._channel: + return None + if not text: + return None + config = self.get_config(source) + if not config: + return None + client: Discord = self.get_instance(config.name) + if not client: + return None + if not client.start_typing( + userid=str(userid) if userid else None, + chat_id=str(chat_id) if chat_id else None, + ): + return None + return { + "channel": channel.value, + "source": source, + "userid": userid, + "message_id": str(message_id) if message_id else None, + "chat_id": str(chat_id) if chat_id else None, + "metadata": {"kind": "typing"}, + } + + def mark_message_processing_finished( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + status: Optional[dict] = None, + ) -> Optional[bool]: + """ + 停止 Discord typing 续发任务。 + """ + if channel != self._channel: + return None + target_chat_id = (status or {}).get("chat_id") or chat_id + target_userid = (status or {}).get("userid") or userid + config = self.get_config(source) + if not config: + return False + client: Discord = self.get_instance(config.name) + if not client: + return False + return client.stop_typing( + userid=str(target_userid) if target_userid else None, + chat_id=str(target_chat_id) if target_chat_id else None, + ) + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: """ 直接发送消息并返回消息ID等信息 diff --git a/app/modules/discord/discord.py b/app/modules/discord/discord.py index aa4a3953..435b09dd 100644 --- a/app/modules/discord/discord.py +++ b/app/modules/discord/discord.py @@ -79,6 +79,10 @@ class Discord: ] = {} # userid -> chat_id mapping for reply targeting self._broadcast_channel = None self._bot_user_id: Optional[int] = None + self._typing_tasks: Dict[str, asyncio.Task] = {} + self._typing_stop_events: Dict[str, asyncio.Event] = {} + self._typing_interval_seconds = 5 + self._typing_max_duration_seconds = 5 * 60 self._register_events() self._start() @@ -209,6 +213,9 @@ class Discord: if not self._client or not self._loop or not self._thread: return try: + asyncio.run_coroutine_threadsafe( + self._stop_all_typing_tasks(), self._loop + ).result(timeout=5) asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result( timeout=10 ) @@ -367,6 +374,125 @@ class Discord: logger.error(f"发送 Discord 种子列表失败:{err}") return False + def start_typing( + self, + userid: Optional[str] = None, + chat_id: Optional[str] = None, + max_duration_seconds: Optional[float] = None, + ) -> bool: + """ + 持续发送 Discord typing 指示,直到显式停止或达到最大续期。 + """ + if not self.get_state(): + return False + typing_key = self._typing_key(userid=userid, chat_id=chat_id) + if not typing_key: + return False + try: + future = asyncio.run_coroutine_threadsafe( + self._start_typing_task( + typing_key=typing_key, + userid=userid, + chat_id=chat_id, + max_duration_seconds=max_duration_seconds, + ), + self._loop, + ) + return future.result(timeout=10) + except Exception as err: + logger.error(f"发送 Discord typing 状态失败:{err}") + return False + + def stop_typing( + self, + userid: Optional[str] = None, + chat_id: Optional[str] = None, + ) -> bool: + """ + 停止 Discord typing 续发任务。 + """ + typing_key = self._typing_key(userid=userid, chat_id=chat_id) + if not typing_key or not self._loop: + return False + try: + future = asyncio.run_coroutine_threadsafe( + self._stop_typing_task(typing_key), self._loop + ) + return future.result(timeout=5) + except Exception as err: + logger.error(f"停止 Discord typing 状态失败:{err}") + return False + + @staticmethod + def _typing_key(userid: Optional[str] = None, chat_id: Optional[str] = None) -> str: + """优先按频道维度管理 typing 状态,缺失时退回用户维度。""" + if chat_id: + return f"chat:{chat_id}" + if userid: + return f"user:{userid}" + return "" + + async def _start_typing_task( + self, + typing_key: str, + userid: Optional[str] = None, + chat_id: Optional[str] = None, + max_duration_seconds: Optional[float] = None, + ) -> bool: + await self._stop_typing_task(typing_key) + channel = await self._resolve_channel(userid=userid, chat_id=chat_id) + if not channel: + return False + stop_event = asyncio.Event() + max_duration = max_duration_seconds or self._typing_max_duration_seconds + + async def _typing_worker() -> None: + started_at = self._loop.time() + try: + while not stop_event.is_set(): + if self._loop.time() - started_at >= max_duration: + logger.warning( + "Discord typing状态超过最大续期,自动停止: key=%s", + typing_key, + ) + break + try: + await channel.trigger_typing() + except Exception as err: + logger.debug(f"触发 Discord typing 状态失败:{err}") + try: + await asyncio.wait_for( + stop_event.wait(), + timeout=self._typing_interval_seconds, + ) + except asyncio.TimeoutError: + pass + finally: + current_task = asyncio.current_task() + if self._typing_tasks.get(typing_key) is current_task: + self._typing_tasks.pop(typing_key, None) + self._typing_stop_events.pop(typing_key, None) + + self._typing_stop_events[typing_key] = stop_event + self._typing_tasks[typing_key] = asyncio.create_task(_typing_worker()) + return True + + async def _stop_typing_task(self, typing_key: str) -> bool: + stop_event = self._typing_stop_events.pop(typing_key, None) + task = self._typing_tasks.pop(typing_key, None) + if stop_event: + stop_event.set() + if task and task is not asyncio.current_task() and not task.done(): + try: + await asyncio.wait_for(asyncio.shield(task), timeout=1) + except asyncio.TimeoutError: + pass + return bool(stop_event or task) + + async def _stop_all_typing_tasks(self) -> None: + for typing_key in list(self._typing_tasks.keys()): + await self._stop_typing_task(typing_key) + def delete_msg( self, message_id: Union[str, int], chat_id: Optional[str] = None ) -> Optional[bool]: diff --git a/app/modules/feishu/__init__.py b/app/modules/feishu/__init__.py index 4c8cdaa5..db6fa9fa 100644 --- a/app/modules/feishu/__init__.py +++ b/app/modules/feishu/__init__.py @@ -360,6 +360,67 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]): return False return client.delete_message_reaction(message_id=message_id, reaction_id=reaction_id) + def mark_message_processing_started( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + text: Optional[str] = None, + ) -> Optional[dict]: + """ + 使用飞书消息表情标记“正在处理”。 + """ + if channel != self._channel: + return None + if not message_id or not text or str(text).startswith("CALLBACK:"): + return None + reaction_id = self.add_feishu_message_reaction( + message_id=str(message_id), + emoji_type=Feishu.PROCESSING_REACTION_EMOJI, + source=source, + ) + if not reaction_id: + return None + return { + "channel": channel.value, + "source": source, + "userid": userid, + "message_id": str(message_id), + "chat_id": str(chat_id) if chat_id else None, + "metadata": { + "kind": "reaction", + "reaction_id": str(reaction_id), + "emoji_type": Feishu.PROCESSING_REACTION_EMOJI, + }, + } + + def mark_message_processing_finished( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + status: Optional[dict] = None, + ) -> Optional[bool]: + """ + 删除飞书“正在处理”表情。 + """ + if channel != self._channel: + return None + metadata = (status or {}).get("metadata") or {} + target_message_id = (status or {}).get("message_id") or message_id + reaction_id = metadata.get("reaction_id") + if not target_message_id or not reaction_id: + return False + return self.delete_feishu_message_reaction( + message_id=str(target_message_id), + reaction_id=str(reaction_id), + source=source, + ) + def finalize_message(self, response: MessageResponse) -> bool: if response.channel != self._channel or not isinstance(response.metadata, dict): return False diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 90e74796..dd7a9ba4 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -12,6 +12,7 @@ from app.schemas.types import ModuleType class SlackModule(_ModuleBase, _MessageBase[Slack]): + PROCESSING_REACTION = "eyes" _AUDIO_SUFFIXES = ( ".mp3", ".m4a", @@ -222,10 +223,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): images = None audio_refs = None files = None + message_id = None + chat_id = None if msg_json.get("type") == "message": userid = msg_json.get("user") text = msg_json.get("text") username = msg_json.get("user") + message_id = msg_json.get("ts") + chat_id = msg_json.get("channel") images = self._extract_images(msg_json) audio_refs = self._extract_audio_refs(msg_json) files = self._extract_files(msg_json) @@ -270,6 +275,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): flags=re.IGNORECASE, ).strip() username = "" + message_id = msg_json.get("event", {}).get("ts") + chat_id = msg_json.get("event", {}).get("channel") images = self._extract_images(msg_json.get("event", {})) audio_refs = self._extract_audio_refs(msg_json.get("event", {})) files = self._extract_files(msg_json.get("event", {})) @@ -281,6 +288,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): userid = msg_json.get("user_id") text = msg_json.get("command") username = msg_json.get("user_name") + chat_id = msg_json.get("channel_id") else: return None logger.info( @@ -294,6 +302,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): userid=userid, username=username, text=text, + message_id=message_id, + chat_id=chat_id, images=images, audio_refs=audio_refs, files=files, @@ -589,6 +599,78 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return True return False + def mark_message_processing_started( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + text: Optional[str] = None, + ) -> Optional[dict]: + """ + 使用 Slack reaction 标记“正在处理”。 + """ + if channel != self._channel: + return None + if not message_id or not chat_id or not text or str(text).startswith("CALLBACK:"): + return None + config = self.get_config(source) + if not config: + return None + client: Slack = self.get_instance(config.name) + if not client: + return None + if not client.add_reaction( + channel=str(chat_id), + timestamp=str(message_id), + emoji=self.PROCESSING_REACTION, + ): + return None + return { + "channel": channel.value, + "source": source, + "userid": userid, + "message_id": str(message_id), + "chat_id": str(chat_id), + "metadata": { + "kind": "reaction", + "emoji": self.PROCESSING_REACTION, + }, + } + + def mark_message_processing_finished( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + status: Optional[dict] = None, + ) -> Optional[bool]: + """ + 移除 Slack “正在处理” reaction。 + """ + if channel != self._channel: + return None + metadata = (status or {}).get("metadata") or {} + target_message_id = (status or {}).get("message_id") or message_id + target_chat_id = (status or {}).get("chat_id") or chat_id + emoji = metadata.get("emoji") or self.PROCESSING_REACTION + if not target_message_id or not target_chat_id: + return False + config = self.get_config(source) + if not config: + return False + client: Slack = self.get_instance(config.name) + if not client: + return False + return client.remove_reaction( + channel=str(target_chat_id), + timestamp=str(target_message_id), + emoji=str(emoji), + ) + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: """ 直接发送消息并返回消息ID等信息 diff --git a/app/modules/slack/slack.py b/app/modules/slack/slack.py index 77b95f94..e0c2086e 100644 --- a/app/modules/slack/slack.py +++ b/app/modules/slack/slack.py @@ -289,6 +289,40 @@ class Slack: logger.error(f"Slack文件发送失败: {err}") return False, str(err) + def add_reaction(self, channel: str, timestamp: str, emoji: str) -> bool: + """ + 为 Slack 消息添加 reaction,用作正在处理状态。 + """ + if not self._client or not channel or not timestamp or not emoji: + return False + try: + result = self._client.reactions_add( + channel=channel, + timestamp=timestamp, + name=emoji, + ) + return bool(result and result.get("ok", True)) + except Exception as err: + logger.error(f"Slack添加reaction失败: {err}") + return False + + def remove_reaction(self, channel: str, timestamp: str, emoji: str) -> bool: + """ + 移除 Slack 消息 reaction。 + """ + if not self._client or not channel or not timestamp or not emoji: + return False + try: + result = self._client.reactions_remove( + channel=channel, + timestamp=timestamp, + name=emoji, + ) + return bool(result and result.get("ok", True)) + except Exception as err: + logger.error(f"Slack移除reaction失败: {err}") + return False + def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, original_message_id: Optional[str] = None, diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 183f0675..f7392de4 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -596,6 +596,57 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return True return False + def mark_message_processing_started( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + text: Optional[str] = None, + ) -> Optional[dict]: + """ + 标记 Telegram 消息正在处理。 + 入站侧已经启动 typing 任务,这里只返回可用于统一收口的上下文。 + """ + if channel != self._channel: + return None + if not text: + return None + return { + "channel": channel.value, + "source": source, + "userid": userid, + "message_id": message_id, + "chat_id": chat_id, + "metadata": {"kind": "typing"}, + } + + def mark_message_processing_finished( + self, + channel: MessageChannel, + source: str, + userid: Optional[Union[str, int]] = None, + message_id: Optional[Union[str, int]] = None, + chat_id: Optional[Union[str, int]] = None, + status: Optional[dict] = None, + ) -> Optional[bool]: + """ + 结束 Telegram typing 状态。 + """ + if channel != self._channel: + return None + if status: + chat_id = status.get("chat_id") or chat_id + userid = status.get("userid") or userid + client_config = self.get_config(source) + if not client_config: + return False + client: Telegram = self.get_instance(client_config.name) + if not client: + return False + return client.stop_typing(chat_id=chat_id, userid=userid) + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: """ 直接发送消息并返回消息ID等信息 diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index ad55c9d8..e3be15fd 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -43,7 +43,12 @@ class Telegram: ] = {} # userid -> chat_id mapping for reply targeting _bot_username: Optional[str] = None # Bot username for mention detection _typing_tasks: Dict[str, threading.Thread] = {} # chat_id -> typing任务 - _typing_stop_flags: Dict[str, bool] = {} # chat_id -> 停止标志 + _typing_stop_flags: Dict[str, threading.Event] = {} # chat_id -> 停止信号 + _typing_lock = threading.RLock() + _typing_interval_seconds = 5 + _typing_max_duration_seconds = 5 * 60 + _typing_command_max_duration_seconds = 30 + _typing_callback_max_duration_seconds = 60 def __init__( self, @@ -54,13 +59,13 @@ class Telegram: """ 初始化参数 """ + # 即使配置不完整也保留基础属性,便于测试和未启用实例安全调用发送方法。 + self._telegram_token = TELEGRAM_TOKEN + self._telegram_chat_id = TELEGRAM_CHAT_ID + self._polling_thread = None if not TELEGRAM_TOKEN or not TELEGRAM_CHAT_ID: logger.error("Telegram配置不完整!") return - # Token - self._telegram_token = TELEGRAM_TOKEN - # Chat Id - self._telegram_chat_id = TELEGRAM_CHAT_ID # 初始化机器人 if self._telegram_token and self._telegram_chat_id: # telegram bot api 地址,格式:https://api.telegram.org @@ -114,22 +119,42 @@ class Telegram: # Check if we should process this message if self._should_process_message(message): # 启动持续发送正在输入状态 - self._start_typing_task(message.chat.id) + message_text = message.text or message.caption or "" + max_duration = ( + self._typing_command_max_duration_seconds + if ( + message_text.startswith("/") + and not message_text.lower().startswith("/ai") + ) + else None + ) + self._start_typing_task( + message.chat.id, max_duration_seconds=max_duration + ) payload = self._serialize_update_payload(message) if not payload: logger.warn("Telegram消息序列化失败,跳过转发") + self._stop_typing_task(message.chat.id) return - RequestUtils(timeout=15).post_res(self._ds_url, json=payload) + response = RequestUtils(timeout=15).post_res( + self._ds_url, json=payload + ) + if not response or response.status_code >= 400: + logger.warn("Telegram消息转发失败,停止typing状态") + self._stop_typing_task(message.chat.id) @_bot.callback_query_handler(func=lambda call: True) def callback_query(call): """ 处理按钮点击回调 """ + chat_id = None + typing_started = False try: # Update user-chat mapping for callbacks too + chat_id = call.message.chat.id self._update_user_chat_mapping( - call.from_user.id, call.message.chat.id + call.from_user.id, chat_id ) # 解析回调数据 @@ -146,7 +171,7 @@ class Telegram: "message": { "message_id": call.message.message_id, "chat": { - "id": call.message.chat.id, + "id": chat_id, }, }, "data": callback_data, @@ -157,13 +182,24 @@ class Telegram: _bot.answer_callback_query(call.id) # 启动持续发送正在输入状态 - self._start_typing_task(call.message.chat.id) + self._start_typing_task( + chat_id, + max_duration_seconds=self._typing_callback_max_duration_seconds, + ) + typing_started = True # 发送给主程序处理 - RequestUtils(timeout=15).post_res(self._ds_url, json=callback_json) + response = RequestUtils(timeout=15).post_res( + self._ds_url, json=callback_json + ) + if not response or response.status_code >= 400: + logger.warn("Telegram按钮回调转发失败,停止typing状态") + self._stop_typing_task(chat_id) except Exception as err: logger.error(f"处理按钮回调失败:{str(err)}") + if typing_started and chat_id is not None: + self._stop_typing_task(chat_id) _bot.answer_callback_query(call.id, "处理失败,请重试") def run_polling(): @@ -326,46 +362,85 @@ class Telegram: """ return self._bot is not None - def _start_typing_task(self, chat_id: Union[str, int]) -> None: + def _start_typing_task( + self, + chat_id: Union[str, int], + max_duration_seconds: Optional[float] = None, + ) -> None: """ 启动持续发送正在输入状态的任务 """ chat_id_str = str(chat_id) # 如果已有任务在运行,先停止 - if chat_id_str in self._typing_tasks: - self._stop_typing_task(chat_id_str) + self._stop_typing_task(chat_id_str) - # 设置停止标志 - self._typing_stop_flags[chat_id_str] = False + # 使用独立 Event 避免同一 chat 新旧 typing 线程互相误改停止标记。 + stop_event = threading.Event() + max_duration = max_duration_seconds or self._typing_max_duration_seconds def typing_worker(): """定期发送typing状态的后台线程""" - while not self._typing_stop_flags.get(chat_id_str, True): - try: - if self._bot: - self._bot.send_chat_action(chat_id, "typing") - except Exception as e: - logger.debug(f"发送typing状态失败: {e}") - # 每5秒发送一次(Telegram客户端会在约5-6秒后消失状态) - for _ in range(50): - if self._typing_stop_flags.get(chat_id_str, True): + started_at = time.monotonic() + try: + while not stop_event.is_set(): + if time.monotonic() - started_at >= max_duration: + logger.warning( + "Telegram typing状态超过最大续期,自动停止: chat_id=%s", + chat_id_str, + ) break - time.sleep(0.1) + try: + if self._bot: + self._bot.send_chat_action(chat_id, "typing") + except Exception as e: + logger.debug(f"发送typing状态失败: {e}") + # Telegram 客户端约 5-6 秒后会隐藏 typing,需要周期性续发。 + stop_event.wait(self._typing_interval_seconds) + finally: + with self._typing_lock: + current = self._typing_tasks.get(chat_id_str) + if current is threading.current_thread(): + self._typing_tasks.pop(chat_id_str, None) + self._typing_stop_flags.pop(chat_id_str, None) thread = threading.Thread(target=typing_worker, daemon=True) + with self._typing_lock: + self._typing_stop_flags[chat_id_str] = stop_event + self._typing_tasks[chat_id_str] = thread thread.start() - self._typing_tasks[chat_id_str] = thread def _stop_typing_task(self, chat_id: Union[str, int]) -> None: """ 停止正在输入状态的任务 """ chat_id_str = str(chat_id) - self._typing_stop_flags[chat_id_str] = True - if chat_id_str in self._typing_tasks: + with self._typing_lock: + stop_event = self._typing_stop_flags.pop(chat_id_str, None) task = self._typing_tasks.pop(chat_id_str, None) - if task and task.is_alive(): - task.join(timeout=1) + if stop_event: + stop_event.set() + if task and task.is_alive() and task is not threading.current_thread(): + task.join(timeout=1) + + def stop_typing( + self, + chat_id: Optional[Union[str, int]] = None, + userid: Optional[Union[str, int]] = None, + ) -> bool: + """ + 外部链路主动停止 typing 状态。 + """ + if chat_id: + target_chat_id = chat_id + elif userid: + target_chat_id = self._get_user_chat_id(str(userid)) or str(userid) + else: + target_chat_id = None + target_chat_id = target_chat_id or (str(userid) if userid else None) + if not target_chat_id: + return False + self._stop_typing_task(target_chat_id) + return True def send_msg( self, @@ -395,12 +470,12 @@ class Telegram: if not self._telegram_token or not self._telegram_chat_id: return None - if not title and not text: - logger.warn("标题和内容不能同时为空") - return {"success": False} - # Determine target chat_id with improved logic using user mapping chat_id = self._determine_target_chat_id(userid, original_chat_id) + if not title and not text: + logger.warn("标题和内容不能同时为空") + self._stop_typing_task(chat_id) + return {"success": False} try: # 标准化标题后再加粗,避免**符号被显示为文本 @@ -483,6 +558,7 @@ class Telegram: voice_file = Path(voice_path) if not voice_file.exists(): logger.error(f"语音文件不存在: {voice_file}") + self._stop_typing_task(chat_id) return {"success": False} try: @@ -526,12 +602,13 @@ class Telegram: if not self._bot or not file_path: return None + chat_id = self._determine_target_chat_id(userid, original_chat_id) local_file = Path(file_path) if not local_file.exists() or not local_file.is_file(): logger.error(f"附件文件不存在: {local_file}") + self._stop_typing_task(chat_id) return {"success": False} - chat_id = self._determine_target_chat_id(userid, original_chat_id) send_name = file_name or local_file.name suffix = local_file.suffix.lower() is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} @@ -622,6 +699,8 @@ class Telegram: if not self._telegram_token or not self._telegram_chat_id: return None + # 列表消息也可能是一次交互的最终响应,需要确保 typing 状态在发送后结束。 + chat_id = self._determine_target_chat_id(userid, original_chat_id) try: index, image, caption = 1, "", "*%s*" % title for media in medias: @@ -649,9 +728,6 @@ class Telegram: if link: caption = f"{caption}\n[查看详情]({link})" - # Determine target chat_id with improved logic using user mapping - chat_id = self._determine_target_chat_id(userid, original_chat_id) - # 创建按钮键盘 reply_markup = None if buttons: @@ -675,6 +751,8 @@ class Telegram: except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") return False + finally: + self._stop_typing_task(chat_id) def send_torrents_msg( self, @@ -699,6 +777,8 @@ class Telegram: if not self._telegram_token or not self._telegram_chat_id: return None + # 资源列表是搜索交互的常见出口,也必须统一释放 typing 状态。 + chat_id = self._determine_target_chat_id(userid, original_chat_id) try: index, caption = 1, "*%s*" % title image = torrents[0].media_info.get_message_image() @@ -725,9 +805,6 @@ class Telegram: if link: caption = f"{caption}\n[查看详情]({link})" - # Determine target chat_id with improved logic using user mapping - chat_id = self._determine_target_chat_id(userid, original_chat_id) - # 创建按钮键盘 reply_markup = None if buttons: @@ -751,6 +828,8 @@ class Telegram: except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") return False + finally: + self._stop_typing_task(chat_id) @staticmethod def _create_inline_keyboard(buttons: List[List[Dict]]) -> InlineKeyboardMarkup: @@ -872,6 +951,8 @@ class Telegram: except Exception as e: logger.error(f"编辑Telegram消息异常: {str(e)}") return False + finally: + self._stop_typing_task(chat_id) def __edit_message( self, diff --git a/app/schemas/message.py b/app/schemas/message.py index 6ec6e749..64e365ab 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -275,6 +275,8 @@ class ChannelCapability(Enum): LINKS = "links" # 支持文件发送 FILE_SENDING = "file_sending" + # 支持可收口的消息处理状态提示,如 reaction 或 typing + PROCESSING_STATUS = "processing_status" @dataclass @@ -312,6 +314,7 @@ class ChannelCapabilityManager: ChannelCapability.IMAGES, ChannelCapability.LINKS, ChannelCapability.FILE_SENDING, + ChannelCapability.PROCESSING_STATUS, }, max_buttons_per_row=4, max_button_rows=10, @@ -339,6 +342,7 @@ class ChannelCapabilityManager: ChannelCapability.IMAGES, ChannelCapability.LINKS, ChannelCapability.FILE_SENDING, + ChannelCapability.PROCESSING_STATUS, }, max_buttons_per_row=3, max_button_rows=8, @@ -370,6 +374,7 @@ class ChannelCapabilityManager: ChannelCapability.LINKS, ChannelCapability.MENU_COMMANDS, ChannelCapability.FILE_SENDING, + ChannelCapability.PROCESSING_STATUS, }, max_buttons_per_row=3, max_button_rows=8, @@ -390,6 +395,7 @@ class ChannelCapabilityManager: ChannelCapability.IMAGES, ChannelCapability.LINKS, ChannelCapability.FILE_SENDING, + ChannelCapability.PROCESSING_STATUS, }, max_buttons_per_row=5, max_button_rows=5, diff --git a/tests/test_feishu.py b/tests/test_feishu.py index 10aabba4..729c935c 100644 --- a/tests/test_feishu.py +++ b/tests/test_feishu.py @@ -965,6 +965,50 @@ class TestFeishu(unittest.TestCase): self.assertEqual(reaction_id, "reaction_2") self.assertTrue(deleted) + def test_module_processing_status_uses_reaction_helpers(self): + module = FeishuModule() + module._channel = MessageChannel.Feishu + + with ( + patch.object( + module, + "add_feishu_message_reaction", + return_value="reaction_processing", + ) as add_reaction, + patch.object( + module, + "delete_feishu_message_reaction", + return_value=True, + ) as delete_reaction, + ): + status = module.mark_message_processing_started( + channel=MessageChannel.Feishu, + source="feishu-main", + userid="ou_x", + message_id="om_x", + chat_id="oc_x", + text="hello", + ) + deleted = module.mark_message_processing_finished( + channel=MessageChannel.Feishu, + source="feishu-main", + userid="ou_x", + status=status, + ) + + add_reaction.assert_called_once_with( + message_id="om_x", + emoji_type="GLANCE", + source="feishu-main", + ) + delete_reaction.assert_called_once_with( + message_id="om_x", + reaction_id="reaction_processing", + source="feishu-main", + ) + self.assertEqual(status["metadata"]["reaction_id"], "reaction_processing") + self.assertTrue(deleted) + def test_module_finalize_message_closes_streaming_card(self): module = FeishuModule() module._channel = MessageChannel.Feishu diff --git a/tests/test_message_processing_status.py b/tests/test_message_processing_status.py new file mode 100644 index 00000000..001a4d29 --- /dev/null +++ b/tests/test_message_processing_status.py @@ -0,0 +1,153 @@ +import json +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from app.agent import _finish_processing_status +from app.modules.discord import DiscordModule +from app.modules.slack import SlackModule +from app.schemas.message import ChannelCapability, ChannelCapabilityManager +from app.schemas.types import MessageChannel + + +class TestMessageProcessingStatus(unittest.TestCase): + def test_processing_status_capability_only_enabled_for_supported_channels(self): + supported = { + MessageChannel.Telegram, + MessageChannel.Feishu, + MessageChannel.Slack, + MessageChannel.Discord, + } + + for channel in MessageChannel: + self.assertEqual( + ChannelCapabilityManager.supports_capability( + channel, ChannelCapability.PROCESSING_STATUS + ), + channel in supported, + ) + + def test_slack_processing_status_uses_reaction(self): + module = SlackModule() + module._channel = MessageChannel.Slack + client = MagicMock() + client.add_reaction.return_value = True + client.remove_reaction.return_value = True + + with ( + patch.object( + module, "get_config", return_value=SimpleNamespace(name="slack-main") + ), + patch.object(module, "get_instance", return_value=client), + ): + status = module.mark_message_processing_started( + channel=MessageChannel.Slack, + source="slack-main", + userid="U01", + message_id="1710000000.000100", + chat_id="C01", + text="hello", + ) + removed = module.mark_message_processing_finished( + channel=MessageChannel.Slack, + source="slack-main", + userid="U01", + status=status, + ) + + client.add_reaction.assert_called_once_with( + channel="C01", + timestamp="1710000000.000100", + emoji="eyes", + ) + client.remove_reaction.assert_called_once_with( + channel="C01", + timestamp="1710000000.000100", + emoji="eyes", + ) + self.assertEqual(status["metadata"]["kind"], "reaction") + self.assertTrue(removed) + + def test_slack_parser_exposes_message_location_for_reaction_status(self): + module = SlackModule() + + with patch.object( + module, "get_config", return_value=SimpleNamespace(name="slack-main") + ): + message = module.message_parser( + source="slack-main", + body=json.dumps( + { + "type": "message", + "user": "U01", + "text": "hello", + "ts": "1710000000.000100", + "channel": "C01", + } + ), + form=None, + args=None, + ) + + self.assertEqual(message.message_id, "1710000000.000100") + self.assertEqual(message.chat_id, "C01") + + def test_discord_processing_status_starts_and_stops_typing(self): + module = DiscordModule() + module._channel = MessageChannel.Discord + client = MagicMock() + client.start_typing.return_value = True + client.stop_typing.return_value = True + + with ( + patch.object( + module, "get_config", return_value=SimpleNamespace(name="discord-main") + ), + patch.object(module, "get_instance", return_value=client), + ): + status = module.mark_message_processing_started( + channel=MessageChannel.Discord, + source="discord-main", + userid="10001", + message_id="20002", + chat_id="30003", + text="hello", + ) + finished = module.mark_message_processing_finished( + channel=MessageChannel.Discord, + source="discord-main", + userid="10001", + status=status, + ) + + client.start_typing.assert_called_once_with(userid="10001", chat_id="30003") + client.stop_typing.assert_called_once_with(userid="10001", chat_id="30003") + self.assertEqual(status["metadata"]["kind"], "typing") + self.assertTrue(finished) + + def test_agent_finish_processing_status_uses_module_interface(self): + status = { + "channel": MessageChannel.Telegram.value, + "source": "telegram-main", + "userid": "10001", + "message_id": None, + "chat_id": "-100", + "metadata": {"kind": "typing"}, + } + + with patch("app.agent.AgentChain") as chain_cls: + _finish_processing_status(status, user_id="fallback") + + chain_cls.return_value.run_module.assert_called_once_with( + "mark_message_processing_finished", + channel=MessageChannel.Telegram, + source="telegram-main", + userid="10001", + message_id=None, + chat_id="-100", + status=status, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_telegram_typing_lifecycle.py b/tests/test_telegram_typing_lifecycle.py new file mode 100644 index 00000000..4e96fb87 --- /dev/null +++ b/tests/test_telegram_typing_lifecycle.py @@ -0,0 +1,189 @@ +import time +import unittest +from unittest.mock import Mock, patch + +from app.chain.message import MessageChain +from app.modules.telegram.telegram import Telegram +from app.schemas.types import MessageChannel + + +class TestTelegramTypingLifecycle(unittest.TestCase): + def setUp(self): + self._cleanup_typing_tasks() + + def tearDown(self): + self._cleanup_typing_tasks() + + @staticmethod + def _cleanup_typing_tasks(): + helper = Telegram.__new__(Telegram) + for chat_id in list(Telegram._typing_tasks.keys()): + helper._stop_typing_task(chat_id) + Telegram._typing_tasks.clear() + Telegram._typing_stop_flags.clear() + Telegram._user_chat_mapping.clear() + + @staticmethod + def _telegram_client() -> Telegram: + telegram = Telegram.__new__(Telegram) + telegram._bot = Mock() + telegram._telegram_token = "token" + telegram._telegram_chat_id = "default-chat" + # 缩短测试中的等待时间,不改变生产默认续发间隔。 + telegram._typing_interval_seconds = 0.01 + telegram._typing_max_duration_seconds = 1 + return telegram + + def test_start_typing_can_stop_by_chat_id(self): + telegram = self._telegram_client() + + telegram._start_typing_task("chat-1", max_duration_seconds=1) + time.sleep(0.03) + + self.assertIn("chat-1", Telegram._typing_tasks) + self.assertTrue(telegram._bot.send_chat_action.called) + self.assertTrue(telegram.stop_typing(chat_id="chat-1")) + self.assertNotIn("chat-1", Telegram._typing_tasks) + + def test_start_typing_can_stop_by_user_mapping(self): + telegram = self._telegram_client() + Telegram._user_chat_mapping["10001"] = "chat-2" + + telegram._start_typing_task("chat-2", max_duration_seconds=1) + time.sleep(0.03) + + self.assertTrue(telegram.stop_typing(userid="10001")) + self.assertNotIn("chat-2", Telegram._typing_tasks) + + def test_typing_task_has_max_duration_guard(self): + telegram = self._telegram_client() + + telegram._start_typing_task("chat-3", max_duration_seconds=0.02) + time.sleep(0.08) + + self.assertNotIn("chat-3", Telegram._typing_tasks) + + def test_slash_command_stops_typing_when_message_handler_returns(self): + chain = MessageChain.__new__(MessageChain) + status = MessageChain._ProcessingStatus( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + chat_id="-100", + metadata={"kind": "typing"}, + ) + + with patch.object(chain, "_record_user_message"), patch.object( + chain, "_mark_message_processing_started", return_value=status + ), patch.object(chain, "_handle_message_core"), patch.object( + chain, "_mark_message_processing_finished" + ) as finish_status: + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="/sites", + original_chat_id="-100", + ) + + finish_status.assert_called_once_with( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + status=status, + original_message_id=None, + original_chat_id="-100", + ) + + def test_async_agent_keeps_processing_status_for_worker(self): + chain = MessageChain.__new__(MessageChain) + status = MessageChain._ProcessingStatus( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + chat_id="-100", + metadata={"kind": "typing"}, + ) + + with patch.object(chain, "_record_user_message"), patch.object( + chain, "_mark_message_processing_started", return_value=status + ), patch.object(chain, "_handle_message_core", return_value=True), patch.object( + chain, "_mark_message_processing_finished" + ) as finish_status: + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="/ai 搜索电影", + original_chat_id="-100", + ) + + finish_status.assert_not_called() + + def test_callback_stops_typing_when_message_handler_returns(self): + chain = MessageChain.__new__(MessageChain) + status = MessageChain._ProcessingStatus( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + chat_id="-100", + metadata={"kind": "typing"}, + ) + + with patch.object(chain, "_record_user_message"), patch.object( + chain, "_mark_message_processing_started", return_value=status + ), patch.object(chain, "_handle_message_core"), patch.object( + chain, "_mark_message_processing_finished" + ) as finish_status: + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="CALLBACK:sites:req-1:refresh", + original_chat_id="-100", + ) + + finish_status.assert_called_once_with( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + status=status, + original_message_id=None, + original_chat_id="-100", + ) + + def test_chain_finishes_processing_through_module_interface(self): + chain = MessageChain.__new__(MessageChain) + status = MessageChain._ProcessingStatus( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + chat_id="-100", + metadata={"kind": "typing"}, + ) + + with patch.object(chain, "run_module") as run_module: + chain._mark_message_processing_finished( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + status=status, + original_chat_id="-100", + ) + + run_module.assert_called_once_with( + "mark_message_processing_finished", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + message_id=None, + chat_id="-100", + status=status.to_dict(), + ) + + +if __name__ == "__main__": + unittest.main()