mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 07:26:48 +00:00
fix: simplify message typing lifecycle
This commit is contained in:
@@ -58,22 +58,36 @@ def _finish_processing_status(status: Optional[dict], user_id: Optional[str] = N
|
||||
"""结束入站消息的渠道处理状态。"""
|
||||
if not status:
|
||||
return
|
||||
try:
|
||||
channel = MessageChannel(status.get("channel"))
|
||||
except Exception:
|
||||
return
|
||||
try:
|
||||
AgentChain().run_module(
|
||||
"mark_message_processing_finished",
|
||||
channel=channel,
|
||||
source=status.get("source"),
|
||||
userid=status.get("userid") or user_id,
|
||||
message_id=status.get("message_id"),
|
||||
chat_id=status.get("chat_id"),
|
||||
status=status,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"结束Agent消息处理状态失败: {err}")
|
||||
AgentChain().finish_message_processing_status(
|
||||
status=status,
|
||||
userid=user_id,
|
||||
)
|
||||
|
||||
|
||||
async def _async_start_processing_status(task: "_MessageTask") -> Optional[dict]:
|
||||
"""
|
||||
在 Agent worker 中启动渠道处理状态。
|
||||
渠道启动可能触发外部 API,同步实现需切到线程池避免阻塞事件循环。
|
||||
"""
|
||||
if not task.channel:
|
||||
return None
|
||||
|
||||
def _start() -> Optional[dict]:
|
||||
"""在线程池中通过统一 Chain 接口启动处理状态。"""
|
||||
try:
|
||||
return AgentChain().start_message_processing_status(
|
||||
channel=MessageChannel(task.channel),
|
||||
source=task.source,
|
||||
userid=task.user_id,
|
||||
message_id=task.original_message_id,
|
||||
chat_id=task.original_chat_id,
|
||||
text=task.message,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"启动Agent消息处理状态失败: {err}")
|
||||
return None
|
||||
|
||||
return await run_in_threadpool(_start)
|
||||
|
||||
|
||||
async def _async_finish_processing_status(
|
||||
@@ -952,8 +966,6 @@ class AgentManager:
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# typing 这类状态按会话/聊天共享,前一条任务结束时可能仍需延续到后续排队消息。
|
||||
self._deferred_processing_statuses: Dict[str, dict] = {}
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
@@ -1009,7 +1021,6 @@ class AgentManager:
|
||||
pass
|
||||
self._session_workers.clear()
|
||||
self._session_queues.clear()
|
||||
self._deferred_processing_statuses.clear()
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
@@ -1026,7 +1037,6 @@ class AgentManager:
|
||||
username: str = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
processing_status: Optional[dict] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -1044,7 +1054,6 @@ class AgentManager:
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
processing_status=processing_status,
|
||||
reply_mode=reply_mode,
|
||||
)
|
||||
|
||||
@@ -1099,15 +1108,12 @@ class AgentManager:
|
||||
break
|
||||
|
||||
try:
|
||||
await self._start_task_processing_status(task)
|
||||
await self._process_message_internal(task)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
|
||||
finally:
|
||||
await self._finish_task_processing_status(
|
||||
session_id=session_id,
|
||||
task=task,
|
||||
queue=queue,
|
||||
)
|
||||
await self._finish_task_processing_status(task)
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -1121,52 +1127,23 @@ class AgentManager:
|
||||
and self._session_queues[session_id].empty()
|
||||
):
|
||||
self._session_queues.pop(session_id, None)
|
||||
self._deferred_processing_statuses.pop(session_id, None)
|
||||
|
||||
@staticmethod
|
||||
def _is_shared_processing_status(status: Optional[dict]) -> bool:
|
||||
async def _start_task_processing_status(task: _MessageTask) -> None:
|
||||
"""
|
||||
判断状态是否属于同一聊天窗口共享的处理提示。
|
||||
reaction 绑定到具体消息,应按消息收口;typing 绑定到会话/聊天,需要等队列空闲再关闭。
|
||||
在 Agent worker 真正开始处理消息时启动渠道处理状态。
|
||||
"""
|
||||
metadata = (status or {}).get("metadata") or {}
|
||||
return isinstance(metadata, dict) and metadata.get("kind") == "typing"
|
||||
|
||||
async def _finish_task_processing_status(
|
||||
self,
|
||||
session_id: str,
|
||||
task: _MessageTask,
|
||||
queue: asyncio.Queue,
|
||||
) -> None:
|
||||
"""
|
||||
根据会话队列状态结束或延后处理提示。
|
||||
当后面还有排队消息时,typing 状态继续保留;队列真正空闲后再统一关闭。
|
||||
"""
|
||||
status = task.processing_status
|
||||
if self._is_shared_processing_status(status) and not queue.empty():
|
||||
self._deferred_processing_statuses[session_id] = status
|
||||
if task.processing_status:
|
||||
return
|
||||
task.processing_status = await _async_start_processing_status(task)
|
||||
|
||||
if status:
|
||||
await _async_finish_processing_status(status, task.user_id)
|
||||
if self._is_shared_processing_status(status):
|
||||
self._deferred_processing_statuses.pop(session_id, None)
|
||||
elif queue.empty():
|
||||
deferred_status = self._deferred_processing_statuses.pop(
|
||||
session_id, None
|
||||
)
|
||||
if deferred_status:
|
||||
await _async_finish_processing_status(
|
||||
deferred_status, task.user_id
|
||||
)
|
||||
return
|
||||
|
||||
if not queue.empty():
|
||||
return
|
||||
|
||||
deferred_status = self._deferred_processing_statuses.pop(session_id, None)
|
||||
if deferred_status:
|
||||
await _async_finish_processing_status(deferred_status, task.user_id)
|
||||
@staticmethod
|
||||
async def _finish_task_processing_status(task: _MessageTask) -> None:
|
||||
"""
|
||||
在 Agent worker 完成或异常后结束本条消息的渠道处理状态。
|
||||
"""
|
||||
await _async_finish_processing_status(task.processing_status, task.user_id)
|
||||
task.processing_status = None
|
||||
|
||||
async def _process_message_internal(self, task: _MessageTask):
|
||||
"""
|
||||
@@ -1232,7 +1209,6 @@ class AgentManager:
|
||||
break
|
||||
self._session_queues.pop(session_id, None)
|
||||
stopped = True
|
||||
self._deferred_processing_statuses.pop(session_id, None)
|
||||
|
||||
if stopped:
|
||||
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
|
||||
@@ -1256,7 +1232,6 @@ class AgentManager:
|
||||
|
||||
# 清理队列
|
||||
self._session_queues.pop(session_id, None)
|
||||
self._deferred_processing_statuses.pop(session_id, None)
|
||||
|
||||
# 清理agent
|
||||
if session_id in self.active_agents:
|
||||
|
||||
@@ -41,6 +41,7 @@ from app.schemas import (
|
||||
MessageResponse,
|
||||
)
|
||||
from app.utils.identity import normalize_internal_user_id
|
||||
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import (
|
||||
TorrentStatus,
|
||||
@@ -122,6 +123,74 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
self.filecache.delete(filename)
|
||||
|
||||
def start_message_processing_status(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Optional[Union[str, int]] = None,
|
||||
message_id: Optional[Union[str, int]] = None,
|
||||
chat_id: Optional[Union[str, int]] = None,
|
||||
text: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
启动渠道侧消息输入/处理状态。
|
||||
具体表现由消息模块实现,例如 typing 保活或消息 reaction。
|
||||
"""
|
||||
if not channel or not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.PROCESSING_STATUS
|
||||
):
|
||||
return None
|
||||
try:
|
||||
status = self.run_module(
|
||||
"mark_message_processing_started",
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
message_id=message_id,
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"启动消息处理状态失败: {err}")
|
||||
return None
|
||||
return status if isinstance(status, dict) else None
|
||||
|
||||
def finish_message_processing_status(
|
||||
self,
|
||||
status: Optional[dict] = None,
|
||||
channel: Optional[MessageChannel] = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Optional[Union[str, int]] = None,
|
||||
message_id: Optional[Union[str, int]] = None,
|
||||
chat_id: Optional[Union[str, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
结束渠道侧消息输入/处理状态。
|
||||
优先使用 start 返回的 status,缺失时使用显式渠道和消息定位参数。
|
||||
"""
|
||||
target_channel = channel
|
||||
if status:
|
||||
try:
|
||||
target_channel = MessageChannel(status.get("channel"))
|
||||
except Exception:
|
||||
target_channel = channel
|
||||
if not target_channel or not ChannelCapabilityManager.supports_capability(
|
||||
target_channel, ChannelCapability.PROCESSING_STATUS
|
||||
):
|
||||
return
|
||||
try:
|
||||
self.run_module(
|
||||
"mark_message_processing_finished",
|
||||
channel=target_channel,
|
||||
source=(status or {}).get("source") or source,
|
||||
userid=(status or {}).get("userid") or userid,
|
||||
message_id=(status or {}).get("message_id") or message_id,
|
||||
chat_id=(status or {}).get("chat_id") or chat_id,
|
||||
status=status,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"结束消息处理状态失败: {err}")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_notification_for_dispatch(
|
||||
message: Notification
|
||||
|
||||
@@ -137,14 +137,7 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
|
||||
processing_status = self._mark_message_processing_started(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
text=text,
|
||||
)
|
||||
processing_status = None
|
||||
continues_async = False
|
||||
try:
|
||||
# 语音输入只用于转写为文本,不默认改变回复形式。
|
||||
@@ -181,6 +174,23 @@ class MessageChain(ChainBase):
|
||||
text=text,
|
||||
)
|
||||
|
||||
if not self._is_agent_message(
|
||||
channel=channel,
|
||||
userid=userid,
|
||||
text=text,
|
||||
images=images,
|
||||
files=files,
|
||||
has_audio_input=has_audio_input,
|
||||
):
|
||||
processing_status = self._mark_message_processing_started(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
text=text,
|
||||
)
|
||||
|
||||
continues_async = self._handle_message_core(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -310,7 +320,6 @@ class MessageChain(ChainBase):
|
||||
original_chat_id=original_chat_id,
|
||||
images=images,
|
||||
files=files,
|
||||
processing_status=processing_status,
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -327,7 +336,6 @@ class MessageChain(ChainBase):
|
||||
original_chat_id=original_chat_id,
|
||||
images=images,
|
||||
files=files,
|
||||
processing_status=processing_status,
|
||||
)
|
||||
|
||||
if MediaInteractionChain().handle_text_interaction(
|
||||
@@ -350,6 +358,35 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return False
|
||||
|
||||
def _is_agent_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
text: str,
|
||||
images: Optional[List[CommingMessage.MessageImage]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
has_audio_input: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
判断本条消息是否会进入 Agent worker,由 Agent worker 管理 typing 生命周期。
|
||||
"""
|
||||
if text.startswith("CALLBACK:"):
|
||||
return self._parse_agent_choice_callback(text[9:]) is not None
|
||||
if text.lower().startswith("/ai"):
|
||||
return True
|
||||
if text.startswith("/"):
|
||||
return False
|
||||
if not (
|
||||
settings.AI_AGENT_ENABLE
|
||||
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
|
||||
):
|
||||
return False
|
||||
if self._get_latest_slash_interaction(userid):
|
||||
return False
|
||||
if media_interaction_manager.get_by_user(userid):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _mark_message_processing_started(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
@@ -360,27 +397,17 @@ class MessageChain(ChainBase):
|
||||
text: str,
|
||||
) -> Optional[_ProcessingStatus]:
|
||||
"""为支持的渠道标记“消息正在处理”。"""
|
||||
if not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.PROCESSING_STATUS
|
||||
):
|
||||
status = self.start_message_processing_status(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
text=text,
|
||||
)
|
||||
if not status:
|
||||
return None
|
||||
|
||||
try:
|
||||
status = self.run_module(
|
||||
"mark_message_processing_started",
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
text=text,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"标记消息处理状态失败: {err}")
|
||||
return None
|
||||
|
||||
if not isinstance(status, dict):
|
||||
return None
|
||||
metadata = status.get("metadata")
|
||||
return self._ProcessingStatus(
|
||||
channel=channel,
|
||||
@@ -404,22 +431,16 @@ class MessageChain(ChainBase):
|
||||
结束渠道侧“消息正在处理”状态。
|
||||
不同渠道的表现可能是 reaction、typing 等,消息链只负责调用通用模块接口。
|
||||
"""
|
||||
if not status and not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.PROCESSING_STATUS
|
||||
):
|
||||
if not status:
|
||||
return
|
||||
try:
|
||||
self.run_module(
|
||||
"mark_message_processing_finished",
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
message_id=status.message_id if status else original_message_id,
|
||||
chat_id=status.chat_id if status else original_chat_id,
|
||||
status=status.to_dict() if status else None,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"结束消息处理状态失败: {err}")
|
||||
self.finish_message_processing_status(
|
||||
status=status.to_dict(),
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
message_id=status.message_id or original_message_id,
|
||||
chat_id=status.chat_id or original_chat_id,
|
||||
)
|
||||
|
||||
def _handle_callback(
|
||||
self,
|
||||
@@ -501,7 +522,6 @@ class MessageChain(ChainBase):
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
processing_status=processing_status,
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -1148,7 +1168,6 @@ class MessageChain(ChainBase):
|
||||
images: Optional[List[CommingMessage.MessageImage]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
processing_status: Optional[_ProcessingStatus] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
@@ -1261,9 +1280,6 @@ class MessageChain(ChainBase):
|
||||
username=username,
|
||||
original_message_id=str(original_message_id) if original_message_id else None,
|
||||
original_chat_id=original_chat_id,
|
||||
processing_status=processing_status.to_dict()
|
||||
if processing_status
|
||||
else None,
|
||||
),
|
||||
global_vars.loop,
|
||||
)
|
||||
|
||||
@@ -34,22 +34,10 @@ def _finish_command_processing_status(status: Optional[dict], user_id: Optional[
|
||||
"""
|
||||
if not status:
|
||||
return
|
||||
try:
|
||||
channel = MessageChannel(status.get("channel"))
|
||||
except Exception:
|
||||
return
|
||||
try:
|
||||
CommandChain().run_module(
|
||||
"mark_message_processing_finished",
|
||||
channel=channel,
|
||||
source=status.get("source"),
|
||||
userid=status.get("userid") or user_id,
|
||||
message_id=status.get("message_id"),
|
||||
chat_id=status.get("chat_id"),
|
||||
status=status,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"结束命令消息处理状态失败: {err}")
|
||||
CommandChain().finish_message_processing_status(
|
||||
status=status,
|
||||
userid=user_id,
|
||||
)
|
||||
|
||||
|
||||
class Command(metaclass=Singleton):
|
||||
|
||||
@@ -585,14 +585,12 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
continue
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
stop_typing = not (metadata or {}).get("agent_managed_typing")
|
||||
result = client.edit_msg(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
title=title,
|
||||
buttons=buttons,
|
||||
stop_typing=stop_typing,
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
|
||||
@@ -138,14 +138,9 @@ class TestMessageProcessingStatus(unittest.TestCase):
|
||||
with patch("app.agent.AgentChain") as chain_cls:
|
||||
_finish_processing_status(status, user_id="fallback")
|
||||
|
||||
chain_cls.return_value.run_module.assert_called_once_with(
|
||||
"mark_message_processing_finished",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-main",
|
||||
userid="10001",
|
||||
message_id=None,
|
||||
chat_id="-100",
|
||||
chain_cls.return_value.finish_message_processing_status.assert_called_once_with(
|
||||
status=status,
|
||||
userid="fallback",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||
sys.modules.setdefault("app.helper.sites", ModuleType("app.helper.sites"))
|
||||
setattr(sys.modules["app.helper.sites"], "SitesHelper", object)
|
||||
|
||||
from app.agent import AgentManager, _MessageTask
|
||||
from app.agent import AgentManager, _MessageTask, _async_start_processing_status
|
||||
from app.chain.message import MessageChain
|
||||
from app.command import Command, _finish_command_processing_status
|
||||
from app.modules.telegram import TelegramModule
|
||||
@@ -202,29 +202,25 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
|
||||
with patch("app.command.CommandChain") as chain_cls:
|
||||
_finish_command_processing_status(status, user_id="fallback")
|
||||
|
||||
chain_cls.return_value.run_module.assert_called_once_with(
|
||||
"mark_message_processing_finished",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
message_id=None,
|
||||
chat_id="-100",
|
||||
chain_cls.return_value.finish_message_processing_status.assert_called_once_with(
|
||||
status=status,
|
||||
userid="fallback",
|
||||
)
|
||||
|
||||
def test_async_agent_keeps_processing_status_for_worker(self):
|
||||
def test_async_agent_leaves_processing_status_to_worker(self):
|
||||
chain = MessageChain.__new__(MessageChain)
|
||||
status = MessageChain._ProcessingStatus(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
chat_id="-100",
|
||||
metadata={"kind": "typing"},
|
||||
)
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch.object(
|
||||
chain, "_mark_message_processing_started", return_value=status
|
||||
), patch.object(chain, "_handle_message_core", return_value=True), patch.object(
|
||||
chain, "_mark_message_processing_started"
|
||||
) as start_status, patch(
|
||||
"app.chain.message.settings.AI_AGENT_ENABLE", True
|
||||
), patch(
|
||||
"app.chain.message.agent_manager.process_message",
|
||||
new_callable=AsyncMock,
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
|
||||
), patch.object(
|
||||
chain, "_mark_message_processing_finished"
|
||||
) as finish_status:
|
||||
chain.handle_message(
|
||||
@@ -236,7 +232,83 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
|
||||
original_chat_id="-100",
|
||||
)
|
||||
|
||||
start_status.assert_not_called()
|
||||
finish_status.assert_not_called()
|
||||
process_message.assert_called_once()
|
||||
self.assertNotIn("processing_status", process_message.call_args.kwargs)
|
||||
self.assertEqual(
|
||||
process_message.call_args.kwargs["channel"],
|
||||
MessageChannel.Telegram.value,
|
||||
)
|
||||
self.assertEqual(process_message.call_args.kwargs["source"], "telegram-test")
|
||||
self.assertEqual(process_message.call_args.kwargs["original_chat_id"], "-100")
|
||||
|
||||
def test_agent_manager_starts_processing_status_when_task_runs(self):
|
||||
async def _run():
|
||||
manager = AgentManager()
|
||||
task = _MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第一条",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
original_chat_id="-100",
|
||||
)
|
||||
status = {
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.agent._async_start_processing_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
) as start_status:
|
||||
await manager._start_task_processing_status(task)
|
||||
|
||||
start_status.assert_awaited_once_with(task)
|
||||
self.assertEqual(task.processing_status, status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_agent_start_processing_status_uses_chain_interface(self):
|
||||
async def _run():
|
||||
task = _MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第一条",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
original_message_id="10",
|
||||
original_chat_id="-100",
|
||||
)
|
||||
status = {
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"message_id": "10",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
}
|
||||
|
||||
with patch("app.agent.AgentChain") as chain_cls:
|
||||
chain_cls.return_value.start_message_processing_status.return_value = status
|
||||
result = await _async_start_processing_status(task)
|
||||
|
||||
chain_cls.return_value.start_message_processing_status.assert_called_once_with(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
message_id="10",
|
||||
chat_id="-100",
|
||||
text="第一条",
|
||||
)
|
||||
self.assertEqual(result, status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_callback_stops_typing_when_message_handler_returns(self):
|
||||
chain = MessageChain.__new__(MessageChain)
|
||||
@@ -281,7 +353,7 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
|
||||
metadata={"kind": "typing"},
|
||||
)
|
||||
|
||||
with patch.object(chain, "run_module") as run_module:
|
||||
with patch.object(chain, "finish_message_processing_status") as finish_status:
|
||||
chain._mark_message_processing_finished(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
@@ -290,119 +362,106 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
|
||||
original_chat_id="-100",
|
||||
)
|
||||
|
||||
run_module.assert_called_once_with(
|
||||
"mark_message_processing_finished",
|
||||
finish_status.assert_called_once_with(
|
||||
status=status.to_dict(),
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
message_id=None,
|
||||
chat_id="-100",
|
||||
status=status.to_dict(),
|
||||
)
|
||||
|
||||
def test_agent_manager_defers_shared_typing_until_queued_task_finishes(self):
|
||||
def test_agent_manager_finishes_processing_status_after_each_task(self):
|
||||
async def _run():
|
||||
manager = AgentManager()
|
||||
queue = asyncio.Queue()
|
||||
first = _MessageTask(
|
||||
status = {
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
}
|
||||
task = _MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第一条",
|
||||
processing_status={
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
},
|
||||
processing_status=status,
|
||||
)
|
||||
second = _MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第二条",
|
||||
processing_status={
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
},
|
||||
)
|
||||
await queue.put(second)
|
||||
|
||||
with patch(
|
||||
"app.agent._async_finish_processing_status",
|
||||
new_callable=AsyncMock,
|
||||
) as finish_status:
|
||||
await manager._finish_task_processing_status(
|
||||
session_id="session-1",
|
||||
task=first,
|
||||
queue=queue,
|
||||
)
|
||||
finish_status.assert_not_awaited()
|
||||
self.assertEqual(
|
||||
manager._deferred_processing_statuses["session-1"],
|
||||
first.processing_status,
|
||||
)
|
||||
await manager._finish_task_processing_status(task)
|
||||
|
||||
queue.get_nowait()
|
||||
await manager._finish_task_processing_status(
|
||||
session_id="session-1",
|
||||
task=second,
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
finish_status.assert_awaited_once_with(
|
||||
second.processing_status, "10001"
|
||||
)
|
||||
self.assertNotIn("session-1", manager._deferred_processing_statuses)
|
||||
finish_status.assert_awaited_once_with(status, "10001")
|
||||
self.assertIsNone(task.processing_status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_agent_manager_closes_deferred_typing_when_next_task_has_no_status(self):
|
||||
def test_agent_worker_starts_and_finishes_each_queued_task(self):
|
||||
async def _run():
|
||||
manager = AgentManager()
|
||||
queue = asyncio.Queue()
|
||||
first = _MessageTask(
|
||||
manager._session_queues["session-1"] = asyncio.Queue()
|
||||
first_status = {
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing", "seq": 1},
|
||||
}
|
||||
second_status = {
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing", "seq": 2},
|
||||
}
|
||||
await manager._session_queues["session-1"].put(_MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第一条",
|
||||
processing_status={
|
||||
"channel": MessageChannel.Telegram.value,
|
||||
"source": "telegram-test",
|
||||
"userid": "10001",
|
||||
"chat_id": "-100",
|
||||
"metadata": {"kind": "typing"},
|
||||
},
|
||||
)
|
||||
second = _MessageTask(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
original_chat_id="-100",
|
||||
))
|
||||
await manager._session_queues["session-1"].put(_MessageTask(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
message="第二条",
|
||||
processing_status=None,
|
||||
)
|
||||
await queue.put(second)
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
original_chat_id="-100",
|
||||
))
|
||||
|
||||
with patch(
|
||||
"app.agent._async_start_processing_status",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[first_status, second_status],
|
||||
) as start_status, patch.object(
|
||||
manager,
|
||||
"_process_message_internal",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.agent._async_finish_processing_status",
|
||||
new_callable=AsyncMock,
|
||||
) as finish_status:
|
||||
await manager._finish_task_processing_status(
|
||||
session_id="session-1",
|
||||
task=first,
|
||||
queue=queue,
|
||||
)
|
||||
queue.get_nowait()
|
||||
await manager._finish_task_processing_status(
|
||||
session_id="session-1",
|
||||
task=second,
|
||||
queue=queue,
|
||||
manager._session_workers["session-1"] = asyncio.create_task(
|
||||
manager._session_worker("session-1")
|
||||
)
|
||||
await manager._session_queues["session-1"].join()
|
||||
manager._session_workers["session-1"].cancel()
|
||||
await manager._session_workers["session-1"]
|
||||
|
||||
finish_status.assert_awaited_once_with(
|
||||
first.processing_status, "10001"
|
||||
self.assertEqual(start_status.await_count, 2)
|
||||
self.assertEqual(
|
||||
finish_status.await_args_list[0].args,
|
||||
(first_status, "10001"),
|
||||
)
|
||||
self.assertEqual(
|
||||
finish_status.await_args_list[1].args,
|
||||
(second_status, "10001"),
|
||||
)
|
||||
self.assertNotIn("session-1", manager._deferred_processing_statuses)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user