Fix background prompt message leakage

This commit is contained in:
jxxghp
2026-04-30 06:58:43 +08:00
parent 6532c60a3c
commit baebd0ed1a
7 changed files with 102 additions and 36 deletions

View File

@@ -1029,7 +1029,7 @@ class AgentManager:
output_callback: Optional[Callable[[str], None]] = None,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: bool = True,
allow_message_tools: Optional[bool] = None,
) -> None:
"""
以独立后台会话执行一段 prompt。
@@ -1047,6 +1047,10 @@ class AgentManager:
agent.force_streaming = bool(output_callback)
agent.reply_mode = reply_mode
agent.persist_output_message = persist_output_message
if reply_mode == ReplyMode.CAPTURE_ONLY:
allow_message_tools = False
elif allow_message_tools is None:
allow_message_tools = True
agent.allow_message_tools = allow_message_tools
try:

View File

@@ -435,6 +435,9 @@ class StreamingHandler:
if not current_text or current_text == self._sent_text:
# 没有新内容需要刷新
return
if not self._channel or not self._source:
logger.debug("流式输出缺少渠道上下文,仅保留 buffer不外发消息")
return
chain = _StreamChain()

View File

@@ -113,16 +113,24 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
if tool_message:
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
else:
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
agent_message = await self._stream_handler.take()
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
messages.append(f"⚙️ => {tool_message}")
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
if self._channel and self._source:
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
agent_message = await self._stream_handler.take()
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
messages.append(f"⚙️ => {tool_message}")
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
else:
# 后台 capture 流程没有渠道上下文,不能把工具提示回灌到默认通知渠道。
self._stream_handler.record_tool_call(
tool_name=self.name,
tool_message=tool_message,
tool_kwargs=kwargs,
)
else:
# 非VERBOSE不逐条回显工具调用转为在下一段文本前补一句聚合摘要
self._stream_handler.record_tool_call(

View File

@@ -79,15 +79,6 @@ class MoviePilotToolFactory:
MoviePilot工具工厂
"""
_MESSAGE_TOOL_CLASSES = frozenset(
{
SendMessageTool,
AskUserChoiceTool,
SendLocalFileTool,
SendVoiceMessageTool,
}
)
@staticmethod
def _should_enable_choice_tool(channel: str = None) -> bool:
if not channel:
@@ -191,12 +182,9 @@ class MoviePilotToolFactory:
)
# 创建内置工具
for ToolClass in tool_definitions:
if (
not allow_message_tools
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
):
continue
tool = ToolClass(session_id=session_id, user_id=user_id)
if not allow_message_tools and getattr(tool, "sends_message", False):
continue
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_stream_handler(stream_handler=stream_handler)
tool.set_agent_context(agent_context=agent_context)
@@ -211,11 +199,6 @@ class MoviePilotToolFactory:
tool_classes = plugin_info.get("tools", [])
for ToolClass in tool_classes:
try:
if (
not allow_message_tools
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
):
continue
# 验证工具类是否继承自 MoviePilotTool
if not issubclass(ToolClass, MoviePilotTool):
logger.warning(
@@ -224,6 +207,8 @@ class MoviePilotToolFactory:
continue
# 创建工具实例
tool = ToolClass(session_id=session_id, user_id=user_id)
if not allow_message_tools and getattr(tool, "sends_message", False):
continue
tool.set_message_attr(
channel=channel, source=source, username=username
)

View File

@@ -636,6 +636,8 @@ class MessageChain(ChainBase):
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=_capture_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
await self.async_post_message(
Notification(

View File

@@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage
from app.agent import MoviePilotAgent, AgentManager, ReplyMode
from app.agent.memory import memory_manager
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
class _FakeGraphState:
@@ -124,6 +125,31 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
process_message.await_args.kwargs["reply_mode"],
)
async def test_run_background_prompt_forces_disable_message_tools_when_capture_only(self):
captured = {}
async def fake_process(self, message, images=None, files=None):
captured["message"] = message
captured["reply_mode"] = self.reply_mode
captured["allow_message_tools"] = self.allow_message_tools
captured["user_id"] = self.user_id
with (
patch.object(MoviePilotAgent, "process", new=fake_process),
patch.object(MoviePilotAgent, "cleanup", new=AsyncMock()),
patch.object(memory_manager, "clear_memory"),
):
await AgentManager.run_background_prompt(
message="background task",
reply_mode=ReplyMode.CAPTURE_ONLY,
allow_message_tools=True,
)
self.assertEqual("background task", captured["message"])
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
self.assertFalse(captured["allow_message_tools"])
self.assertEqual(SYSTEM_INTERNAL_USER_ID, captured["user_id"])
if __name__ == "__main__":
unittest.main()

View File

@@ -42,19 +42,19 @@ class TestAgentToolStreaming(unittest.TestCase):
result, buffered_message = asyncio.run(self._run_tool("prefix"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
self.assertEqual(buffered_message, "prefix\n\n(调用了 1 次工具)\n\n")
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n\n")
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
result, buffered_message = asyncio.run(self._run_tool(""))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "(调用了 1 次工具)\n")
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
async def _run():
@@ -74,7 +74,7 @@ class TestAgentToolStreaming(unittest.TestCase):
self.assertEqual(
buffered_message,
"让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果",
"让我来检查一下:\n\n(调用了 1 次工具)\n\n已经拿到结果",
)
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
@@ -109,7 +109,7 @@ class TestAgentToolStreaming(unittest.TestCase):
self.assertEqual(
buffered_message,
"处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析",
"处理中:\n\n(执行了 2 次搜索,读取了 2 个文件)\n\n继续分析",
)
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
@@ -130,7 +130,7 @@ class TestAgentToolStreaming(unittest.TestCase):
emitted, queued, buffered_message = asyncio.run(_run())
self.assertEqual(emitted, "(读取了 1 个文件)\n")
self.assertEqual(emitted, "(读取了 1 个文件)\n\n")
self.assertEqual(queued, emitted)
self.assertEqual(buffered_message, emitted)
@@ -164,6 +164,7 @@ class TestAgentToolStreaming(unittest.TestCase):
def test_flush_edits_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._source = "telegram"
handler._streaming_enabled = True
handler._message_response = MessageResponse(
message_id=1,
@@ -187,6 +188,43 @@ class TestAgentToolStreaming(unittest.TestCase):
)
self.assertEqual(handler._sent_text, "hello world")
def test_flush_without_channel_context_does_not_send_direct_message(self):
handler = StreamingHandler()
handler._streaming_enabled = True
handler.emit("hello")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
asyncio.run(handler._flush())
run_in_threadpool_mock.assert_not_awaited()
self.assertFalse(handler.has_sent_message)
def test_verbose_background_tool_call_does_not_post_message(self):
async def _run():
tool = DummyTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
await handler.start_streaming()
tool.set_stream_handler(handler)
tool.set_message_attr(channel=None, source=None, username="tester")
with (
patch.object(settings, "AI_AGENT_VERBOSE", True),
patch.object(
DummyTool, "send_tool_message", new_callable=AsyncMock
) as send_tool_message,
):
result = await tool._arun(explanation="run test tool")
buffered_message = await handler.take()
return result, buffered_message, send_tool_message
result, buffered_message, send_tool_message = asyncio.run(_run())
self.assertEqual(result, "ok")
send_tool_message.assert_not_awaited()
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
if __name__ == "__main__":
unittest.main()