From 40d0b60aa2b7fbdc4128c29c0deb4b15d2011745 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sun, 31 May 2026 21:55:25 +0800 Subject: [PATCH] feat: add async subagent task control --- app/agent/__init__.py | 4 +- app/agent/middleware/subagents.py | 553 ++++++++++++++++++++++++-- tests/test_agent_background_output.py | 50 +++ tests/test_agent_subagents.py | 117 +++++- 4 files changed, 679 insertions(+), 45 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 7b1275fa..0edf7609 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -37,6 +37,7 @@ from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware from app.agent.middleware.runtime_config import RuntimeConfigMiddleware from app.agent.middleware.skills import SkillsMiddleware from app.agent.middleware.subagents import ( + SUBAGENT_CONTROL_TOOL_NAME, SUBAGENT_TASK_TOOL_NAME, create_subagent_middlewares, is_subagent_stream_metadata, @@ -833,7 +834,8 @@ class MoviePilotAgent: always_include_tools.extend( tool.name for tool in subagent_task_tools - if getattr(tool, "name", None) == SUBAGENT_TASK_TOOL_NAME + if getattr(tool, "name", None) + in {SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME} ) # 中间件 diff --git a/app/agent/middleware/subagents.py b/app/agent/middleware/subagents.py index 5243bac1..9dd76d6d 100644 --- a/app/agent/middleware/subagents.py +++ b/app/agent/middleware/subagents.py @@ -1,10 +1,13 @@ """MoviePilot 子代理中间件适配。""" +import asyncio +import json import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass +from datetime import datetime from functools import lru_cache -from typing import Any, Optional +from typing import Any, Literal, Optional from langchain.agents import create_agent from langchain.agents.middleware.types import ( @@ -26,13 +29,28 @@ from app.log import logger SUBAGENT_TASK_TOOL_NAME = "task" +SUBAGENT_CONTROL_TOOL_NAME = "subagent_task" SUBAGENT_STREAM_MARKER_KEY = "ls_agent_type" SUBAGENT_STREAM_MARKER_VALUE = "subagent" +SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS = 60000 +SUBAGENT_MAX_WAIT_TIMEOUT_MS = 300000 +SUBAGENT_MAX_ACTIVE_TASKS = 8 +SUBAGENT_MAX_CONCURRENT_TASKS = 4 +SUBAGENT_RESULT_MAX_CHARS = 12000 +SUBAGENT_DESCRIPTION_MAX_CHARS = 500 SUBAGENT_PARENT_PROMPT = """ -You may use the `task` tool to delegate independent research, retrieval, +You may use subagent tools to delegate independent research, retrieval, diagnosis, or planning work to built-in subagents. +Delegation modes: +- Use `task` for one blocking subtask when you need the result immediately. +- Use `subagent_task` for two or more independent subtasks. Start them first + with `action=start` and a `tasks` array, then use `action=status`, + `action=wait`, or `action=cancel` with the returned task IDs. +- Use `subagent_task` with `action=run` when you want to launch a bounded + batch and wait for the batch in one tool call. + Rules: - Delegate when a task benefits from focused investigation, such as media identity checks, site/resource search, subscription analysis, download/transfer diagnosis, or read-only system inspection. - Subagent output is private context for your decision-making. Do not expose a subagent's process or final report verbatim to the user. @@ -47,6 +65,14 @@ SUBAGENT_TASK_DESCRIPTION = ( "not be forwarded verbatim to the user." ) +SUBAGENT_CONTROL_DESCRIPTION = ( + "Start and manage multiple MoviePilot subagent tasks asynchronously. " + "Use action=start with tasks=[{description, subagent_type}] to launch a batch " + "and get task IDs immediately. Use action=status to inspect tasks, action=wait " + "to wait for all or any task result, action=cancel to stop running tasks, and " + "action=run to launch a bounded batch and wait in one call." +) + SUBAGENT_BASE_PROMPT = """You are a silent subagent working for the MoviePilot main agent. Requirements: @@ -80,6 +106,66 @@ class _TaskToolInput(BaseModel): ) +class _SubAgentTaskSpec(BaseModel): + """异步子代理任务定义。""" + + description: str = Field(..., description="Complete task description for the subagent") + subagent_type: str = Field( + default="general-purpose", + description="Subagent type to invoke, such as general-purpose or media-researcher", + ) + + +class _SubAgentControlInput(BaseModel): + """异步子代理管控工具输入。""" + + action: Literal["start", "status", "wait", "cancel", "run"] = Field( + default="start", + description="Task action: start, status, wait, cancel, or run.", + ) + description: Optional[str] = Field( + default=None, + description="Single task description for action=start or action=run.", + ) + subagent_type: Optional[str] = Field( + default="general-purpose", + description="Single task subagent type for action=start or action=run.", + ) + tasks: Optional[list[_SubAgentTaskSpec]] = Field( + default=None, + description="Batch task specs for action=start or action=run.", + ) + task_ids: Optional[list[str]] = Field( + default=None, + description="Task IDs returned by action=start. Empty means all known tasks.", + ) + task_id: Optional[str] = Field( + default=None, + description="Single task ID for status, wait, or cancel.", + ) + wait_mode: Literal["all", "any"] = Field( + default="all", + description="For action=wait or action=run: wait for all selected tasks or any one task.", + ) + timeout_ms: Optional[int] = Field( + default=SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS, + description="Maximum wait time in milliseconds for action=wait or action=run.", + ) + + +@dataclass +class _SubAgentRuntimeTask: + """运行中的异步子代理任务记录。""" + + task_id: str + description: str + subagent_type: str + task: asyncio.Task + created_at: datetime + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + + def is_subagent_stream_metadata(metadata: Any) -> bool: """判断流式 token 元数据是否来自子代理。""" if not isinstance(metadata, dict): @@ -327,6 +413,88 @@ def _extract_final_text(result: Any) -> str: return _extract_text_content(result).strip() +def _clip_text(text: Any, max_chars: int) -> tuple[str, bool]: + """裁剪过长文本,返回文本和是否被裁剪。""" + normalized = "" if text is None else str(text) + if len(normalized) <= max_chars: + return normalized, False + return normalized[:max_chars], True + + +def _format_datetime(value: Optional[datetime]) -> Optional[str]: + """格式化任务时间。""" + if not value: + return None + return value.strftime("%Y-%m-%d %H:%M:%S") + + +class _SubAgentAgentProvider: + """子代理图懒加载与执行器。""" + + def __init__( + self, + *, + model: BaseChatModel, + profiles: tuple[_SubAgentProfile, ...], + tools: list[BaseTool], + ) -> None: + """初始化子代理执行器。""" + self._model = model + self._profiles = {profile.name: profile for profile in profiles} + self._tools = tools + self._agents = {} + self._default_agent_name = "general-purpose" + + def _resolve_profile(self, agent_name: Optional[str]) -> _SubAgentProfile: + """解析子代理类型,未知类型回退到默认子代理。""" + return self._profiles.get(agent_name or "") or self._profiles[ + self._default_agent_name + ] + + def get_agent(self, agent_name: Optional[str]) -> tuple[str, Any]: + """懒加载指定名称的子代理图。""" + profile = self._resolve_profile(agent_name) + cached_agent = self._agents.get(profile.name) + if cached_agent: + return profile.name, cached_agent + + subagent_tools = _select_tools(self._tools, profile) + agent = create_agent( + model=self._model, + tools=subagent_tools, + system_prompt=profile.prompt, + name=profile.name, + ) + self._agents[profile.name] = agent + return profile.name, agent + + async def run_task( + self, + *, + description: str, + subagent_type: Optional[str], + task_id: Optional[str] = None, + ) -> str: + """调用指定子代理并只返回供主代理读取的结果。""" + agent_name, agent = self.get_agent(subagent_type) + thread_suffix = task_id or uuid.uuid4().hex + result = await agent.ainvoke( + {"messages": [HumanMessage(content=description)]}, + config={ + "configurable": { + "thread_id": f"subagent-{agent_name}-{thread_suffix}", + SUBAGENT_STREAM_MARKER_KEY: SUBAGENT_STREAM_MARKER_VALUE, + }, + "metadata": { + "lc_agent_name": agent_name, + SUBAGENT_STREAM_MARKER_KEY: SUBAGENT_STREAM_MARKER_VALUE, + }, + }, + ) + final_text = _extract_final_text(result) + return final_text or "The subagent did not return a usable result." + + class MoviePilotSubAgentMiddleware(AgentMiddleware): """MoviePilot 本地子代理中间件兜底实现。""" @@ -340,11 +508,11 @@ class MoviePilotSubAgentMiddleware(AgentMiddleware): task_description: str = SUBAGENT_TASK_DESCRIPTION, ) -> None: self.system_prompt = system_prompt - self._model = model - self._profiles = {profile.name: profile for profile in profiles} - self._tools = tools - self._agents = {} - self._default_agent_name = "general-purpose" + self._provider = _SubAgentAgentProvider( + model=model, + profiles=profiles, + tools=tools, + ) self.tools = [ StructuredTool.from_function( coroutine=self._run_task, @@ -359,42 +527,14 @@ class MoviePilotSubAgentMiddleware(AgentMiddleware): def _get_agent(self, agent_name: str) -> Any: """懒加载指定名称的子代理图。""" - profile = self._profiles.get(agent_name) or self._profiles[ - self._default_agent_name - ] - cached_agent = self._agents.get(profile.name) - if cached_agent: - return cached_agent - - subagent_tools = _select_tools(self._tools, profile) - agent = create_agent( - model=self._model, - tools=subagent_tools, - system_prompt=profile.prompt, - name=profile.name, - ) - self._agents[profile.name] = agent - return agent + return self._provider.get_agent(agent_name)[1] async def _run_task(self, description: str, subagent_type: str) -> str: """调用指定子代理并只返回供主代理读取的结果。""" - agent_name = subagent_type or self._default_agent_name - agent = self._get_agent(agent_name) - result = await agent.ainvoke( - {"messages": [HumanMessage(content=description)]}, - config={ - "configurable": { - "thread_id": f"subagent-{agent_name}-{uuid.uuid4().hex}", - SUBAGENT_STREAM_MARKER_KEY: SUBAGENT_STREAM_MARKER_VALUE, - }, - "metadata": { - "lc_agent_name": agent_name, - SUBAGENT_STREAM_MARKER_KEY: SUBAGENT_STREAM_MARKER_VALUE, - }, - }, + return await self._provider.run_task( + description=description, + subagent_type=subagent_type, ) - final_text = _extract_final_text(result) - return final_text or "The subagent did not return a usable result." async def awrap_model_call( self, @@ -411,6 +551,323 @@ class MoviePilotSubAgentMiddleware(AgentMiddleware): return await handler(request.override(system_message=new_system_message)) +class SubAgentTaskControlMiddleware(AgentMiddleware): + """提供异步子代理任务调度工具的中间件。""" + + def __init__( + self, + *, + model: BaseChatModel, + profiles: tuple[_SubAgentProfile, ...], + tools: list[BaseTool], + task_description: str = SUBAGENT_CONTROL_DESCRIPTION, + ) -> None: + """初始化异步子代理调度中间件。""" + self._provider = _SubAgentAgentProvider( + model=model, + profiles=profiles, + tools=tools, + ) + self._semaphore = asyncio.Semaphore(SUBAGENT_MAX_CONCURRENT_TASKS) + self._tasks: dict[str, _SubAgentRuntimeTask] = {} + self.tools = [ + StructuredTool.from_function( + coroutine=self._control_task, + name=SUBAGENT_CONTROL_TOOL_NAME, + description=( + f"{task_description}\n\nAvailable subagents:\n" + f"{_format_subagent_catalog(profiles)}" + ), + args_schema=_SubAgentControlInput, + ) + ] + + @staticmethod + def _json_response(payload: dict[str, Any]) -> str: + """将工具响应序列化为稳定 JSON。""" + return json.dumps(payload, ensure_ascii=False, indent=2) + + @staticmethod + def _normalize_timeout_ms(timeout_ms: Optional[int]) -> int: + """规范化等待超时时间。""" + if timeout_ms is None: + return SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS + return max(0, min(int(timeout_ms), SUBAGENT_MAX_WAIT_TIMEOUT_MS)) + + @staticmethod + def _task_status(record: _SubAgentRuntimeTask) -> str: + """读取任务当前状态。""" + task = record.task + if task.cancelled(): + return "cancelled" + if not task.done(): + return "running" if record.started_at else "pending" + if task.exception(): + return "failed" + return "completed" + + @staticmethod + def _task_output(record: _SubAgentRuntimeTask) -> dict[str, Any]: + """格式化单个任务状态和结果。""" + description, description_truncated = _clip_text( + record.description, + SUBAGENT_DESCRIPTION_MAX_CHARS, + ) + payload: dict[str, Any] = { + "task_id": record.task_id, + "subagent_type": record.subagent_type, + "status": SubAgentTaskControlMiddleware._task_status(record), + "description": description, + "description_truncated": description_truncated, + "created_at": _format_datetime(record.created_at), + "started_at": _format_datetime(record.started_at), + "finished_at": _format_datetime(record.finished_at), + } + if not record.task.done(): + return payload + if record.task.cancelled(): + return payload + + error = record.task.exception() + if error: + payload["error"] = str(error) + return payload + + result, result_truncated = _clip_text( + record.task.result(), + SUBAGENT_RESULT_MAX_CHARS, + ) + payload["result"] = result + payload["result_truncated"] = result_truncated + return payload + + def _selected_records( + self, + *, + task_ids: Optional[list[str]] = None, + task_id: Optional[str] = None, + active_only: bool = False, + ) -> tuple[list[_SubAgentRuntimeTask], list[str]]: + """根据任务 ID 选择记录。""" + selected_ids = [] + if task_id: + selected_ids.append(task_id) + selected_ids.extend(task_ids or []) + if not selected_ids: + records = list(self._tasks.values()) + if active_only: + records = [record for record in records if not record.task.done()] + return records, [] + + records = [] + missing_ids = [] + seen_ids = set() + for selected_id in selected_ids: + if selected_id in seen_ids: + continue + seen_ids.add(selected_id) + record = self._tasks.get(selected_id) + if record: + records.append(record) + else: + missing_ids.append(selected_id) + return records, missing_ids + + def _normalize_specs( + self, + *, + description: Optional[str], + subagent_type: Optional[str], + tasks: Optional[list[_SubAgentTaskSpec]], + ) -> tuple[list[_SubAgentTaskSpec], Optional[str]]: + """规范化单任务和批量任务输入。""" + specs = [] + for task in tasks or []: + if isinstance(task, dict): + task = _SubAgentTaskSpec(**task) + if task.description.strip(): + specs.append(task) + if not specs and description and description.strip(): + specs.append( + _SubAgentTaskSpec( + description=description, + subagent_type=subagent_type or "general-purpose", + ) + ) + if not specs: + return [], "缺少可执行的子代理任务描述。" + if len(specs) > SUBAGENT_MAX_ACTIVE_TASKS: + return [], f"单次最多可提交 {SUBAGENT_MAX_ACTIVE_TASKS} 个子代理任务。" + + active_count = sum( + 1 for record in self._tasks.values() if not record.task.done() + ) + if active_count + len(specs) > SUBAGENT_MAX_ACTIVE_TASKS: + return [], ( + f"当前仍有 {active_count} 个子代理任务未完成," + f"总并发上限为 {SUBAGENT_MAX_ACTIVE_TASKS}。" + ) + return specs, None + + async def _execute_managed_task(self, record: _SubAgentRuntimeTask) -> str: + """执行受调度器管理的子代理任务。""" + async with self._semaphore: + record.started_at = datetime.now() + try: + return await self._provider.run_task( + description=record.description, + subagent_type=record.subagent_type, + task_id=record.task_id, + ) + except asyncio.CancelledError: + raise + except Exception as err: + logger.error(f"子代理任务执行失败: task_id={record.task_id}, error={err}") + raise + + def _mark_task_finished(self, task_id: str, task: asyncio.Task) -> None: + """记录任务完成时间并取出异常避免未读取告警。""" + record = self._tasks.get(task_id) + if record: + record.finished_at = datetime.now() + if task.cancelled(): + return + try: + task.exception() + except Exception: + return + + def _start_tasks(self, specs: list[_SubAgentTaskSpec]) -> list[_SubAgentRuntimeTask]: + """启动一批异步子代理任务。""" + records = [] + for spec in specs: + task_id = f"subagent-{uuid.uuid4().hex[:12]}" + record = _SubAgentRuntimeTask( + task_id=task_id, + description=spec.description.strip(), + subagent_type=spec.subagent_type or "general-purpose", + task=None, + created_at=datetime.now(), + ) + task = asyncio.create_task( + self._execute_managed_task(record), + name=task_id, + ) + record.task = task + task.add_done_callback( + lambda finished_task, finished_task_id=task_id: self._mark_task_finished( + finished_task_id, + finished_task, + ) + ) + self._tasks[task_id] = record + records.append(record) + return records + + async def _wait_records( + self, + *, + records: list[_SubAgentRuntimeTask], + wait_mode: str, + timeout_ms: Optional[int], + ) -> None: + """按等待模式等待一组任务完成。""" + pending_tasks = [record.task for record in records if not record.task.done()] + if not pending_tasks: + return + + timeout = self._normalize_timeout_ms(timeout_ms) / 1000 + if timeout <= 0: + return + + return_when = asyncio.FIRST_COMPLETED if wait_mode == "any" else asyncio.ALL_COMPLETED + await asyncio.wait( + pending_tasks, + timeout=timeout, + return_when=return_when, + ) + + async def _cancel_records(self, records: list[_SubAgentRuntimeTask]) -> None: + """取消一组尚未完成的任务。""" + cancellable_tasks = [ + record.task for record in records if not record.task.done() + ] + for task in cancellable_tasks: + task.cancel() + if cancellable_tasks: + await asyncio.gather(*cancellable_tasks, return_exceptions=True) + + async def _control_task( + self, + action: str = "start", + description: Optional[str] = None, + subagent_type: Optional[str] = "general-purpose", + tasks: Optional[list[_SubAgentTaskSpec]] = None, + task_ids: Optional[list[str]] = None, + task_id: Optional[str] = None, + wait_mode: str = "all", + timeout_ms: Optional[int] = SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS, + ) -> str: + """管理异步子代理任务。""" + if action in {"start", "run"}: + specs, error = self._normalize_specs( + description=description, + subagent_type=subagent_type, + tasks=tasks, + ) + if error: + return self._json_response({"success": False, "error": error}) + + records = self._start_tasks(specs) + if action == "run": + await self._wait_records( + records=records, + wait_mode=wait_mode, + timeout_ms=timeout_ms, + ) + + return self._json_response( + { + "success": True, + "action": action, + "wait_mode": wait_mode if action == "run" else None, + "tasks": [self._task_output(record) for record in records], + } + ) + + records, missing_ids = self._selected_records( + task_ids=task_ids, + task_id=task_id, + active_only=action in {"wait", "cancel"} and not task_ids and not task_id, + ) + + if action == "wait": + await self._wait_records( + records=records, + wait_mode=wait_mode, + timeout_ms=timeout_ms, + ) + elif action == "cancel": + await self._cancel_records(records) + + return self._json_response( + { + "success": True, + "action": action, + "wait_mode": wait_mode if action == "wait" else None, + "missing_task_ids": missing_ids, + "tasks": [self._task_output(record) for record in records], + } + ) + + async def aafter_agent(self, state: Any, runtime: Any) -> None: + """Agent 结束时取消未完成的子代理任务,避免后台泄漏。""" + unfinished_records = [ + record for record in self._tasks.values() if not record.task.done() + ] + await self._cancel_records(unfinished_records) + + class SubAgentCallSummaryMiddleware(AgentMiddleware): """记录子代理调用次数的中间件。""" @@ -427,13 +884,14 @@ class SubAgentCallSummaryMiddleware(AgentMiddleware): tool = request.tool if ( tool - and getattr(tool, "name", None) == SUBAGENT_TASK_TOOL_NAME + and getattr(tool, "name", None) + in {SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME} and self.stream_handler and getattr(self.stream_handler, "is_streaming", False) ): tool_call = request.tool_call or {} self.stream_handler.record_tool_call( - tool_name=SUBAGENT_TASK_TOOL_NAME, + tool_name=getattr(tool, "name", SUBAGENT_TASK_TOOL_NAME), tool_message="Subagent invoked", tool_kwargs=tool_call.get("args") or {}, ) @@ -501,16 +959,27 @@ def create_subagent_middlewares( profiles=profiles, tools=tools, ) + control_middleware = SubAgentTaskControlMiddleware( + model=model, + profiles=profiles, + tools=tools, + ) - task_tools = list(getattr(subagent_middleware, "tools", []) or []) + task_tools = [ + *list(getattr(subagent_middleware, "tools", []) or []), + *list(getattr(control_middleware, "tools", []) or []), + ] return [ subagent_middleware, + control_middleware, SubAgentCallSummaryMiddleware(stream_handler=stream_handler), ], task_tools __all__ = [ + "SUBAGENT_CONTROL_TOOL_NAME", "SUBAGENT_TASK_TOOL_NAME", + "SubAgentTaskControlMiddleware", "create_subagent_middlewares", "is_subagent_stream_metadata", ] diff --git a/tests/test_agent_background_output.py b/tests/test_agent_background_output.py index 2a06cee6..5ee9f66f 100644 --- a/tests/test_agent_background_output.py +++ b/tests/test_agent_background_output.py @@ -13,6 +13,10 @@ from app.agent import ( _MessageTask, ) from app.agent.memory import memory_manager +from app.agent.middleware.subagents import ( + SUBAGENT_CONTROL_TOOL_NAME, + SUBAGENT_TASK_TOOL_NAME, +) from app.agent.tools.factory import MoviePilotToolFactory from app.core.config import settings from app.utils.identity import SYSTEM_INTERNAL_USER_ID @@ -355,6 +359,52 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): self.assertIn("send_message", always_include) + async def test_create_agent_always_includes_subagent_tools(self): + """工具筛选开启时应保留同步和异步子代理入口。""" + captured = {} + agent = MoviePilotAgent(session_id="normal-session", user_id="system") + agent._initialize_tools = lambda: [] + agent._initialize_subagent_tools = lambda: [] + + def _tool_selector(**kwargs): + captured["always_include"] = kwargs["always_include"] + return "selector" + + with ( + patch.object(settings, "LLM_MAX_TOOLS", 5), + patch.object(agent, "_initialize_llm", new=AsyncMock(return_value=object())), + patch("app.agent.prompt_manager.get_agent_prompt", return_value="PROMPT"), + patch( + "app.agent.create_subagent_middlewares", + return_value=( + ["subagent"], + [ + SimpleNamespace(name=SUBAGENT_TASK_TOOL_NAME), + SimpleNamespace(name=SUBAGENT_CONTROL_TOOL_NAME), + ], + ), + ), + patch( + "app.agent.MoviePilotToolFactory.get_tool_selector_always_include_names", + return_value=[], + ), + patch("app.agent.SkillsMiddleware", side_effect=lambda *args, **kwargs: "skills"), + patch("app.agent.JobsMiddleware", side_effect=lambda *args, **kwargs: "jobs"), + patch("app.agent.RuntimeConfigMiddleware", side_effect=lambda *args, **kwargs: "runtime"), + patch("app.agent.MemoryMiddleware", side_effect=lambda *args, **kwargs: "memory"), + patch("app.agent.ActivityLogMiddleware", side_effect=lambda *args, **kwargs: "activity"), + patch("app.agent.SummarizationMiddleware", side_effect=lambda *args, **kwargs: "summary"), + patch("app.agent.PatchToolCallsMiddleware", side_effect=lambda *args, **kwargs: "patch"), + patch("app.agent.UsageMiddleware", side_effect=lambda *args, **kwargs: "usage"), + patch("app.agent.ToolSelectorMiddleware", side_effect=_tool_selector), + patch("app.agent.InMemorySaver", return_value="checkpointer"), + patch("app.agent.create_agent", side_effect=lambda **kwargs: kwargs), + ): + await agent._create_agent(streaming=False) + + self.assertIn(SUBAGENT_TASK_TOOL_NAME, captured["always_include"]) + self.assertIn(SUBAGENT_CONTROL_TOOL_NAME, captured["always_include"]) + async def test_create_agent_keeps_activity_log_for_normal_session(self): agent = MoviePilotAgent(session_id="normal-session", user_id="system") agent._initialize_tools = lambda: [] diff --git a/tests/test_agent_subagents.py b/tests/test_agent_subagents.py index a7baf960..da4eac9e 100644 --- a/tests/test_agent_subagents.py +++ b/tests/test_agent_subagents.py @@ -1,3 +1,5 @@ +import asyncio +import json import unittest from pathlib import Path from types import SimpleNamespace @@ -8,7 +10,9 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel import app.agent.middleware.subagents as subagent_module from app.agent.middleware.subagents import ( MoviePilotSubAgentMiddleware, + SUBAGENT_CONTROL_TOOL_NAME, SUBAGENT_TASK_TOOL_NAME, + SubAgentTaskControlMiddleware, create_subagent_middlewares, ) from app.agent.tools.tags import ToolTag @@ -25,10 +29,15 @@ class TestAgentSubagents(unittest.TestCase): stream_handler=None, ) - self.assertEqual(len(middlewares), 2) - self.assertEqual([tool.name for tool in task_tools], [SUBAGENT_TASK_TOOL_NAME]) + self.assertEqual(len(middlewares), 3) + self.assertEqual( + [tool.name for tool in task_tools], + [SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME], + ) self.assertIn("media-researcher", task_tools[0].description) self.assertIn("system-diagnostician", task_tools[0].description) + self.assertIn("action=start", task_tools[1].description) + self.assertIn("action=wait", task_tools[1].description) def test_subagent_tools_are_selected_by_tags(self): """子代理应根据工具标签筛选工具,而不是依赖工具名名单。""" @@ -83,5 +92,109 @@ class TestAgentSubagents(unittest.TestCase): self.assertEqual([], missing_tools) +class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase): + async def test_control_tool_starts_tasks_concurrently_and_waits(self): + """异步子代理管控工具应批量启动任务,并在 wait 时收集结果。""" + model = FakeListChatModel(responses=["ok"]) + middleware = SubAgentTaskControlMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=[], + ) + running_descriptions = [] + both_started = asyncio.Event() + allow_finish = asyncio.Event() + + async def _fake_run_task(self, *, description, subagent_type, task_id=None): + running_descriptions.append(description) + if len(running_descriptions) == 2: + both_started.set() + await allow_finish.wait() + return f"{subagent_type}:{description}:{task_id}" + + with patch.object( + subagent_module._SubAgentAgentProvider, + "run_task", + new=_fake_run_task, + ): + start_payload = json.loads( + await middleware._control_task( + action="start", + tasks=[ + { + "description": "检查媒体库", + "subagent_type": "media-researcher", + }, + { + "description": "检查下载器", + "subagent_type": "download-diagnostician", + }, + ], + ) + ) + + await asyncio.wait_for(both_started.wait(), timeout=1) + allow_finish.set() + task_ids = [task["task_id"] for task in start_payload["tasks"]] + wait_payload = json.loads( + await middleware._control_task( + action="wait", + task_ids=task_ids, + wait_mode="all", + timeout_ms=1000, + ) + ) + + self.assertTrue(start_payload["success"]) + self.assertEqual(2, len(task_ids)) + self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions) + self.assertEqual( + ["completed", "completed"], + [task["status"] for task in wait_payload["tasks"]], + ) + self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"]) + self.assertIn( + "download-diagnostician:检查下载器", + wait_payload["tasks"][1]["result"], + ) + + async def test_after_agent_cancels_unfinished_tasks(self): + """Agent 结束时应取消仍在运行的异步子代理任务。""" + model = FakeListChatModel(responses=["ok"]) + middleware = SubAgentTaskControlMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=[], + ) + task_started = asyncio.Event() + + async def _fake_run_task(self, *, description, subagent_type, task_id=None): + task_started.set() + await asyncio.Event().wait() + + with patch.object( + subagent_module._SubAgentAgentProvider, + "run_task", + new=_fake_run_task, + ): + start_payload = json.loads( + await middleware._control_task( + action="start", + description="长时间诊断", + subagent_type="system-diagnostician", + ) + ) + await asyncio.wait_for(task_started.wait(), timeout=1) + await middleware.aafter_agent({}, None) + status_payload = json.loads( + await middleware._control_task( + action="status", + task_ids=[start_payload["tasks"][0]["task_id"]], + ) + ) + + self.assertEqual("cancelled", status_payload["tasks"][0]["status"]) + + if __name__ == "__main__": unittest.main()