fix: simplify message typing lifecycle

This commit is contained in:
jxxghp
2026-05-23 00:11:56 +08:00
parent cde267c55f
commit a74f04a149
7 changed files with 336 additions and 236 deletions

View File

@@ -58,22 +58,36 @@ def _finish_processing_status(status: Optional[dict], user_id: Optional[str] = N
"""结束入站消息的渠道处理状态。"""
if not status:
return
try:
channel = MessageChannel(status.get("channel"))
except Exception:
return
try:
AgentChain().run_module(
"mark_message_processing_finished",
channel=channel,
source=status.get("source"),
userid=status.get("userid") or user_id,
message_id=status.get("message_id"),
chat_id=status.get("chat_id"),
status=status,
)
except Exception as err:
logger.debug(f"结束Agent消息处理状态失败: {err}")
AgentChain().finish_message_processing_status(
status=status,
userid=user_id,
)
async def _async_start_processing_status(task: "_MessageTask") -> Optional[dict]:
"""
在 Agent worker 中启动渠道处理状态。
渠道启动可能触发外部 API同步实现需切到线程池避免阻塞事件循环。
"""
if not task.channel:
return None
def _start() -> Optional[dict]:
"""在线程池中通过统一 Chain 接口启动处理状态。"""
try:
return AgentChain().start_message_processing_status(
channel=MessageChannel(task.channel),
source=task.source,
userid=task.user_id,
message_id=task.original_message_id,
chat_id=task.original_chat_id,
text=task.message,
)
except Exception as err:
logger.debug(f"启动Agent消息处理状态失败: {err}")
return None
return await run_in_threadpool(_start)
async def _async_finish_processing_status(
@@ -952,8 +966,6 @@ class AgentManager:
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# typing 这类状态按会话/聊天共享,前一条任务结束时可能仍需延续到后续排队消息。
self._deferred_processing_statuses: Dict[str, dict] = {}
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
@@ -1009,7 +1021,6 @@ class AgentManager:
pass
self._session_workers.clear()
self._session_queues.clear()
self._deferred_processing_statuses.clear()
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
@@ -1026,7 +1037,6 @@ class AgentManager:
username: str = None,
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
processing_status: Optional[dict] = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
) -> str:
"""
@@ -1044,7 +1054,6 @@ class AgentManager:
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
reply_mode=reply_mode,
)
@@ -1099,15 +1108,12 @@ class AgentManager:
break
try:
await self._start_task_processing_status(task)
await self._process_message_internal(task)
except Exception as e:
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
finally:
await self._finish_task_processing_status(
session_id=session_id,
task=task,
queue=queue,
)
await self._finish_task_processing_status(task)
queue.task_done()
except asyncio.CancelledError:
@@ -1121,52 +1127,23 @@ class AgentManager:
and self._session_queues[session_id].empty()
):
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
@staticmethod
def _is_shared_processing_status(status: Optional[dict]) -> bool:
async def _start_task_processing_status(task: _MessageTask) -> None:
"""
判断状态是否属于同一聊天窗口共享的处理提示
reaction 绑定到具体消息应按消息收口typing 绑定到会话/聊天,需要等队列空闲再关闭。
在 Agent worker 真正开始处理消息时启动渠道处理状态
"""
metadata = (status or {}).get("metadata") or {}
return isinstance(metadata, dict) and metadata.get("kind") == "typing"
async def _finish_task_processing_status(
self,
session_id: str,
task: _MessageTask,
queue: asyncio.Queue,
) -> None:
"""
根据会话队列状态结束或延后处理提示。
当后面还有排队消息时typing 状态继续保留;队列真正空闲后再统一关闭。
"""
status = task.processing_status
if self._is_shared_processing_status(status) and not queue.empty():
self._deferred_processing_statuses[session_id] = status
if task.processing_status:
return
task.processing_status = await _async_start_processing_status(task)
if status:
await _async_finish_processing_status(status, task.user_id)
if self._is_shared_processing_status(status):
self._deferred_processing_statuses.pop(session_id, None)
elif queue.empty():
deferred_status = self._deferred_processing_statuses.pop(
session_id, None
)
if deferred_status:
await _async_finish_processing_status(
deferred_status, task.user_id
)
return
if not queue.empty():
return
deferred_status = self._deferred_processing_statuses.pop(session_id, None)
if deferred_status:
await _async_finish_processing_status(deferred_status, task.user_id)
@staticmethod
async def _finish_task_processing_status(task: _MessageTask) -> None:
"""
在 Agent worker 完成或异常后结束本条消息的渠道处理状态。
"""
await _async_finish_processing_status(task.processing_status, task.user_id)
task.processing_status = None
async def _process_message_internal(self, task: _MessageTask):
"""
@@ -1232,7 +1209,6 @@ class AgentManager:
break
self._session_queues.pop(session_id, None)
stopped = True
self._deferred_processing_statuses.pop(session_id, None)
if stopped:
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
@@ -1256,7 +1232,6 @@ class AgentManager:
# 清理队列
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
# 清理agent
if session_id in self.active_agents:

View File

@@ -41,6 +41,7 @@ from app.schemas import (
MessageResponse,
)
from app.utils.identity import normalize_internal_user_id
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
from app.schemas.category import CategoryConfig
from app.schemas.types import (
TorrentStatus,
@@ -122,6 +123,74 @@ class ChainBase(metaclass=ABCMeta):
"""
self.filecache.delete(filename)
def start_message_processing_status(
self,
channel: MessageChannel,
source: Optional[str],
userid: Optional[Union[str, int]] = None,
message_id: Optional[Union[str, int]] = None,
chat_id: Optional[Union[str, int]] = None,
text: Optional[str] = None,
) -> Optional[dict]:
"""
启动渠道侧消息输入/处理状态。
具体表现由消息模块实现,例如 typing 保活或消息 reaction。
"""
if not channel or not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
return None
try:
status = self.run_module(
"mark_message_processing_started",
channel=channel,
source=source,
userid=userid,
message_id=message_id,
chat_id=chat_id,
text=text,
)
except Exception as err:
logger.debug(f"启动消息处理状态失败: {err}")
return None
return status if isinstance(status, dict) else None
def finish_message_processing_status(
self,
status: Optional[dict] = None,
channel: Optional[MessageChannel] = None,
source: Optional[str] = None,
userid: Optional[Union[str, int]] = None,
message_id: Optional[Union[str, int]] = None,
chat_id: Optional[Union[str, int]] = None,
) -> None:
"""
结束渠道侧消息输入/处理状态。
优先使用 start 返回的 status缺失时使用显式渠道和消息定位参数。
"""
target_channel = channel
if status:
try:
target_channel = MessageChannel(status.get("channel"))
except Exception:
target_channel = channel
if not target_channel or not ChannelCapabilityManager.supports_capability(
target_channel, ChannelCapability.PROCESSING_STATUS
):
return
try:
self.run_module(
"mark_message_processing_finished",
channel=target_channel,
source=(status or {}).get("source") or source,
userid=(status or {}).get("userid") or userid,
message_id=(status or {}).get("message_id") or message_id,
chat_id=(status or {}).get("chat_id") or chat_id,
status=status,
)
except Exception as err:
logger.debug(f"结束消息处理状态失败: {err}")
@staticmethod
def _normalize_notification_for_dispatch(
message: Notification

View File

@@ -137,14 +137,7 @@ class MessageChain(ChainBase):
"""
images = CommingMessage.MessageImage.normalize_list(images)
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
userid=userid,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
text=text,
)
processing_status = None
continues_async = False
try:
# 语音输入只用于转写为文本,不默认改变回复形式。
@@ -181,6 +174,23 @@ class MessageChain(ChainBase):
text=text,
)
if not self._is_agent_message(
channel=channel,
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
):
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
userid=userid,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
text=text,
)
continues_async = self._handle_message_core(
channel=channel,
source=source,
@@ -310,7 +320,6 @@ class MessageChain(ChainBase):
original_chat_id=original_chat_id,
images=images,
files=files,
processing_status=processing_status,
)
if (
@@ -327,7 +336,6 @@ class MessageChain(ChainBase):
original_chat_id=original_chat_id,
images=images,
files=files,
processing_status=processing_status,
)
if MediaInteractionChain().handle_text_interaction(
@@ -350,6 +358,35 @@ class MessageChain(ChainBase):
)
return False
def _is_agent_message(
self,
channel: MessageChannel,
userid: Union[str, int],
text: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
has_audio_input: bool = False,
) -> bool:
"""
判断本条消息是否会进入 Agent worker由 Agent worker 管理 typing 生命周期。
"""
if text.startswith("CALLBACK:"):
return self._parse_agent_choice_callback(text[9:]) is not None
if text.lower().startswith("/ai"):
return True
if text.startswith("/"):
return False
if not (
settings.AI_AGENT_ENABLE
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
):
return False
if self._get_latest_slash_interaction(userid):
return False
if media_interaction_manager.get_by_user(userid):
return False
return True
def _mark_message_processing_started(
self,
channel: MessageChannel,
@@ -360,27 +397,17 @@ class MessageChain(ChainBase):
text: str,
) -> Optional[_ProcessingStatus]:
"""为支持的渠道标记“消息正在处理”。"""
if not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
status = self.start_message_processing_status(
channel=channel,
source=source,
userid=userid,
message_id=original_message_id,
chat_id=original_chat_id,
text=text,
)
if not status:
return None
try:
status = self.run_module(
"mark_message_processing_started",
channel=channel,
source=source,
userid=userid,
message_id=original_message_id,
chat_id=original_chat_id,
text=text,
)
except Exception as err:
logger.debug(f"标记消息处理状态失败: {err}")
return None
if not isinstance(status, dict):
return None
metadata = status.get("metadata")
return self._ProcessingStatus(
channel=channel,
@@ -404,22 +431,16 @@ class MessageChain(ChainBase):
结束渠道侧“消息正在处理”状态。
不同渠道的表现可能是 reaction、typing 等,消息链只负责调用通用模块接口。
"""
if not status and not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
if not status:
return
try:
self.run_module(
"mark_message_processing_finished",
channel=channel,
source=source,
userid=userid,
message_id=status.message_id if status else original_message_id,
chat_id=status.chat_id if status else original_chat_id,
status=status.to_dict() if status else None,
)
except Exception as err:
logger.debug(f"结束消息处理状态失败: {err}")
self.finish_message_processing_status(
status=status.to_dict(),
channel=channel,
source=source,
userid=userid,
message_id=status.message_id or original_message_id,
chat_id=status.chat_id or original_chat_id,
)
def _handle_callback(
self,
@@ -501,7 +522,6 @@ class MessageChain(ChainBase):
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
):
return True
@@ -1148,7 +1168,6 @@ class MessageChain(ChainBase):
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
session_id: Optional[str] = None,
processing_status: Optional[_ProcessingStatus] = None,
) -> bool:
"""
处理AI智能体消息
@@ -1261,9 +1280,6 @@ class MessageChain(ChainBase):
username=username,
original_message_id=str(original_message_id) if original_message_id else None,
original_chat_id=original_chat_id,
processing_status=processing_status.to_dict()
if processing_status
else None,
),
global_vars.loop,
)

