mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-13 07:26:45 +00:00
Fix background prompt message leakage
This commit is contained in:
@@ -1029,7 +1029,7 @@ class AgentManager:
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: bool = True,
|
||||
allow_message_tools: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
以独立后台会话执行一段 prompt。
|
||||
@@ -1047,6 +1047,10 @@ class AgentManager:
|
||||
agent.force_streaming = bool(output_callback)
|
||||
agent.reply_mode = reply_mode
|
||||
agent.persist_output_message = persist_output_message
|
||||
if reply_mode == ReplyMode.CAPTURE_ONLY:
|
||||
allow_message_tools = False
|
||||
elif allow_message_tools is None:
|
||||
allow_message_tools = True
|
||||
agent.allow_message_tools = allow_message_tools
|
||||
|
||||
try:
|
||||
|
||||
@@ -435,6 +435,9 @@ class StreamingHandler:
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
if not self._channel or not self._source:
|
||||
logger.debug("流式输出缺少渠道上下文,仅保留 buffer,不外发消息")
|
||||
return
|
||||
|
||||
chain = _StreamChain()
|
||||
|
||||
|
||||
@@ -113,16 +113,24 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
if self._channel and self._source:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 后台 capture 流程没有渠道上下文,不能把工具提示回灌到默认通知渠道。
|
||||
self._stream_handler.record_tool_call(
|
||||
tool_name=self.name,
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=kwargs,
|
||||
)
|
||||
else:
|
||||
# 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要
|
||||
self._stream_handler.record_tool_call(
|
||||
|
||||
@@ -79,15 +79,6 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
_MESSAGE_TOOL_CLASSES = frozenset(
|
||||
{
|
||||
SendMessageTool,
|
||||
AskUserChoiceTool,
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
@@ -191,12 +182,9 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
if (
|
||||
not allow_message_tools
|
||||
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
|
||||
):
|
||||
continue
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
@@ -211,11 +199,6 @@ class MoviePilotToolFactory:
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
if (
|
||||
not allow_message_tools
|
||||
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
|
||||
):
|
||||
continue
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(
|
||||
@@ -224,6 +207,8 @@ class MoviePilotToolFactory:
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
|
||||
@@ -636,6 +636,8 @@ class MessageChain(ChainBase):
|
||||
session_prefix=f"__agent_manual_redo_{history_id}",
|
||||
output_callback=_capture_output,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
await self.async_post_message(
|
||||
Notification(
|
||||
|
||||
@@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent import MoviePilotAgent, AgentManager, ReplyMode
|
||||
from app.agent.memory import memory_manager
|
||||
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
|
||||
|
||||
|
||||
class _FakeGraphState:
|
||||
@@ -124,6 +125,31 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
process_message.await_args.kwargs["reply_mode"],
|
||||
)
|
||||
|
||||
async def test_run_background_prompt_forces_disable_message_tools_when_capture_only(self):
|
||||
captured = {}
|
||||
|
||||
async def fake_process(self, message, images=None, files=None):
|
||||
captured["message"] = message
|
||||
captured["reply_mode"] = self.reply_mode
|
||||
captured["allow_message_tools"] = self.allow_message_tools
|
||||
captured["user_id"] = self.user_id
|
||||
|
||||
with (
|
||||
patch.object(MoviePilotAgent, "process", new=fake_process),
|
||||
patch.object(MoviePilotAgent, "cleanup", new=AsyncMock()),
|
||||
patch.object(memory_manager, "clear_memory"),
|
||||
):
|
||||
await AgentManager.run_background_prompt(
|
||||
message="background task",
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
allow_message_tools=True,
|
||||
)
|
||||
|
||||
self.assertEqual("background task", captured["message"])
|
||||
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
|
||||
self.assertFalse(captured["allow_message_tools"])
|
||||
self.assertEqual(SYSTEM_INTERNAL_USER_ID, captured["user_id"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -42,19 +42,19 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
|
||||
self.assertEqual(buffered_message, "prefix\n\n(调用了 1 次工具)\n\n")
|
||||
|
||||
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n\n")
|
||||
|
||||
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool(""))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n")
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
|
||||
|
||||
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
|
||||
async def _run():
|
||||
@@ -74,7 +74,7 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果",
|
||||
"让我来检查一下:\n\n(调用了 1 次工具)\n\n已经拿到结果",
|
||||
)
|
||||
|
||||
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
|
||||
@@ -109,7 +109,7 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析",
|
||||
"处理中:\n\n(执行了 2 次搜索,读取了 2 个文件)\n\n继续分析",
|
||||
)
|
||||
|
||||
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
|
||||
@@ -130,7 +130,7 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
emitted, queued, buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(emitted, "(读取了 1 个文件)\n")
|
||||
self.assertEqual(emitted, "(读取了 1 个文件)\n\n")
|
||||
self.assertEqual(queued, emitted)
|
||||
self.assertEqual(buffered_message, emitted)
|
||||
|
||||
@@ -164,6 +164,7 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
def test_flush_edits_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._source = "telegram"
|
||||
handler._streaming_enabled = True
|
||||
handler._message_response = MessageResponse(
|
||||
message_id=1,
|
||||
@@ -187,6 +188,43 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(handler._sent_text, "hello world")
|
||||
|
||||
def test_flush_without_channel_context_does_not_send_direct_message(self):
|
||||
handler = StreamingHandler()
|
||||
handler._streaming_enabled = True
|
||||
handler.emit("hello")
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
|
||||
) as run_in_threadpool_mock:
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
run_in_threadpool_mock.assert_not_awaited()
|
||||
self.assertFalse(handler.has_sent_message)
|
||||
|
||||
def test_verbose_background_tool_call_does_not_post_message(self):
|
||||
async def _run():
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
tool.set_stream_handler(handler)
|
||||
tool.set_message_attr(channel=None, source=None, username="tester")
|
||||
|
||||
with (
|
||||
patch.object(settings, "AI_AGENT_VERBOSE", True),
|
||||
patch.object(
|
||||
DummyTool, "send_tool_message", new_callable=AsyncMock
|
||||
) as send_tool_message,
|
||||
):
|
||||
result = await tool._arun(explanation="run test tool")
|
||||
buffered_message = await handler.take()
|
||||
return result, buffered_message, send_tool_message
|
||||
|
||||
result, buffered_message, send_tool_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
send_tool_message.assert_not_awaited()
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user