mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-30 07:26:48 +00:00
90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
import asyncio
|
|
import unittest
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
|
|
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
|
|
|
|
|
def _build_tool_call(tool_call_id: str = "call_1", name: str = "search") -> dict:
|
|
"""构造测试用工具调用。"""
|
|
return {
|
|
"id": tool_call_id,
|
|
"type": "tool_call",
|
|
"name": name,
|
|
"args": {},
|
|
}
|
|
|
|
|
|
class TestPatchToolCallsMiddleware(unittest.TestCase):
|
|
"""测试工具调用历史修复中间件。"""
|
|
|
|
def test_adds_missing_tool_messages_immediately_after_ai_message(self):
|
|
"""缺失工具响应时应立即补齐 ToolMessage。"""
|
|
middleware = PatchToolCallsMiddleware()
|
|
messages = [
|
|
HumanMessage(content="查天气"),
|
|
AIMessage(content="", tool_calls=[_build_tool_call()]),
|
|
HumanMessage(content="不用查了"),
|
|
]
|
|
|
|
result = middleware.before_agent({"messages": messages}, runtime=None)
|
|
|
|
patched_messages = result["messages"].value
|
|
self.assertIs(patched_messages[1], messages[1])
|
|
self.assertIsInstance(patched_messages[2], ToolMessage)
|
|
self.assertEqual(patched_messages[2].tool_call_id, "call_1")
|
|
self.assertIs(patched_messages[3], messages[2])
|
|
|
|
def test_moves_late_tool_messages_next_to_matching_ai_message(self):
|
|
"""乱序工具响应应移动到对应 assistant 消息之后。"""
|
|
middleware = PatchToolCallsMiddleware()
|
|
tool_message = ToolMessage(content="晴天", tool_call_id="call_1")
|
|
messages = [
|
|
HumanMessage(content="查天气"),
|
|
AIMessage(content="", tool_calls=[_build_tool_call()]),
|
|
HumanMessage(content="再问一句"),
|
|
tool_message,
|
|
]
|
|
|
|
result = middleware.before_agent({"messages": messages}, runtime=None)
|
|
|
|
patched_messages = result["messages"].value
|
|
self.assertIs(patched_messages[1], messages[1])
|
|
self.assertIs(patched_messages[2], tool_message)
|
|
self.assertIs(patched_messages[3], messages[2])
|
|
self.assertNotIn(tool_message, patched_messages[4:])
|
|
|
|
def test_drops_orphan_tool_messages(self):
|
|
"""孤立工具响应不应继续进入模型请求历史。"""
|
|
middleware = PatchToolCallsMiddleware()
|
|
orphan_tool_message = ToolMessage(content="晴天", tool_call_id="call_orphan")
|
|
messages = [
|
|
HumanMessage(content="查天气"),
|
|
orphan_tool_message,
|
|
HumanMessage(content="继续"),
|
|
]
|
|
|
|
result = middleware.before_agent({"messages": messages}, runtime=None)
|
|
|
|
patched_messages = result["messages"].value
|
|
self.assertEqual([msg.type for msg in patched_messages], ["human", "human"])
|
|
self.assertNotIn(orphan_tool_message, patched_messages)
|
|
|
|
def test_async_hook_normalizes_messages(self):
|
|
"""异步 Agent 执行入口也应修复工具调用历史。"""
|
|
middleware = PatchToolCallsMiddleware()
|
|
messages = [
|
|
HumanMessage(content="查天气"),
|
|
AIMessage(content="", tool_calls=[_build_tool_call()]),
|
|
]
|
|
|
|
result = asyncio.run(middleware.abefore_agent({"messages": messages}, runtime=None))
|
|
|
|
patched_messages = result["messages"].value
|
|
self.assertEqual([msg.type for msg in patched_messages], ["human", "ai", "tool"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|