diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index b69ce501..b2a378d6 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -12,7 +12,7 @@ from app.schemas.message import ( ChannelCapabilityManager, ChannelCapability, ) -from app.schemas.types import MessageChannel +from app.schemas.types import MessageChannel, NotificationType class _StreamChain(ChainBase): @@ -212,22 +212,11 @@ class StreamingHandler: await self._flush() message_response = self._message_response - if ( - message_response - and message_response.channel == MessageChannel.Feishu - and isinstance(message_response.metadata, dict) - ): - stream_meta = message_response.metadata.get("feishu_streaming") or {} - card_id = str(stream_meta.get("card_id") or "").strip() - sequence = int(stream_meta.get("sequence") or 1) + 1 - if card_id: - await run_in_threadpool( - _StreamChain().run_module, - "close_feishu_streaming_card", - card_id=card_id, - sequence=sequence, - source=message_response.source, - ) + if message_response: + await run_in_threadpool( + _StreamChain().finalize_message, + message_response, + ) # 检查是否所有缓冲内容都已发送 with self._lock: @@ -480,6 +469,7 @@ class StreamingHandler: Notification( channel=self._channel, source=self._source, + mtype=NotificationType.Agent, userid=self._user_id, username=self._username, title=self._title, @@ -522,6 +512,7 @@ class StreamingHandler: Notification( channel=self._channel, source=self._source, + mtype=NotificationType.Agent, userid=self._user_id, username=self._username, title=self._title, diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 3a50bbee..1690b1a0 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1585,6 +1585,16 @@ class ChainBase(metaclass=ABCMeta): message=self._normalize_notification_for_dispatch(message), ) + def finalize_message( + self, + response: MessageResponse, + ) -> bool: + """ + 对已发送消息执行渠道收尾动作。 + 例如关闭流式卡片状态;无特殊收尾的渠道直接返回 False。 + """ + return self.run_module("finalize_message", response=response) + def metadata_img( self, mediainfo: MediaInfo, diff --git a/app/modules/feishu/__init__.py b/app/modules/feishu/__init__.py index 320b85af..0867a6b4 100644 --- a/app/modules/feishu/__init__.py +++ b/app/modules/feishu/__init__.py @@ -305,16 +305,18 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]): return False return client.delete_message_reaction(message_id=message_id, reaction_id=reaction_id) - def close_feishu_streaming_card( - self, - card_id: str, - sequence: int, - source: str, - ) -> bool: - client_config = self.get_config(source) + def finalize_message(self, response: MessageResponse) -> bool: + if response.channel != self._channel or not isinstance(response.metadata, dict): + return False + stream_meta = response.metadata.get("feishu_streaming") or {} + card_id = str(stream_meta.get("card_id") or "").strip() + if not card_id: + return False + client_config = self.get_config(response.source) if not client_config: return False client = self.get_instance(client_config.name) if not client: return False + sequence = int(stream_meta.get("sequence") or 1) + 1 return client.close_streaming_card(card_id=card_id, sequence=sequence) diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index de427019..2e8f9631 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -12,7 +12,7 @@ from app.agent.tools.base import MoviePilotTool from app.api.endpoints.openai import _OpenAIStreamingHandler from app.core.config import settings from app.schemas.message import MessageResponse -from app.schemas.types import MessageChannel +from app.schemas.types import MessageChannel, NotificationType class DummyTool(MoviePilotTool): @@ -159,6 +159,10 @@ class TestAgentToolStreaming(unittest.TestCase): self.assertEqual( run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message" ) + self.assertEqual( + run_in_threadpool_mock.await_args.args[1].mtype, + NotificationType.Agent, + ) self.assertTrue(handler.has_sent_message) def test_flush_edits_message_via_threadpool(self): @@ -188,6 +192,38 @@ class TestAgentToolStreaming(unittest.TestCase): ) self.assertEqual(handler._sent_text, "hello world") + def test_stop_streaming_uses_generic_finalize_message(self): + handler = StreamingHandler() + handler._message_response = MessageResponse( + message_id="om_stream", + chat_id="oc_stream", + channel=MessageChannel.Feishu, + source="feishu-main", + metadata={"feishu_streaming": {"card_id": "card_stream", "sequence": 2}}, + success=True, + ) + handler._sent_text = "hello" + handler._buffer = "hello" + handler._streaming_enabled = True + + with patch( + "app.agent.callback.run_in_threadpool", new_callable=AsyncMock + ) as run_in_threadpool_mock, patch.object( + handler, "_cancel_flush_task", new_callable=AsyncMock + ), patch.object( + handler, "_flush", new_callable=AsyncMock + ): + asyncio.run(handler.stop_streaming()) + + self.assertEqual(run_in_threadpool_mock.await_count, 1) + self.assertEqual( + run_in_threadpool_mock.await_args.args[0].__name__, "finalize_message" + ) + self.assertEqual( + run_in_threadpool_mock.await_args.args[1].message_id, + "om_stream", + ) + def test_flush_without_channel_context_does_not_send_direct_message(self): handler = StreamingHandler() handler._streaming_enabled = True diff --git a/tests/test_feishu.py b/tests/test_feishu.py index 4eb58c7a..b869fa14 100644 --- a/tests/test_feishu.py +++ b/tests/test_feishu.py @@ -21,7 +21,7 @@ if "Pinyin2Hanzi" not in sys.modules: from app.modules.feishu import FeishuModule from app.modules.feishu.feishu import Feishu from app.schemas import Notification -from app.schemas.message import ChannelCapability, ChannelCapabilityManager +from app.schemas.message import ChannelCapability, ChannelCapabilityManager, MessageResponse from app.schemas.types import MessageChannel, NotificationType @@ -586,15 +586,30 @@ class TestFeishu(unittest.TestCase): self.assertEqual(reaction_id, "reaction_2") self.assertTrue(deleted) - def test_module_close_streaming_card_delegates_to_client(self): + def test_module_finalize_message_closes_streaming_card(self): module = FeishuModule() + module._channel = MessageChannel.Feishu client = MagicMock() client.close_streaming_card.return_value = True with patch.object(module, "get_config", return_value=SimpleNamespace(name="feishu-main")), patch.object( module, "get_instance", return_value=client ): - success = module.close_feishu_streaming_card("card_stream", 4, "feishu-main") + success = module.finalize_message( + MessageResponse( + message_id="om_stream", + chat_id="oc_stream", + channel=MessageChannel.Feishu, + source="feishu-main", + metadata={ + "feishu_streaming": { + "card_id": "card_stream", + "sequence": 3, + } + }, + success=True, + ) + ) self.assertTrue(success) client.close_streaming_card.assert_called_once_with(card_id="card_stream", sequence=4)