From 97cfcda03cf43f754c1c2d490d2c823d670ee510 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 4 Jun 2026 21:06:25 +0800 Subject: [PATCH] feat(workflow): implement action contract management for inputs and outputs --- app/chain/workflow.py | 61 +++++++++++++++-- app/workflow/__init__.py | 10 +++ app/workflow/actions/__init__.py | 88 ++++++++++++++++++++++++- app/workflow/actions/add_download.py | 6 ++ app/workflow/actions/add_subscribe.py | 5 ++ app/workflow/actions/fetch_downloads.py | 6 ++ app/workflow/actions/fetch_medias.py | 4 ++ app/workflow/actions/fetch_rss.py | 4 ++ app/workflow/actions/fetch_torrents.py | 5 ++ app/workflow/actions/filter_medias.py | 5 ++ app/workflow/actions/filter_torrents.py | 5 ++ app/workflow/actions/invoke_plugin.py | 2 + app/workflow/actions/note.py | 2 + app/workflow/actions/scan_file.py | 4 ++ app/workflow/actions/scrape_file.py | 5 ++ app/workflow/actions/send_event.py | 2 + app/workflow/actions/send_message.py | 2 + app/workflow/actions/transfer_file.py | 9 +++ tests/test_workflow_actions.py | 81 +++++++++++++++++++++++ tests/test_workflow_execution.py | 45 ++++++++++++- 20 files changed, 341 insertions(+), 10 deletions(-) diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 1f14fb4c..a5744f55 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -555,13 +555,28 @@ class WorkflowExecutor: 根据动作输入声明读取上游节点输出。 """ inputs = {} - input_paths = action.inputs or self.get_action_data_value(action, "inputs") or [] + input_paths = action.inputs or self.get_action_data_value(action, "inputs") + if not input_paths: + input_paths = [ + field["name"] for field in self.get_action_contract(action).get("inputs") or [] + ] if isinstance(input_paths, str): input_paths = [item.strip() for item in input_paths.splitlines() if item.strip()] for input_path in input_paths: - inputs[input_path] = self.resolve_context_path(input_path) + inputs[input_path] = self.resolve_action_input(input_path) return inputs + def resolve_action_input(self, input_path: str) -> Any: + """ + 解析动作输入声明。 + """ + if input_path in ActionContext.model_fields: + value = getattr(self.context, input_path, None) + if value not in (None, "", [], {}): + return value + return self.context.artifacts.get(input_path) if self.context.artifacts else value + return self.resolve_context_path(input_path) + def build_action_runtime(self, action: Action) -> dict: """ 构建传递给动作的新运行期数据。 @@ -588,9 +603,14 @@ class WorkflowExecutor: 根据动作输出声明整理当前节点输出。 """ outputs = action_result.outputs or self.extract_context_outputs(result_context) - declared_outputs = action.outputs or self.get_action_data_value(action, "outputs") + declared_outputs = self.get_action_output_declarations(action) if isinstance(declared_outputs, list): - return {key: outputs.get(key) for key in declared_outputs if outputs.get(key) not in (None, "", [], {})} + normalized_outputs = {} + for item in declared_outputs: + key = item.get("name") if isinstance(item, dict) else item + if key and outputs.get(key) not in (None, "", [], {}): + normalized_outputs[key] = outputs.get(key) + return normalized_outputs or outputs if isinstance(declared_outputs, dict): return { key: outputs.get(key) @@ -648,12 +668,28 @@ class WorkflowExecutor: """ 获取动作输出声明配置。 """ - outputs_config = action.outputs or self.get_action_data_value(action, "outputs") or {} + outputs_config = self.get_action_output_declarations(action) if isinstance(outputs_config, dict): value = outputs_config.get(output_key) or {} return value if isinstance(value, dict) else {} + if isinstance(outputs_config, list): + for item in outputs_config: + if isinstance(item, dict) and item.get("name") == output_key: + return { + key: value for key, value in item.items() + if key not in ("name", "label", "kind") and value not in (None, "", [], {}) + } return {} + def get_action_output_declarations(self, action: Action) -> Any: + """ + 获取动作输出声明,优先使用节点显式配置,其次使用动作固定契约。 + """ + outputs_config = action.outputs or self.get_action_data_value(action, "outputs") + if outputs_config: + return outputs_config + return self.get_action_contract(action).get("outputs") or {} + def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str: """ 获取输出默认合并策略。 @@ -796,7 +832,20 @@ class WorkflowExecutor: """ 获取动作并发互斥键。 """ - return action.concurrency_key or self.get_action_data_value(action, "concurrency_key") + return ( + action.concurrency_key + or self.get_action_data_value(action, "concurrency_key") + or self.get_action_contract(action).get("concurrency_key") + ) + + def get_action_contract(self, action: Action) -> dict: + """ + 获取动作固定输入输出契约。 + """ + get_contract = getattr(self.workflowmanager, "get_action_contract", None) + if not get_contract: + return {} + return get_contract(action.type) or {} def get_flow_condition(self, flow: ActionFlow) -> Optional[str]: """ diff --git a/app/workflow/__init__.py b/app/workflow/__init__.py index eeb51ede..415d780e 100644 --- a/app/workflow/__init__.py +++ b/app/workflow/__init__.py @@ -242,6 +242,7 @@ class WorkFlowManager(metaclass=Singleton): "type": key, "name": action.name, "description": action.description, + "contract": action.get_contract(), "data": { "label": action.name, **action.data @@ -249,6 +250,15 @@ class WorkFlowManager(metaclass=Singleton): } for key, action in self._actions.items() ] + def get_action_contract(self, action_type: str) -> dict: + """ + 获取动作输入输出契约。 + """ + action = self._actions.get(action_type) + if not action or not hasattr(action, "get_contract"): + return {} + return action.get_contract() + def update_workflow_event(self, workflow: Workflow): """ 更新工作流事件触发器 diff --git a/app/workflow/actions/__init__.py b/app/workflow/actions/__init__.py index 2f5a481c..8dfc01c6 100644 --- a/app/workflow/actions/__init__.py +++ b/app/workflow/actions/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Union +from typing import Any, Union from app.chain import ChainBase from app.db.systemconfig_oper import SystemConfigOper @@ -23,6 +23,8 @@ class BaseAction(ABC): _message = "" # 缓存键值 _cache_key = "WorkflowCache-%s" + # 动作输入输出契约,由具体动作按需覆盖 + contract = {} def __init__(self, action_id: str): self._action_id = action_id @@ -48,6 +50,41 @@ class BaseAction(ABC): def data(cls) -> dict: # noqa pass + @classmethod + def get_contract(cls) -> dict: + """ + 获取动作输入输出契约。 + """ + contract = getattr(cls, "contract", None) or {} + input_fields = cls._build_contract_fields(contract.get("inputs") or []) + output_fields = cls._build_contract_fields(contract.get("outputs") or []) + return { + "inputs": input_fields, + "outputs": output_fields, + "condition_fields": output_fields, + "concurrency_key": contract.get("concurrency_key"), + } + + @classmethod + def _build_contract_fields(cls, fields: list) -> list: + """ + 标准化动作契约字段。 + """ + result = [] + for field in fields: + if isinstance(field, str): + field = {"name": field} + if not isinstance(field, dict) or not field.get("name"): + continue + result.append({ + "name": field["name"], + "label": field.get("label") or field["name"], + "kind": field.get("kind") or "scalar", + "merge": field.get("merge"), + "identity": field.get("identity"), + }) + return result + @property def done(self) -> bool: """ @@ -115,10 +152,55 @@ class BaseAction(ABC): """ 使用显式输入与运行期信息执行动作。 """ - _ = inputs, runtime + self._apply_inputs_to_context(inputs=inputs, context=context) + self._apply_runtime_to_context(runtime=runtime, context=context) result_context = self.execute(workflow_id, params, context) + outputs = self._extract_outputs_from_context(result_context) return ActionResult( success=self.success, message=self.message, - context=result_context + context=result_context, + outputs=outputs ) + + def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None: + """ + 将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。 + """ + inputs = inputs or {} + for field in self.get_contract().get("inputs") or []: + missing = object() + field_name = field["name"] + value = inputs.get(field_name, missing) + if value is missing: + # 兼容旧版节点输入路径,例如 outputs.A.torrents。 + for input_key, input_value in inputs.items(): + if isinstance(input_key, str) and input_key.split(".")[-1] == field_name: + value = input_value + break + if value is not missing: + setattr(context, field_name, value) + + @staticmethod + def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None: + """ + 将运行期信息写入 runtime_state,供动作和执行状态读取。 + """ + if not runtime: + return + context.runtime_state = context.runtime_state or {} + context.runtime_state["current_action_runtime"] = { + key: value for key, value in runtime.items() + if key != "cancel_token" + } + + def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]: + """ + 按动作契约从上下文提取输出。 + """ + outputs = {} + for field in self.get_contract().get("outputs") or []: + value = getattr(context, field["name"], None) + if value not in (None, "", [], {}): + outputs[field["name"]] = value + return outputs diff --git a/app/workflow/actions/add_download.py b/app/workflow/actions/add_download.py index c097c274..e400e35f 100644 --- a/app/workflow/actions/add_download.py +++ b/app/workflow/actions/add_download.py @@ -26,6 +26,12 @@ class AddDownloadAction(BaseAction): 添加下载资源 """ + contract = { + "inputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + "outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}], + "concurrency_key": "download", + } + def __init__(self, action_id: str): super().__init__(action_id) self._added_downloads = [] diff --git a/app/workflow/actions/add_subscribe.py b/app/workflow/actions/add_subscribe.py index 77ed32c0..10460328 100644 --- a/app/workflow/actions/add_subscribe.py +++ b/app/workflow/actions/add_subscribe.py @@ -19,6 +19,11 @@ class AddSubscribeAction(BaseAction): 添加订阅 """ + contract = { + "inputs": [{"name": "medias", "label": "媒体", "kind": "list"}], + "outputs": [{"name": "subscribes", "label": "订阅", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._added_subscribes = [] diff --git a/app/workflow/actions/fetch_downloads.py b/app/workflow/actions/fetch_downloads.py index aebeb927..8689052d 100644 --- a/app/workflow/actions/fetch_downloads.py +++ b/app/workflow/actions/fetch_downloads.py @@ -16,6 +16,12 @@ class FetchDownloadsAction(BaseAction): 获取下载任务 """ + contract = { + "inputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}], + "outputs": [{"name": "downloads", "label": "下载任务", "kind": "list", "merge": "replace"}], + "concurrency_key": "download", + } + def __init__(self, action_id: str): super().__init__(action_id) self._downloads = [] diff --git a/app/workflow/actions/fetch_medias.py b/app/workflow/actions/fetch_medias.py index ff6cc1f4..d2c1fdf4 100644 --- a/app/workflow/actions/fetch_medias.py +++ b/app/workflow/actions/fetch_medias.py @@ -27,6 +27,10 @@ class FetchMediasAction(BaseAction): 获取媒体数据 """ + contract = { + "outputs": [{"name": "medias", "label": "媒体", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) diff --git a/app/workflow/actions/fetch_rss.py b/app/workflow/actions/fetch_rss.py index 3fbd5d30..6a64c07f 100644 --- a/app/workflow/actions/fetch_rss.py +++ b/app/workflow/actions/fetch_rss.py @@ -30,6 +30,10 @@ class FetchRssAction(BaseAction): 获取RSS资源列表 """ + contract = { + "outputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._rss_torrents = [] diff --git a/app/workflow/actions/fetch_torrents.py b/app/workflow/actions/fetch_torrents.py index cad332d3..c8a6c5dd 100644 --- a/app/workflow/actions/fetch_torrents.py +++ b/app/workflow/actions/fetch_torrents.py @@ -30,6 +30,11 @@ class FetchTorrentsAction(BaseAction): 搜索站点资源 """ + contract = { + "inputs": [{"name": "medias", "label": "媒体", "kind": "list"}], + "outputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._torrents = [] diff --git a/app/workflow/actions/filter_medias.py b/app/workflow/actions/filter_medias.py index 94667172..b4234e6a 100644 --- a/app/workflow/actions/filter_medias.py +++ b/app/workflow/actions/filter_medias.py @@ -22,6 +22,11 @@ class FilterMediasAction(BaseAction): 过滤媒体数据 """ + contract = { + "inputs": [{"name": "medias", "label": "媒体", "kind": "list"}], + "outputs": [{"name": "medias", "label": "媒体", "kind": "list", "merge": "replace"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._medias = [] diff --git a/app/workflow/actions/filter_torrents.py b/app/workflow/actions/filter_torrents.py index 23eed7e4..46015a35 100644 --- a/app/workflow/actions/filter_torrents.py +++ b/app/workflow/actions/filter_torrents.py @@ -27,6 +27,11 @@ class FilterTorrentsAction(BaseAction): 过滤资源数据 """ + contract = { + "inputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + "outputs": [{"name": "torrents", "label": "资源", "kind": "list", "merge": "replace"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._torrents = [] diff --git a/app/workflow/actions/invoke_plugin.py b/app/workflow/actions/invoke_plugin.py index ce50a163..0a652c3c 100644 --- a/app/workflow/actions/invoke_plugin.py +++ b/app/workflow/actions/invoke_plugin.py @@ -20,6 +20,8 @@ class InvokePluginAction(BaseAction): 调用插件 """ + contract = {} + def __init__(self, action_id: str): super().__init__(action_id) self._success = False diff --git a/app/workflow/actions/note.py b/app/workflow/actions/note.py index 13086757..ed9d1767 100644 --- a/app/workflow/actions/note.py +++ b/app/workflow/actions/note.py @@ -7,6 +7,8 @@ class NoteAction(BaseAction): 备注 """ + contract = {} + @classmethod @property def name(cls) -> str: # noqa diff --git a/app/workflow/actions/scan_file.py b/app/workflow/actions/scan_file.py index eb64219e..d7935a61 100644 --- a/app/workflow/actions/scan_file.py +++ b/app/workflow/actions/scan_file.py @@ -24,6 +24,10 @@ class ScanFileAction(BaseAction): 整理文件 """ + contract = { + "outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._fileitems = [] diff --git a/app/workflow/actions/scrape_file.py b/app/workflow/actions/scrape_file.py index 7f4ceab3..30f29887 100644 --- a/app/workflow/actions/scrape_file.py +++ b/app/workflow/actions/scrape_file.py @@ -18,6 +18,11 @@ class ScrapeFileAction(BaseAction): 刮削文件 """ + contract = { + "inputs": [{"name": "fileitems", "label": "文件", "kind": "list"}], + "outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}], + } + def __init__(self, action_id: str): super().__init__(action_id) self._scraped_files = [] diff --git a/app/workflow/actions/send_event.py b/app/workflow/actions/send_event.py index 7d96e6f0..c03ca0c1 100644 --- a/app/workflow/actions/send_event.py +++ b/app/workflow/actions/send_event.py @@ -16,6 +16,8 @@ class SendEventAction(BaseAction): 发送事件 """ + contract = {} + @classmethod @property def name(cls) -> str: # noqa diff --git a/app/workflow/actions/send_message.py b/app/workflow/actions/send_message.py index 2bf0bda1..c67a8ca2 100644 --- a/app/workflow/actions/send_message.py +++ b/app/workflow/actions/send_message.py @@ -20,6 +20,8 @@ class SendMessageAction(BaseAction): 发送消息 """ + contract = {} + def __init__(self, action_id: str): super().__init__(action_id) diff --git a/app/workflow/actions/transfer_file.py b/app/workflow/actions/transfer_file.py index 08e95991..9f3d16c2 100644 --- a/app/workflow/actions/transfer_file.py +++ b/app/workflow/actions/transfer_file.py @@ -26,6 +26,15 @@ class TransferFileAction(BaseAction): 整理文件 """ + contract = { + "inputs": [ + {"name": "downloads", "label": "下载任务", "kind": "list"}, + {"name": "fileitems", "label": "文件", "kind": "list"}, + ], + "outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}], + "concurrency_key": "transfer", + } + def __init__(self, action_id: str): super().__init__(action_id) self._fileitems = [] diff --git a/tests/test_workflow_actions.py b/tests/test_workflow_actions.py index f4cdbddc..770e927d 100644 --- a/tests/test_workflow_actions.py +++ b/tests/test_workflow_actions.py @@ -1,10 +1,14 @@ from types import SimpleNamespace from app.schemas import ActionContext, DownloadTask, FileItem +from app.schemas.workflow import ActionResult +from app.workflow.actions import BaseAction from app.workflow.actions import fetch_downloads as fetch_downloads_module from app.workflow.actions import scrape_file as scrape_file_module from app.workflow.actions.fetch_downloads import FetchDownloadsAction from app.workflow.actions.scrape_file import ScrapeFileAction +from app.workflow.actions.fetch_rss import FetchRssAction +from app.workflow import WorkFlowManager def test_fetch_downloads_updates_context_downloads(monkeypatch): @@ -78,3 +82,80 @@ def test_scrape_file_keeps_workflow_action_context(monkeypatch): assert result is context assert result.fileitems[0].path == "/library/movie.mkv" assert scraped == [("/library/movie.mkv", "meta", "media")] + + +def test_execute_with_inputs_maps_contract_inputs_outputs_and_runtime(monkeypatch): + """新版动作桥接方法应按契约映射输入、输出和运行期信息。""" + + class ContractAction(BaseAction): + """测试动作契约桥接。""" + + contract = { + "inputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + "outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}], + } + + @classmethod + @property + def name(cls) -> str: + return "契约动作" + + @classmethod + @property + def description(cls) -> str: + return "测试契约动作" + + @classmethod + @property + def data(cls) -> dict: + return {} + + @property + def success(self) -> bool: + return True + + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: + """执行测试动作。""" + _ = workflow_id, params + context.downloads = [ + DownloadTask(download_id=f"{item}-hash", downloader="qbittorrent") + for item in context.torrents + ] + self.job_done("完成") + return context + + result = ContractAction("contract").execute_with_inputs( + workflow_id=1, + params={}, + inputs={"torrents": ["movie"]}, + runtime={"attempt": 1, "max_attempts": 1, "cancel_token": object()}, + context=ActionContext(), + ) + + assert isinstance(result, ActionResult) + assert result.outputs["downloads"][0].download_id == "movie-hash" + assert result.context.runtime_state["current_action_runtime"] == { + "attempt": 1, + "max_attempts": 1, + } + + path_result = ContractAction("contract").execute_with_inputs( + workflow_id=1, + params={}, + inputs={"outputs.FetchRssAction.torrents": ["legacy"]}, + runtime={}, + context=ActionContext(), + ) + + assert path_result.outputs["downloads"][0].download_id == "legacy-hash" + + +def test_workflow_manager_list_actions_exposes_contract(): + """动作列表应返回固定输入输出契约。""" + manager = object.__new__(WorkFlowManager) + manager._actions = {"FetchRssAction": FetchRssAction} + + actions = manager.list_actions() + + assert actions[0]["contract"]["outputs"][0]["name"] == "torrents" + assert actions[0]["contract"]["condition_fields"][0]["label"] == "资源" diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index b93288a4..7da5623a 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -40,9 +40,10 @@ def _encoded_context(context: ActionContext) -> dict: class _FakeWorkflowManager: """记录执行动作的工作流管理器。""" - def __init__(self, calls, results=None): + def __init__(self, calls, results=None, contracts=None): self.calls = calls self.results = results or {} + self.contracts = contracts or {} self.received_inputs = [] def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None): @@ -61,6 +62,10 @@ class _FakeWorkflowManager: result = self.execute(workflow_id, action, context) return result.success, result.message, result.context + def get_action_contract(self, action_type): + """获取伪动作契约。""" + return self.contracts.get(action_type) or {} + def test_workflow_executor_resumes_downstream_nodes(monkeypatch): """恢复执行时应释放已完成节点的后继节点。""" @@ -346,6 +351,44 @@ def test_workflow_executor_passes_declared_inputs(monkeypatch): } +def test_workflow_executor_uses_contract_inputs(monkeypatch): + """未手写输入声明时应按动作契约读取上下文字段。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + contracts={ + "NeedsTorrentsAction": { + "inputs": [{"name": "torrents", "label": "资源", "kind": "list"}], + "outputs": [], + } + }, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"torrents": ["a", "b"]} + ) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "NeedsTorrentsAction", "name": "动作B", "data": {}}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0] + assert b_inputs == {"torrents": ["a", "b"]} + + def test_workflow_executor_persists_structured_state(monkeypatch): """步骤回调应收到可持久化的结构化执行状态。""" calls = []