From baebd0ed1a2b190f3836822c1351d570dce62478 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 30 Apr 2026 06:58:43 +0800 Subject: [PATCH] Fix background prompt message leakage --- app/agent/__init__.py | 6 +++- app/agent/callback/__init__.py | 3 ++ app/agent/tools/base.py | 28 +++++++++------ app/agent/tools/factory.py | 23 +++--------- app/chain/message.py | 2 ++ tests/test_agent_background_output.py | 26 ++++++++++++++ tests/test_agent_tool_streaming.py | 50 +++++++++++++++++++++++---- 7 files changed, 102 insertions(+), 36 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 4a1db4d0..d2be4ed0 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1029,7 +1029,7 @@ class AgentManager: output_callback: Optional[Callable[[str], None]] = None, reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY, persist_output_message: bool = True, - allow_message_tools: bool = True, + allow_message_tools: Optional[bool] = None, ) -> None: """ 以独立后台会话执行一段 prompt。 @@ -1047,6 +1047,10 @@ class AgentManager: agent.force_streaming = bool(output_callback) agent.reply_mode = reply_mode agent.persist_output_message = persist_output_message + if reply_mode == ReplyMode.CAPTURE_ONLY: + allow_message_tools = False + elif allow_message_tools is None: + allow_message_tools = True agent.allow_message_tools = allow_message_tools try: diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 11102d4b..3bfede67 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -435,6 +435,9 @@ class StreamingHandler: if not current_text or current_text == self._sent_text: # 没有新内容需要刷新 return + if not self._channel or not self._source: + logger.debug("流式输出缺少渠道上下文,仅保留 buffer,不外发消息") + return chain = _StreamChain() diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 17757903..a40a69b3 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -113,16 +113,24 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): if tool_message: self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n") else: - # 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送 - agent_message = await self._stream_handler.take() - messages = [] - if agent_message: - messages.append(agent_message) - if tool_message: - messages.append(f"⚙️ => {tool_message}") - if messages: - merged_message = "\n\n".join(messages) - await self.send_tool_message(merged_message) + if self._channel and self._source: + # 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送 + agent_message = await self._stream_handler.take() + messages = [] + if agent_message: + messages.append(agent_message) + if tool_message: + messages.append(f"⚙️ => {tool_message}") + if messages: + merged_message = "\n\n".join(messages) + await self.send_tool_message(merged_message) + else: + # 后台 capture 流程没有渠道上下文,不能把工具提示回灌到默认通知渠道。 + self._stream_handler.record_tool_call( + tool_name=self.name, + tool_message=tool_message, + tool_kwargs=kwargs, + ) else: # 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要 self._stream_handler.record_tool_call( diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 51f6a177..c3d90a93 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -79,15 +79,6 @@ class MoviePilotToolFactory: MoviePilot工具工厂 """ - _MESSAGE_TOOL_CLASSES = frozenset( - { - SendMessageTool, - AskUserChoiceTool, - SendLocalFileTool, - SendVoiceMessageTool, - } - ) - @staticmethod def _should_enable_choice_tool(channel: str = None) -> bool: if not channel: @@ -191,12 +182,9 @@ class MoviePilotToolFactory: ) # 创建内置工具 for ToolClass in tool_definitions: - if ( - not allow_message_tools - and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES - ): - continue tool = ToolClass(session_id=session_id, user_id=user_id) + if not allow_message_tools and getattr(tool, "sends_message", False): + continue tool.set_message_attr(channel=channel, source=source, username=username) tool.set_stream_handler(stream_handler=stream_handler) tool.set_agent_context(agent_context=agent_context) @@ -211,11 +199,6 @@ class MoviePilotToolFactory: tool_classes = plugin_info.get("tools", []) for ToolClass in tool_classes: try: - if ( - not allow_message_tools - and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES - ): - continue # 验证工具类是否继承自 MoviePilotTool if not issubclass(ToolClass, MoviePilotTool): logger.warning( @@ -224,6 +207,8 @@ class MoviePilotToolFactory: continue # 创建工具实例 tool = ToolClass(session_id=session_id, user_id=user_id) + if not allow_message_tools and getattr(tool, "sends_message", False): + continue tool.set_message_attr( channel=channel, source=source, username=username ) diff --git a/app/chain/message.py b/app/chain/message.py index 08120c37..ec070284 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -636,6 +636,8 @@ class MessageChain(ChainBase): session_prefix=f"__agent_manual_redo_{history_id}", output_callback=_capture_output, reply_mode=ReplyMode.CAPTURE_ONLY, + persist_output_message=False, + allow_message_tools=False, ) await self.async_post_message( Notification( diff --git a/tests/test_agent_background_output.py b/tests/test_agent_background_output.py index ff7074e8..a32f7cb3 100644 --- a/tests/test_agent_background_output.py +++ b/tests/test_agent_background_output.py @@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage from app.agent import MoviePilotAgent, AgentManager, ReplyMode from app.agent.memory import memory_manager +from app.utils.identity import SYSTEM_INTERNAL_USER_ID class _FakeGraphState: @@ -124,6 +125,31 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): process_message.await_args.kwargs["reply_mode"], ) + async def test_run_background_prompt_forces_disable_message_tools_when_capture_only(self): + captured = {} + + async def fake_process(self, message, images=None, files=None): + captured["message"] = message + captured["reply_mode"] = self.reply_mode + captured["allow_message_tools"] = self.allow_message_tools + captured["user_id"] = self.user_id + + with ( + patch.object(MoviePilotAgent, "process", new=fake_process), + patch.object(MoviePilotAgent, "cleanup", new=AsyncMock()), + patch.object(memory_manager, "clear_memory"), + ): + await AgentManager.run_background_prompt( + message="background task", + reply_mode=ReplyMode.CAPTURE_ONLY, + allow_message_tools=True, + ) + + self.assertEqual("background task", captured["message"]) + self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"]) + self.assertFalse(captured["allow_message_tools"]) + self.assertEqual(SYSTEM_INTERNAL_USER_ID, captured["user_id"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index 755fba04..617ddeb9 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -42,19 +42,19 @@ class TestAgentToolStreaming(unittest.TestCase): result, buffered_message = asyncio.run(self._run_tool("prefix")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n") + self.assertEqual(buffered_message, "prefix\n\n(调用了 1 次工具)\n\n") def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self): result, buffered_message = asyncio.run(self._run_tool("prefix\n")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n") + self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n\n") def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self): result, buffered_message = asyncio.run(self._run_tool("")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "(调用了 1 次工具)\n") + self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n") def test_non_verbose_tool_summary_is_inserted_before_next_text(self): async def _run(): @@ -74,7 +74,7 @@ class TestAgentToolStreaming(unittest.TestCase): self.assertEqual( buffered_message, - "让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果", + "让我来检查一下:\n\n(调用了 1 次工具)\n\n已经拿到结果", ) def test_non_verbose_tool_summary_aggregates_multiple_categories(self): @@ -109,7 +109,7 @@ class TestAgentToolStreaming(unittest.TestCase): self.assertEqual( buffered_message, - "处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析", + "处理中:\n\n(执行了 2 次搜索,读取了 2 个文件)\n\n继续分析", ) def test_openai_streaming_handler_flushes_pending_summary_to_queue(self): @@ -130,7 +130,7 @@ class TestAgentToolStreaming(unittest.TestCase): emitted, queued, buffered_message = asyncio.run(_run()) - self.assertEqual(emitted, "(读取了 1 个文件)\n") + self.assertEqual(emitted, "(读取了 1 个文件)\n\n") self.assertEqual(queued, emitted) self.assertEqual(buffered_message, emitted) @@ -164,6 +164,7 @@ class TestAgentToolStreaming(unittest.TestCase): def test_flush_edits_message_via_threadpool(self): handler = StreamingHandler() handler._channel = MessageChannel.Telegram.value + handler._source = "telegram" handler._streaming_enabled = True handler._message_response = MessageResponse( message_id=1, @@ -187,6 +188,43 @@ class TestAgentToolStreaming(unittest.TestCase): ) self.assertEqual(handler._sent_text, "hello world") + def test_flush_without_channel_context_does_not_send_direct_message(self): + handler = StreamingHandler() + handler._streaming_enabled = True + handler.emit("hello") + + with patch( + "app.agent.callback.run_in_threadpool", new_callable=AsyncMock + ) as run_in_threadpool_mock: + asyncio.run(handler._flush()) + + run_in_threadpool_mock.assert_not_awaited() + self.assertFalse(handler.has_sent_message) + + def test_verbose_background_tool_call_does_not_post_message(self): + async def _run(): + tool = DummyTool(session_id="session-1", user_id="10001") + handler = StreamingHandler() + await handler.start_streaming() + tool.set_stream_handler(handler) + tool.set_message_attr(channel=None, source=None, username="tester") + + with ( + patch.object(settings, "AI_AGENT_VERBOSE", True), + patch.object( + DummyTool, "send_tool_message", new_callable=AsyncMock + ) as send_tool_message, + ): + result = await tool._arun(explanation="run test tool") + buffered_message = await handler.take() + return result, buffered_message, send_tool_message + + result, buffered_message, send_tool_message = asyncio.run(_run()) + + self.assertEqual(result, "ok") + send_tool_message.assert_not_awaited() + self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n") + if __name__ == "__main__": unittest.main()