From 6fb6996d81c0f6be3b811b034cad442310fcfe51 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 13 May 2026 10:11:31 +0800 Subject: [PATCH] fix: ensure stop_streaming waits for inflight initial flush before final edit; improve message edit/delete return types and logging - Update stop_streaming logic to await inflight initial flush task, preventing duplicate message sends on stream stop - Change message edit/delete methods to return Optional[bool] for clearer channel mismatch handling - Refine Feishu logging to include message_id instead of full data object - Suppress allowed_objects warnings in __init__.py - Add test to verify stop_streaming waits for inflight flush before final edit - Update .pylintrc to use 'E' for error enabling --- .pylintrc | 2 +- app/agent/__init__.py | 4 +++ app/agent/callback/__init__.py | 16 ++++++--- app/modules/discord/__init__.py | 10 +++--- app/modules/feishu/__init__.py | 6 ++-- app/modules/feishu/feishu.py | 10 ++++-- app/modules/slack/__init__.py | 10 +++--- app/modules/telegram/__init__.py | 10 +++--- tests/test_agent_tool_streaming.py | 57 ++++++++++++++++++++++++++++++ 9 files changed, 100 insertions(+), 25 deletions(-) diff --git a/.pylintrc b/.pylintrc index f59dd614..ee30bbd8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -18,7 +18,7 @@ jobs=0 # 禁用大部分警告、约定和重构建议,只保留错误和重要警告 disable=all -enable=error, +enable=E, syntax-error, undefined-variable, used-before-assignment, diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 9af5b39a..0db945ca 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -16,6 +16,10 @@ from langchain_core.messages import ( # noqa: F401 HumanMessage, BaseMessage, ) + +import warnings +warnings.filterwarnings("ignore", message=".*allowed_objects.*") + from langgraph.checkpoint.memory import InMemorySaver from app.agent.callback import StreamingHandler diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index cbcb1d10..652ee8b7 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -437,15 +437,23 @@ class StreamingHandler: async def _cancel_flush_task(self): """ - 取消当前的定时刷新任务 + 停止当前的定时刷新任务。 + + 停止流式输出时,刷新任务可能已经在线程池里发出了首条消息。 + 这里先等待该轮刷新自然完成,确保 message_id 等返回信息能落回本地状态; + 否则最终刷新会误以为尚未发送过消息,从而再次发送一条新消息。 """ - if self._flush_task and not self._flush_task.done(): - self._flush_task.cancel() + current_task = asyncio.current_task() + if ( + self._flush_task + and not self._flush_task.done() + and self._flush_task is not current_task + ): try: await self._flush_task except asyncio.CancelledError: pass - self._flush_task = None + self._flush_task = None async def _flush(self): """ diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index 7b7c33a6..a112b776 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -409,7 +409,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): source: str, message_id: str, chat_id: Optional[str] = None, - ) -> bool: + ) -> Optional[bool]: """ 删除消息 :param channel: 消息渠道 @@ -418,10 +418,10 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): :param chat_id: 聊天ID(频道ID) :return: 删除是否成功 """ + if channel != self._channel: + return None success = False for conf in self.get_configs().values(): - if channel != self._channel: - break if source != conf.name: continue client: Discord = self.get_instance(conf.name) @@ -441,7 +441,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, metadata: Optional[dict] = None, - ) -> bool: + ) -> Optional[bool]: """ 编辑消息 :param channel: 消息渠道 @@ -454,7 +454,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): :return: 编辑是否成功 """ if channel != self._channel: - return False + return None for conf in self.get_configs().values(): if source != conf.name: continue diff --git a/app/modules/feishu/__init__.py b/app/modules/feishu/__init__.py index 3d4d55b6..34c69725 100644 --- a/app/modules/feishu/__init__.py +++ b/app/modules/feishu/__init__.py @@ -161,9 +161,9 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]): title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, metadata: Optional[dict] = None, - ) -> bool: + ) -> Optional[bool]: if channel != self._channel: - return False + return None for conf in self.get_configs().values(): if source != conf.name: continue @@ -296,7 +296,7 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]): message_id: str, reaction_id: str, source: str, - ) -> bool: + ) -> Optional[bool]: client_config = self.get_config(source) if not client_config: return False diff --git a/app/modules/feishu/feishu.py b/app/modules/feishu/feishu.py index 1be954b1..caedb42b 100644 --- a/app/modules/feishu/feishu.py +++ b/app/modules/feishu/feishu.py @@ -837,7 +837,10 @@ class Feishu: return None data = getattr(response, "data", None) - logger.info(f"_send_message 飞书回复消息成功:data={data}") + logger.info( + "_send_message 飞书回复消息成功:message_id=%s", + getattr(data, "message_id", None), + ) return { "success": True, "message_id": getattr(data, "message_id", None), @@ -880,7 +883,10 @@ class Feishu: return None data = getattr(response, "data", None) - logger.info(f"_reply_message 飞书回复消息成功:data={data}") + logger.info( + "_reply_message 飞书回复消息成功:message_id=%s", + getattr(data, "message_id", None), + ) return { "success": True, "message_id": getattr(data, "message_id", None), diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 5405fafd..90e74796 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -527,7 +527,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): source: str, message_id: str, chat_id: Optional[str] = None, - ) -> bool: + ) -> Optional[bool]: """ 删除消息 :param channel: 消息渠道 @@ -536,10 +536,10 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): :param chat_id: 聊天ID(频道ID) :return: 删除是否成功 """ + if channel != self._channel: + return None success = False for conf in self.get_configs().values(): - if channel != self._channel: - break if source != conf.name: continue client: Slack = self.get_instance(conf.name) @@ -559,7 +559,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, metadata: Optional[dict] = None, - ) -> bool: + ) -> Optional[bool]: """ 编辑消息 :param channel: 消息渠道 @@ -572,7 +572,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): :return: 编辑是否成功 """ if channel != self._channel: - return False + return None for conf in self.get_configs().values(): if source != conf.name: continue diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 9810cbc0..183f0675 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -534,7 +534,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): source: str, message_id: int, chat_id: Optional[int] = None, - ) -> bool: + ) -> Optional[bool]: """ 删除消息 :param channel: 消息渠道 @@ -543,10 +543,10 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): :param chat_id: 聊天ID :return: 删除是否成功 """ + if channel != self._channel: + return None success = False for conf in self.get_configs().values(): - if channel != self._channel: - break if source != conf.name: continue client: Telegram = self.get_instance(conf.name) @@ -566,7 +566,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, metadata: Optional[dict] = None, - ) -> bool: + ) -> Optional[bool]: """ 编辑消息 :param channel: 消息渠道 @@ -579,7 +579,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): :return: 编辑是否成功 """ if channel != self._channel: - return False + return None for conf in self.get_configs().values(): if source != conf.name: continue diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index 11fe9dc6..1a13a998 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -192,6 +192,63 @@ class TestAgentToolStreaming(unittest.TestCase): ) self.assertEqual(handler._sent_text, "hello world") + def test_stop_streaming_waits_inflight_initial_flush_before_final_edit(self): + async def _run(): + handler = StreamingHandler() + handler._channel = MessageChannel.Feishu.value + handler._source = "feishu-main" + handler._user_id = "ou_user" + handler._streaming_enabled = True + handler.emit("hello") + + send_started = asyncio.Event() + allow_send_finish = asyncio.Event() + calls = [] + + async def fake_run_in_threadpool(func, *args, **kwargs): + calls.append((func.__name__, args, kwargs)) + if func.__name__ == "send_direct_message": + send_started.set() + await allow_send_finish.wait() + return MessageResponse( + message_id="om_stream", + chat_id="oc_stream", + channel=MessageChannel.Feishu, + source="feishu-main", + success=True, + ) + return True + + with patch( + "app.agent.callback.run_in_threadpool", + new=fake_run_in_threadpool, + ): + # 模拟定时刷新已经开始发送首条消息,但飞书 API 尚未返回。 + handler._flush_task = asyncio.create_task(handler._flush()) + await send_started.wait() + handler.emit(" world") + + stop_task = asyncio.create_task(handler.stop_streaming()) + await asyncio.sleep(0) + self.assertFalse(stop_task.done()) + + allow_send_finish.set() + all_sent, final_text = await stop_task + + return all_sent, final_text, calls + + all_sent, final_text, calls = asyncio.run(_run()) + + self.assertTrue(all_sent) + self.assertEqual(final_text, "hello world") + self.assertEqual( + [call[0] for call in calls], + ["send_direct_message", "edit_message", "finalize_message"], + ) + edit_kwargs = calls[1][2] + self.assertEqual(edit_kwargs["message_id"], "om_stream") + self.assertEqual(edit_kwargs["text"], "hello world") + def test_stop_streaming_uses_generic_finalize_message(self): handler = StreamingHandler() handler._message_response = MessageResponse(