View File

@@ -34,22 +34,10 @@ def _finish_command_processing_status(status: Optional[dict], user_id: Optional[
"""
if not status:
return
try:
channel = MessageChannel(status.get("channel"))
except Exception:
return
try:
CommandChain().run_module(
"mark_message_processing_finished",
channel=channel,
source=status.get("source"),
userid=status.get("userid") or user_id,
message_id=status.get("message_id"),
chat_id=status.get("chat_id"),
status=status,
)
except Exception as err:
logger.debug(f"结束命令消息处理状态失败: {err}")
CommandChain().finish_message_processing_status(
status=status,
userid=user_id,
)
class Command(metaclass=Singleton):

View File

@@ -585,14 +585,12 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
continue
client: Telegram = self.get_instance(conf.name)
if client:
stop_typing = not (metadata or {}).get("agent_managed_typing")
result = client.edit_msg(
chat_id=chat_id,
message_id=message_id,
text=text,
title=title,
buttons=buttons,
stop_typing=stop_typing,
)
if result:
return True

View File

@@ -138,14 +138,9 @@ class TestMessageProcessingStatus(unittest.TestCase):
with patch("app.agent.AgentChain") as chain_cls:
_finish_processing_status(status, user_id="fallback")
chain_cls.return_value.run_module.assert_called_once_with(
"mark_message_processing_finished",
channel=MessageChannel.Telegram,
source="telegram-main",
userid="10001",
message_id=None,
chat_id="-100",
chain_cls.return_value.finish_message_processing_status.assert_called_once_with(
status=status,
userid="fallback",
)

View File

@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, Mock, patch
sys.modules.setdefault("app.helper.sites", ModuleType("app.helper.sites"))
setattr(sys.modules["app.helper.sites"], "SitesHelper", object)
from app.agent import AgentManager, _MessageTask
from app.agent import AgentManager, _MessageTask, _async_start_processing_status
from app.chain.message import MessageChain
from app.command import Command, _finish_command_processing_status
from app.modules.telegram import TelegramModule
@@ -202,29 +202,25 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
with patch("app.command.CommandChain") as chain_cls:
_finish_command_processing_status(status, user_id="fallback")
chain_cls.return_value.run_module.assert_called_once_with(
"mark_message_processing_finished",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
message_id=None,
chat_id="-100",
chain_cls.return_value.finish_message_processing_status.assert_called_once_with(
status=status,
userid="fallback",
)
def test_async_agent_keeps_processing_status_for_worker(self):
def test_async_agent_leaves_processing_status_to_worker(self):
chain = MessageChain.__new__(MessageChain)
status = MessageChain._ProcessingStatus(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
chat_id="-100",
metadata={"kind": "typing"},
)
with patch.object(chain, "_record_user_message"), patch.object(
chain, "_mark_message_processing_started", return_value=status
), patch.object(chain, "_handle_message_core", return_value=True), patch.object(
chain, "_mark_message_processing_started"
) as start_status, patch(
"app.chain.message.settings.AI_AGENT_ENABLE", True
), patch(
"app.chain.message.agent_manager.process_message",
new_callable=AsyncMock,
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
), patch.object(
chain, "_mark_message_processing_finished"
) as finish_status:
chain.handle_message(
@@ -236,7 +232,83 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
original_chat_id="-100",
)
start_status.assert_not_called()
finish_status.assert_not_called()
process_message.assert_called_once()
self.assertNotIn("processing_status", process_message.call_args.kwargs)
self.assertEqual(
process_message.call_args.kwargs["channel"],
MessageChannel.Telegram.value,
)
self.assertEqual(process_message.call_args.kwargs["source"], "telegram-test")
self.assertEqual(process_message.call_args.kwargs["original_chat_id"], "-100")
def test_agent_manager_starts_processing_status_when_task_runs(self):
async def _run():
manager = AgentManager()
task = _MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
channel=MessageChannel.Telegram.value,
source="telegram-test",
original_chat_id="-100",
)
status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
}
with patch(
"app.agent._async_start_processing_status",
new_callable=AsyncMock,
return_value=status,
) as start_status:
await manager._start_task_processing_status(task)
start_status.assert_awaited_once_with(task)
self.assertEqual(task.processing_status, status)
asyncio.run(_run())
def test_agent_start_processing_status_uses_chain_interface(self):
async def _run():
task = _MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
channel=MessageChannel.Telegram.value,
source="telegram-test",
original_message_id="10",
original_chat_id="-100",
)
status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"message_id": "10",
"chat_id": "-100",
"metadata": {"kind": "typing"},
}
with patch("app.agent.AgentChain") as chain_cls:
chain_cls.return_value.start_message_processing_status.return_value = status
result = await _async_start_processing_status(task)
chain_cls.return_value.start_message_processing_status.assert_called_once_with(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
message_id="10",
chat_id="-100",
text="第一条",
)
self.assertEqual(result, status)
asyncio.run(_run())
def test_callback_stops_typing_when_message_handler_returns(self):
chain = MessageChain.__new__(MessageChain)
@@ -281,7 +353,7 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
metadata={"kind": "typing"},
)
with patch.object(chain, "run_module") as run_module:
with patch.object(chain, "finish_message_processing_status") as finish_status:
chain._mark_message_processing_finished(
channel=MessageChannel.Telegram,
source="telegram-test",
@@ -290,119 +362,106 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
original_chat_id="-100",
)
run_module.assert_called_once_with(
"mark_message_processing_finished",
finish_status.assert_called_once_with(
status=status.to_dict(),
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
message_id=None,
chat_id="-100",
status=status.to_dict(),
)
def test_agent_manager_defers_shared_typing_until_queued_task_finishes(self):
def test_agent_manager_finishes_processing_status_after_each_task(self):
async def _run():
manager = AgentManager()
queue = asyncio.Queue()
first = _MessageTask(
status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
}
task = _MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
processing_status=status,
)
second = _MessageTask(
session_id="session-1",
user_id="10001",
message="第二条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
)
await queue.put(second)
with patch(
"app.agent._async_finish_processing_status",
new_callable=AsyncMock,
) as finish_status:
await manager._finish_task_processing_status(
session_id="session-1",
task=first,
queue=queue,
)
finish_status.assert_not_awaited()
self.assertEqual(
manager._deferred_processing_statuses["session-1"],
first.processing_status,
)
await manager._finish_task_processing_status(task)
queue.get_nowait()
await manager._finish_task_processing_status(
session_id="session-1",
task=second,
queue=queue,
)
finish_status.assert_awaited_once_with(
second.processing_status, "10001"
)
self.assertNotIn("session-1", manager._deferred_processing_statuses)
finish_status.assert_awaited_once_with(status, "10001")
self.assertIsNone(task.processing_status)
asyncio.run(_run())
def test_agent_manager_closes_deferred_typing_when_next_task_has_no_status(self):
def test_agent_worker_starts_and_finishes_each_queued_task(self):
async def _run():
manager = AgentManager()
queue = asyncio.Queue()
first = _MessageTask(
manager._session_queues["session-1"] = asyncio.Queue()
first_status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing", "seq": 1},
}
second_status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing", "seq": 2},
}
await manager._session_queues["session-1"].put(_MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
)
second = _MessageTask(
channel=MessageChannel.Telegram.value,
source="telegram-test",
original_chat_id="-100",
))
await manager._session_queues["session-1"].put(_MessageTask(
session_id="session-1",
user_id="10001",
message="第二条",
processing_status=None,
)
await queue.put(second)
channel=MessageChannel.Telegram.value,
source="telegram-test",
original_chat_id="-100",
))
with patch(
"app.agent._async_start_processing_status",
new_callable=AsyncMock,
side_effect=[first_status, second_status],
) as start_status, patch.object(
manager,
"_process_message_internal",
new_callable=AsyncMock,
), patch(
"app.agent._async_finish_processing_status",
new_callable=AsyncMock,
) as finish_status:
await manager._finish_task_processing_status(
session_id="session-1",
task=first,
queue=queue,
)
queue.get_nowait()
await manager._finish_task_processing_status(
session_id="session-1",
task=second,
queue=queue,
manager._session_workers["session-1"] = asyncio.create_task(
manager._session_worker("session-1")
)
await manager._session_queues["session-1"].join()
manager._session_workers["session-1"].cancel()
await manager._session_workers["session-1"]
finish_status.assert_awaited_once_with(
first.processing_status, "10001"
self.assertEqual(start_status.await_count, 2)
self.assertEqual(
finish_status.await_args_list[0].args,
(first_status, "10001"),
)
self.assertEqual(
finish_status.await_args_list[1].args,
(second_status, "10001"),
)
self.assertNotIn("session-1", manager._deferred_processing_statuses)
asyncio.run(_run())