diff --git a/app/chain/message.py b/app/chain/message.py index 4600e2eb..7127a9e5 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -522,6 +522,7 @@ class MessageChain(ChainBase): username=username, original_message_id=original_message_id, original_chat_id=original_chat_id, + processing_status=processing_status, ): return True @@ -1168,6 +1169,7 @@ class MessageChain(ChainBase): images: Optional[List[CommingMessage.MessageImage]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, session_id: Optional[str] = None, + processing_status: Optional[_ProcessingStatus] = None, ) -> bool: """ 处理AI智能体消息 @@ -1267,20 +1269,28 @@ class MessageChain(ChainBase): ) return False + process_kwargs = { + "session_id": session_id, + "user_id": str(userid), + "message": user_message, + "images": images, + "files": prepared_files, + "channel": channel.value if channel else None, + "source": source, + "username": username, + "original_message_id": str(original_message_id) + if original_message_id + else None, + "original_chat_id": original_chat_id, + } + # 回调消息的处理状态已由入口层创建,需要交给 Agent worker 结束; + # 普通 Agent 消息仍不传入,让 worker 在真正开始处理时自行启动状态。 + if processing_status: + process_kwargs["processing_status"] = processing_status.to_dict() + # 在事件循环中处理 asyncio.run_coroutine_threadsafe( - agent_manager.process_message( - session_id=session_id, - user_id=str(userid), - message=user_message, - images=images, - files=prepared_files, - channel=channel.value if channel else None, - source=source, - username=username, - original_message_id=str(original_message_id) if original_message_id else None, - original_chat_id=original_chat_id, - ), + agent_manager.process_message(**process_kwargs), global_vars.loop, ) return True diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py index d4e45b4a..67558699 100644 --- a/tests/test_agent_interaction.py +++ b/tests/test_agent_interaction.py @@ -1,6 +1,6 @@ import asyncio import unittest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch from app.agent.prompt import prompt_manager from app.agent.tools.factory import MoviePilotToolFactory @@ -13,6 +13,7 @@ from app.helper.interaction import ( agent_interaction_manager, ) from app.chain.message import MessageChain +from app.core.config import settings from app.schemas.types import MessageChannel @@ -149,12 +150,20 @@ class TestAgentInteraction(unittest.TestCase): ], ) - with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object( + with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object( chain.messagehelper, "put" - ) as message_put, patch.object(chain.messageoper, "add") as message_add, patch.object( + ) as message_put, patch.object( + chain.messageoper, "add" + ) as message_add, patch.object( chain, "edit_message", return_value=True - ) as edit_message: - chain._handle_callback( + ) as edit_message, patch( + "app.chain.message.agent_manager.process_message", + new_callable=AsyncMock, + ) as process_message, patch( + "app.chain.message.asyncio.run_coroutine_threadsafe", + side_effect=lambda coro, _loop: (coro.close(), Mock())[1], + ): + handled = chain._handle_callback( text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1", channel=MessageChannel.Telegram, source="telegram-test", @@ -164,7 +173,7 @@ class TestAgentInteraction(unittest.TestCase): original_chat_id="456", ) - handle_ai_message.assert_called_once() + self.assertTrue(handled) edit_message.assert_called_once_with( channel=MessageChannel.Telegram, source="telegram-test", @@ -173,9 +182,13 @@ class TestAgentInteraction(unittest.TestCase): title="需要你的选择", text="请选择\n\n已选择:电影", ) - kwargs = handle_ai_message.call_args.kwargs - self.assertEqual(kwargs["text"], "我选择电影") + process_message.assert_called_once() + kwargs = process_message.call_args.kwargs + self.assertEqual(kwargs["message"], "我选择电影") self.assertEqual(kwargs["session_id"], "session-choice") + self.assertEqual(kwargs["channel"], MessageChannel.Telegram.value) + self.assertEqual(kwargs["source"], "telegram-test") + self.assertNotIn("processing_status", kwargs) message_put.assert_called_once() message_add.assert_called_once()