mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-28 07:26:52 +00:00
74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
from typing import Any, Optional
|
|
|
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
from langgraph.runtime import Runtime
|
|
from langgraph.types import Overwrite
|
|
|
|
|
|
class PatchToolCallsMiddleware(AgentMiddleware):
|
|
"""修复消息历史中悬空工具调用的中间件。"""
|
|
|
|
@staticmethod
|
|
def _build_cancelled_tool_message(tool_call: dict[str, Any]) -> ToolMessage:
|
|
"""构造取消状态的工具响应消息。"""
|
|
tool_name = tool_call.get("name") or "unknown_tool"
|
|
tool_call_id = tool_call.get("id") or ""
|
|
tool_msg = (
|
|
f"Tool call {tool_name} with id {tool_call_id} was "
|
|
"cancelled - another message came in before it could be completed."
|
|
)
|
|
return ToolMessage(
|
|
content=tool_msg,
|
|
name=tool_name,
|
|
tool_call_id=tool_call_id,
|
|
)
|
|
|
|
@classmethod
|
|
def _normalize_messages(cls, messages: list[BaseMessage]) -> list[BaseMessage]:
|
|
"""规范化工具调用消息顺序,满足 OpenAI tool_calls 协议要求。"""
|
|
if not messages or len(messages) == 0:
|
|
return messages
|
|
|
|
tool_messages = {
|
|
msg.tool_call_id: msg
|
|
for msg in messages
|
|
if isinstance(msg, ToolMessage) and msg.tool_call_id
|
|
}
|
|
patched_messages = []
|
|
for msg in messages:
|
|
if isinstance(msg, ToolMessage):
|
|
continue
|
|
|
|
patched_messages.append(msg)
|
|
if not isinstance(msg, AIMessage) or not msg.tool_calls:
|
|
continue
|
|
|
|
for tool_call in msg.tool_calls:
|
|
tool_call_id = tool_call.get("id")
|
|
corresponding_tool_msg = tool_messages.get(tool_call_id)
|
|
if corresponding_tool_msg:
|
|
patched_messages.append(corresponding_tool_msg)
|
|
else:
|
|
patched_messages.append(cls._build_cancelled_tool_message(tool_call))
|
|
|
|
return patched_messages
|
|
|
|
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> Optional[dict[str, Any]]: # noqa: ARG002
|
|
"""在代理运行之前,处理任何 AIMessage 中悬空或乱序的工具调用。"""
|
|
messages = state["messages"]
|
|
patched_messages = self._normalize_messages(messages)
|
|
if patched_messages == messages:
|
|
return None
|
|
|
|
return {"messages": Overwrite(patched_messages)}
|
|
|
|
async def abefore_agent(self, state: AgentState, runtime: Runtime[Any]) -> Optional[dict[str, Any]]: # noqa: ARG002
|
|
"""在代理异步运行之前,处理任何 AIMessage 中悬空或乱序的工具调用。"""
|
|
messages = state["messages"]
|
|
patched_messages = self._normalize_messages(messages)
|
|
if patched_messages == messages:
|
|
return None
|
|
|
|
return {"messages": Overwrite(patched_messages)}
|