Files
archived-MoviePilot/tests/test_agent_patch_tool_calls.py
2026-05-26 16:42:52 +08:00

90 lines
3.4 KiB
Python

import asyncio
import unittest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
def _build_tool_call(tool_call_id: str = "call_1", name: str = "search") -> dict:
"""构造测试用工具调用。"""
return {
"id": tool_call_id,
"type": "tool_call",
"name": name,
"args": {},
}
class TestPatchToolCallsMiddleware(unittest.TestCase):
"""测试工具调用历史修复中间件。"""
def test_adds_missing_tool_messages_immediately_after_ai_message(self):
"""缺失工具响应时应立即补齐 ToolMessage。"""
middleware = PatchToolCallsMiddleware()
messages = [
HumanMessage(content="查天气"),
AIMessage(content="", tool_calls=[_build_tool_call()]),
HumanMessage(content="不用查了"),
]
result = middleware.before_agent({"messages": messages}, runtime=None)
patched_messages = result["messages"].value
self.assertIs(patched_messages[1], messages[1])
self.assertIsInstance(patched_messages[2], ToolMessage)
self.assertEqual(patched_messages[2].tool_call_id, "call_1")
self.assertIs(patched_messages[3], messages[2])
def test_moves_late_tool_messages_next_to_matching_ai_message(self):
"""乱序工具响应应移动到对应 assistant 消息之后。"""
middleware = PatchToolCallsMiddleware()
tool_message = ToolMessage(content="晴天", tool_call_id="call_1")
messages = [
HumanMessage(content="查天气"),
AIMessage(content="", tool_calls=[_build_tool_call()]),
HumanMessage(content="再问一句"),
tool_message,
]
result = middleware.before_agent({"messages": messages}, runtime=None)
patched_messages = result["messages"].value
self.assertIs(patched_messages[1], messages[1])
self.assertIs(patched_messages[2], tool_message)
self.assertIs(patched_messages[3], messages[2])
self.assertNotIn(tool_message, patched_messages[4:])
def test_drops_orphan_tool_messages(self):
"""孤立工具响应不应继续进入模型请求历史。"""
middleware = PatchToolCallsMiddleware()
orphan_tool_message = ToolMessage(content="晴天", tool_call_id="call_orphan")
messages = [
HumanMessage(content="查天气"),
orphan_tool_message,
HumanMessage(content="继续"),
]
result = middleware.before_agent({"messages": messages}, runtime=None)
patched_messages = result["messages"].value
self.assertEqual([msg.type for msg in patched_messages], ["human", "human"])
self.assertNotIn(orphan_tool_message, patched_messages)
def test_async_hook_normalizes_messages(self):
"""异步 Agent 执行入口也应修复工具调用历史。"""
middleware = PatchToolCallsMiddleware()
messages = [
HumanMessage(content="查天气"),
AIMessage(content="", tool_calls=[_build_tool_call()]),
]
result = asyncio.run(middleware.abefore_agent({"messages": messages}, runtime=None))
patched_messages = result["messages"].value
self.assertEqual([msg.type for msg in patched_messages], ["human", "ai", "tool"])
if __name__ == "__main__":
unittest.main()