feat: add managed agent command sessions

This commit is contained in:
jxxghp
2026-05-18 20:17:59 +08:00
parent f5eeeebeba
commit 9076acc52e
7 changed files with 1028 additions and 85 deletions

View File

@@ -1,16 +1,25 @@
"""执行Shell命令工具"""
"""执行 Shell 命令工具"""
from __future__ import annotations
import asyncio
import json
import os
import signal
import subprocess
from dataclasses import dataclass, field
from tempfile import NamedTemporaryFile
from typing import Any, Optional, TextIO, Type
from typing import Any, Literal, Optional, TextIO, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.impl.terminal_session import (
TERMINAL_DEFAULT_READ_BYTES,
TERMINAL_MAX_READ_BYTES,
TERMINAL_WAIT_DEFAULT_MS,
terminal_session_manager,
)
from app.log import logger
@@ -20,6 +29,14 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
READ_CHUNK_SIZE = 4096
KILL_GRACE_SECONDS = 3
COMMAND_CONCURRENCY_LIMIT = 2
COMMAND_FORBIDDEN_KEYWORDS = (
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
@@ -38,11 +55,13 @@ class _CommandOutput:
@staticmethod
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
"""按 UTF-8 字节数截断文本,避免截断后出现非法字符。"""
if byte_limit <= 0:
return ""
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
def _write_chunk(self, stream_name: str, text: str) -> None:
"""把输出分片按 stdout/stderr 分段写入临时文件。"""
if not self.temp_file_handle or not text:
return
@@ -56,6 +75,7 @@ class _CommandOutput:
self.temp_file_handle.write(text)
def _ensure_temp_file(self) -> None:
"""首次超出预览上限时创建临时文件并补写已缓存预览。"""
if self.temp_file_handle:
return
@@ -72,6 +92,7 @@ class _CommandOutput:
self._write_chunk(stream_name, chunk)
def close(self) -> None:
"""关闭临时文件句柄,确保输出落盘。"""
if not self.temp_file_handle:
return
self.temp_file_handle.flush()
@@ -79,6 +100,7 @@ class _CommandOutput:
self.temp_file_handle = None
def append(self, stream_name: str, text: str) -> None:
"""追加一段输出,超出预览上限后只保留完整日志文件。"""
if not text:
return
@@ -104,47 +126,141 @@ class _CommandOutput:
@property
def stdout(self) -> str:
"""返回当前保留的 stdout 预览。"""
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stdout"
).strip()
@property
def stderr(self) -> str:
"""返回当前保留的 stderr 预览。"""
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stderr"
).strip()
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
"""执行 Shell 命令工具的输入参数模型"""
explanation: str = Field(
..., description="Clear explanation of why this command is being executed"
..., description="Clear explanation of why this command action is needed"
)
action: Optional[Literal["start", "read", "wait", "write", "kill", "run"]] = Field(
"start",
description=(
"Command action. start launches a managed background session and returns "
"session_id. read/wait/write/kill operate on that session. run executes "
"once and waits until completion or timeout."
),
)
command: Optional[str] = Field(
None,
description="Shell command. Required for action=start or action=run.",
)
session_id: Optional[str] = Field(
None,
description="Command session id returned by action=start.",
)
input_text: Optional[str] = Field(
None,
description="Text to send to stdin for action=write. Use \\u0003 for Ctrl+C.",
)
signal_name: Optional[str] = Field(
"TERM",
description="Signal for action=kill, such as TERM, INT, KILL, or 15.",
)
cwd: Optional[str] = Field(
None,
description="Working directory for action=start or action=run.",
)
env: Optional[dict[str, Any]] = Field(
None,
description="Additional environment variables for action=start.",
)
use_pty: Optional[bool] = Field(
True,
description="Use a pseudo terminal for action=start when supported.",
)
since_seq: Optional[int] = Field(
None,
description="For action=read/wait, return output chunks after this seq.",
)
max_bytes: Optional[int] = Field(
TERMINAL_DEFAULT_READ_BYTES,
description="For action=read/wait, maximum output bytes to return.",
)
timeout_ms: Optional[int] = Field(
TERMINAL_WAIT_DEFAULT_MS,
description="For action=wait, maximum segmented wait time in milliseconds.",
)
command: str = Field(..., description="The shell command to execute")
timeout: Optional[int] = Field(
60, description="Max execution time in seconds (default: 60)"
60,
description="For action=run, max execution time in seconds.",
)
class ExecuteCommandTool(MoviePilotTool):
"""统一执行和管理 Shell 命令的 Agent 工具。"""
name: str = "execute_command"
description: str = (
"Safely execute shell commands on the server. Useful for system "
"maintenance, checking status, or running custom scripts. Includes "
"timeout, concurrency, and output preview limits."
"Start and manage shell commands on the server. By default action=start "
"launches a background session and immediately returns session_id/status/"
"last_seq/output_until_seq. Call the same tool with action=read, wait, "
"write, or kill to poll output, wait in short segments, send stdin, or "
"terminate it. Use action=run only when a one-shot bounded command result "
"is preferred."
)
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
result_max_chars = TERMINAL_MAX_READ_BYTES + 4096
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
"""根据命令动作生成友好的提示消息"""
action = kwargs.get("action") or "start"
command = kwargs.get("command")
session_id = kwargs.get("session_id")
if action in {"start", "run"}:
return f"执行系统命令: {command or ''}"
if action == "read":
return f"读取命令输出: {session_id or ''}"
if action == "wait":
return f"等待命令会话: {session_id or ''}"
if action == "write":
return f"写入命令输入: {session_id or ''}"
if action == "kill":
return f"终止命令会话: {session_id or ''}"
return f"处理命令会话: {session_id or command or ''}"
@staticmethod
def _dump(payload: dict[str, Any]) -> str:
"""把结构化命令会话结果转换为 Agent 容易解析的 JSON 字符串。"""
return json.dumps(payload, ensure_ascii=False, indent=2)
@staticmethod
def _require_session_id(session_id: Optional[str]) -> str:
"""校验会话型 action 必须传入 session_id。"""
if not session_id:
raise ValueError("action 需要传入 session_id")
return session_id
@staticmethod
def _require_command(command: Optional[str]) -> str:
"""校验启动型 action 必须传入 command。"""
if not command or not command.strip():
raise ValueError("action 需要传入 command")
return command
@staticmethod
def _validate_command(command: str) -> None:
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
if keyword in command:
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
@staticmethod
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout"""
"""限制一次性执行命令最长运行时间。"""
try:
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
except (TypeError, ValueError):
@@ -161,7 +277,7 @@ class ExecuteCommandTool(MoviePilotTool):
@staticmethod
def _subprocess_kwargs() -> dict:
"""子进程创建独立进程组,便于超时场景清理整棵子进程。"""
"""一次性命令创建独立进程组,便于超时清理整棵子进程。"""
kwargs = {
"stdin": subprocess.DEVNULL,
"stdout": asyncio.subprocess.PIPE,
@@ -179,17 +295,16 @@ class ExecuteCommandTool(MoviePilotTool):
stream_name: str,
output: _CommandOutput,
) -> None:
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
"""按块读取一次性命令输出,只把前 10KB 保留在返回结果中。"""
while True:
chunk = await stream.read(READ_CHUNK_SIZE)
if not chunk:
break
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
@staticmethod
def _terminate_process(process: Any, sig: int):
"""向进程组发送终止信号不支持进程组的平台回退为单进程终止。"""
def _terminate_process(process: Any, sig: int) -> None:
"""向进程组发送终止信号不支持进程组的平台回退为单进程终止。"""
try:
if os.name == "posix":
os.killpg(process.pid, sig)
@@ -230,7 +345,7 @@ class ExecuteCommandTool(MoviePilotTool):
@staticmethod
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
"""等待一次性命令输出读取任务退出,异常只记录不影响工具返回。"""
if not reader_tasks:
return
done, pending = await asyncio.wait(reader_tasks, timeout=1)
@@ -244,7 +359,7 @@ class ExecuteCommandTool(MoviePilotTool):
logger.debug("命令输出读取任务异常: %s", result)
@staticmethod
def _format_result(
def _format_run_result(
*,
exit_code: Optional[int],
output: _CommandOutput,
@@ -252,6 +367,7 @@ class ExecuteCommandTool(MoviePilotTool):
timed_out: bool,
timeout_note: Optional[str],
) -> str:
"""格式化 action=run 的兼容文本结果。"""
if timed_out:
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
else:
@@ -260,11 +376,7 @@ class ExecuteCommandTool(MoviePilotTool):
if timeout_note:
result += f"\n\n提示:\n{timeout_note}"
if output.temp_file_path:
file_note = (
"截至命令终止前的完整输出"
if timed_out
else "完整输出"
)
file_note = "截至命令终止前的完整输出" if timed_out else "完整输出"
result += (
"\n\n提示:\n"
f"命令输出超过 10KB仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
@@ -281,65 +393,129 @@ class ExecuteCommandTool(MoviePilotTool):
result += "\n\n(无输出内容)"
return result
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
)
# 简单安全过滤
forbidden_keywords = [
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
]
for keyword in forbidden_keywords:
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
async def _run_once(
self,
*,
command: str,
timeout: Optional[int],
cwd: Optional[str] = None,
) -> str:
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
self._validate_command(command)
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
async with _command_semaphore:
process = await asyncio.create_subprocess_shell(
command,
cwd=cwd,
**self._subprocess_kwargs(),
)
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
wait_task = asyncio.create_task(process.wait())
reader_tasks = [
asyncio.create_task(self._read_stream(process.stdout, "stdout", output)),
asyncio.create_task(self._read_stream(process.stderr, "stderr", output)),
]
timed_out = False
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=normalized_timeout
)
except asyncio.TimeoutError:
timed_out = True
await self._cleanup_process(process, wait_task)
try:
await self._finish_reader_tasks(reader_tasks)
finally:
output.close()
return self._format_run_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
timeout_note=timeout_note,
)
async def run(
self,
action: Optional[str] = "start",
command: Optional[str] = None,
session_id: Optional[str] = None,
input_text: Optional[str] = None,
signal_name: Optional[str] = "TERM",
cwd: Optional[str] = None,
env: Optional[dict[str, Any]] = None,
use_pty: Optional[bool] = True,
since_seq: Optional[int] = None,
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
timeout: Optional[int] = 60,
**kwargs,
) -> str:
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
normalized_action = (action or "start").strip().lower()
logger.info(
"执行工具: %s, action=%s, command=%s, session_id=%s",
self.name,
normalized_action,
command,
session_id,
)
try:
async with _command_semaphore:
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
process = await asyncio.create_subprocess_shell(
command, **self._subprocess_kwargs()
if normalized_action == "start":
start_command = self._require_command(command)
self._validate_command(start_command)
payload = await terminal_session_manager.start(
command=start_command,
cwd=cwd,
env=env,
use_pty=use_pty,
)
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
wait_task = asyncio.create_task(process.wait())
reader_tasks = [
asyncio.create_task(
self._read_stream(process.stdout, "stdout", output)
),
asyncio.create_task(
self._read_stream(process.stderr, "stderr", output)
),
]
return self._dump(payload)
timed_out = False
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=normalized_timeout
)
except asyncio.TimeoutError:
timed_out = True
await self._cleanup_process(process, wait_task)
if normalized_action == "read":
payload = await terminal_session_manager.read(
session_id=self._require_session_id(session_id),
since_seq=since_seq,
max_bytes=max_bytes,
)
return self._dump(payload)
try:
await self._finish_reader_tasks(reader_tasks)
finally:
output.close()
if normalized_action == "wait":
payload = await terminal_session_manager.wait(
session_id=self._require_session_id(session_id),
timeout_ms=timeout_ms,
since_seq=since_seq,
max_bytes=max_bytes,
)
return self._dump(payload)
return self._format_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
timeout_note=timeout_note,
if normalized_action == "write":
payload = await terminal_session_manager.write(
session_id=self._require_session_id(session_id),
input_text=input_text or "",
)
return self._dump(payload)
if normalized_action == "kill":
payload = await terminal_session_manager.kill(
session_id=self._require_session_id(session_id),
sig=signal_name,
)
return self._dump(payload)
if normalized_action == "run":
return await self._run_once(
command=self._require_command(command),
timeout=timeout,
cwd=cwd,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)
return f"执行命令时发生错误: {str(e)}"
raise ValueError(f"不支持的 action: {action}")
except Exception as err:
logger.error("执行命令 action 失败: %s", err, exc_info=True)
return self._dump({"error": str(err), "status": "error", "action": normalized_action})