feat(workflow): implement action contract management for inputs and outputs

This commit is contained in:
jxxghp
2026-06-04 21:06:25 +08:00
parent a2984530f8
commit 97cfcda03c
20 changed files with 341 additions and 10 deletions

View File

@@ -555,13 +555,28 @@ class WorkflowExecutor:
根据动作输入声明读取上游节点输出。
"""
inputs = {}
input_paths = action.inputs or self.get_action_data_value(action, "inputs") or []
input_paths = action.inputs or self.get_action_data_value(action, "inputs")
if not input_paths:
input_paths = [
field["name"] for field in self.get_action_contract(action).get("inputs") or []
]
if isinstance(input_paths, str):
input_paths = [item.strip() for item in input_paths.splitlines() if item.strip()]
for input_path in input_paths:
inputs[input_path] = self.resolve_context_path(input_path)
inputs[input_path] = self.resolve_action_input(input_path)
return inputs
def resolve_action_input(self, input_path: str) -> Any:
"""
解析动作输入声明。
"""
if input_path in ActionContext.model_fields:
value = getattr(self.context, input_path, None)
if value not in (None, "", [], {}):
return value
return self.context.artifacts.get(input_path) if self.context.artifacts else value
return self.resolve_context_path(input_path)
def build_action_runtime(self, action: Action) -> dict:
"""
构建传递给动作的新运行期数据。
@@ -588,9 +603,14 @@ class WorkflowExecutor:
根据动作输出声明整理当前节点输出。
"""
outputs = action_result.outputs or self.extract_context_outputs(result_context)
declared_outputs = action.outputs or self.get_action_data_value(action, "outputs")
declared_outputs = self.get_action_output_declarations(action)
if isinstance(declared_outputs, list):
return {key: outputs.get(key) for key in declared_outputs if outputs.get(key) not in (None, "", [], {})}
normalized_outputs = {}
for item in declared_outputs:
key = item.get("name") if isinstance(item, dict) else item
if key and outputs.get(key) not in (None, "", [], {}):
normalized_outputs[key] = outputs.get(key)
return normalized_outputs or outputs
if isinstance(declared_outputs, dict):
return {
key: outputs.get(key)
@@ -648,12 +668,28 @@ class WorkflowExecutor:
"""
获取动作输出声明配置。
"""
outputs_config = action.outputs or self.get_action_data_value(action, "outputs") or {}
outputs_config = self.get_action_output_declarations(action)
if isinstance(outputs_config, dict):
value = outputs_config.get(output_key) or {}
return value if isinstance(value, dict) else {}
if isinstance(outputs_config, list):
for item in outputs_config:
if isinstance(item, dict) and item.get("name") == output_key:
return {
key: value for key, value in item.items()
if key not in ("name", "label", "kind") and value not in (None, "", [], {})
}
return {}
def get_action_output_declarations(self, action: Action) -> Any:
"""
获取动作输出声明,优先使用节点显式配置,其次使用动作固定契约。
"""
outputs_config = action.outputs or self.get_action_data_value(action, "outputs")
if outputs_config:
return outputs_config
return self.get_action_contract(action).get("outputs") or {}
def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str:
"""
获取输出默认合并策略。
@@ -796,7 +832,20 @@ class WorkflowExecutor:
"""
获取动作并发互斥键。
"""
return action.concurrency_key or self.get_action_data_value(action, "concurrency_key")
return (
action.concurrency_key
or self.get_action_data_value(action, "concurrency_key")
or self.get_action_contract(action).get("concurrency_key")
)
def get_action_contract(self, action: Action) -> dict:
"""
获取动作固定输入输出契约。
"""
get_contract = getattr(self.workflowmanager, "get_action_contract", None)
if not get_contract:
return {}
return get_contract(action.type) or {}
def get_flow_condition(self, flow: ActionFlow) -> Optional[str]:
"""

View File

@@ -242,6 +242,7 @@ class WorkFlowManager(metaclass=Singleton):
"type": key,
"name": action.name,
"description": action.description,
"contract": action.get_contract(),
"data": {
"label": action.name,
**action.data
@@ -249,6 +250,15 @@ class WorkFlowManager(metaclass=Singleton):
} for key, action in self._actions.items()
]
def get_action_contract(self, action_type: str) -> dict:
"""
获取动作输入输出契约。
"""
action = self._actions.get(action_type)
if not action or not hasattr(action, "get_contract"):
return {}
return action.get_contract()
def update_workflow_event(self, workflow: Workflow):
"""
更新工作流事件触发器

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Any, Union
from app.chain import ChainBase
from app.db.systemconfig_oper import SystemConfigOper
@@ -23,6 +23,8 @@ class BaseAction(ABC):
_message = ""
# 缓存键值
_cache_key = "WorkflowCache-%s"
# 动作输入输出契约,由具体动作按需覆盖
contract = {}
def __init__(self, action_id: str):
self._action_id = action_id
@@ -48,6 +50,41 @@ class BaseAction(ABC):
def data(cls) -> dict: # noqa
pass
@classmethod
def get_contract(cls) -> dict:
"""
获取动作输入输出契约。
"""
contract = getattr(cls, "contract", None) or {}
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
return {
"inputs": input_fields,
"outputs": output_fields,
"condition_fields": output_fields,
"concurrency_key": contract.get("concurrency_key"),
}
@classmethod
def _build_contract_fields(cls, fields: list) -> list:
"""
标准化动作契约字段。
"""
result = []
for field in fields:
if isinstance(field, str):
field = {"name": field}
if not isinstance(field, dict) or not field.get("name"):
continue
result.append({
"name": field["name"],
"label": field.get("label") or field["name"],
"kind": field.get("kind") or "scalar",
"merge": field.get("merge"),
"identity": field.get("identity"),
})
return result
@property
def done(self) -> bool:
"""
@@ -115,10 +152,55 @@ class BaseAction(ABC):
"""
使用显式输入与运行期信息执行动作。
"""
_ = inputs, runtime
self._apply_inputs_to_context(inputs=inputs, context=context)
self._apply_runtime_to_context(runtime=runtime, context=context)
result_context = self.execute(workflow_id, params, context)
outputs = self._extract_outputs_from_context(result_context)
return ActionResult(
success=self.success,
message=self.message,
context=result_context
context=result_context,
outputs=outputs
)
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
"""
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
"""
inputs = inputs or {}
for field in self.get_contract().get("inputs") or []:
missing = object()
field_name = field["name"]
value = inputs.get(field_name, missing)
if value is missing:
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
for input_key, input_value in inputs.items():
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
value = input_value
break
if value is not missing:
setattr(context, field_name, value)
@staticmethod
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
"""
将运行期信息写入 runtime_state供动作和执行状态读取。
"""
if not runtime:
return
context.runtime_state = context.runtime_state or {}
context.runtime_state["current_action_runtime"] = {
key: value for key, value in runtime.items()
if key != "cancel_token"
}
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
"""
按动作契约从上下文提取输出。
"""
outputs = {}
for field in self.get_contract().get("outputs") or []:
value = getattr(context, field["name"], None)
if value not in (None, "", [], {}):
outputs[field["name"]] = value
return outputs

View File

@@ -26,6 +26,12 @@ class AddDownloadAction(BaseAction):
添加下载资源
"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
"concurrency_key": "download",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._added_downloads = []

View File

@@ -19,6 +19,11 @@ class AddSubscribeAction(BaseAction):
添加订阅
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "subscribes", "label": "订阅", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._added_subscribes = []

View File

@@ -16,6 +16,12 @@ class FetchDownloadsAction(BaseAction):
获取下载任务
"""
contract = {
"inputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list", "merge": "replace"}],
"concurrency_key": "download",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._downloads = []

View File

@@ -27,6 +27,10 @@ class FetchMediasAction(BaseAction):
获取媒体数据
"""
contract = {
"outputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)

View File

@@ -30,6 +30,10 @@ class FetchRssAction(BaseAction):
获取RSS资源列表
"""
contract = {
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._rss_torrents = []

View File

@@ -30,6 +30,11 @@ class FetchTorrentsAction(BaseAction):
搜索站点资源
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._torrents = []

View File

@@ -22,6 +22,11 @@ class FilterMediasAction(BaseAction):
过滤媒体数据
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "medias", "label": "媒体", "kind": "list", "merge": "replace"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._medias = []

View File

@@ -27,6 +27,11 @@ class FilterTorrentsAction(BaseAction):
过滤资源数据
"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "torrents", "label": "资源", "kind": "list", "merge": "replace"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._torrents = []

View File

@@ -20,6 +20,8 @@ class InvokePluginAction(BaseAction):
调用插件
"""
contract = {}
def __init__(self, action_id: str):
super().__init__(action_id)
self._success = False

View File

@@ -7,6 +7,8 @@ class NoteAction(BaseAction):
备注
"""
contract = {}
@classmethod
@property
def name(cls) -> str: # noqa

View File

@@ -24,6 +24,10 @@ class ScanFileAction(BaseAction):
整理文件
"""
contract = {
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._fileitems = []

View File

@@ -18,6 +18,11 @@ class ScrapeFileAction(BaseAction):
刮削文件
"""
contract = {
"inputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._scraped_files = []

View File

@@ -16,6 +16,8 @@ class SendEventAction(BaseAction):
发送事件
"""
contract = {}
@classmethod
@property
def name(cls) -> str: # noqa

View File

@@ -20,6 +20,8 @@ class SendMessageAction(BaseAction):
发送消息
"""
contract = {}
def __init__(self, action_id: str):
super().__init__(action_id)

View File

@@ -26,6 +26,15 @@ class TransferFileAction(BaseAction):
整理文件
"""
contract = {
"inputs": [
{"name": "downloads", "label": "下载任务", "kind": "list"},
{"name": "fileitems", "label": "文件", "kind": "list"},
],
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
"concurrency_key": "transfer",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._fileitems = []

View File

@@ -1,10 +1,14 @@
from types import SimpleNamespace
from app.schemas import ActionContext, DownloadTask, FileItem
from app.schemas.workflow import ActionResult
from app.workflow.actions import BaseAction
from app.workflow.actions import fetch_downloads as fetch_downloads_module
from app.workflow.actions import scrape_file as scrape_file_module
from app.workflow.actions.fetch_downloads import FetchDownloadsAction
from app.workflow.actions.scrape_file import ScrapeFileAction
from app.workflow.actions.fetch_rss import FetchRssAction
from app.workflow import WorkFlowManager
def test_fetch_downloads_updates_context_downloads(monkeypatch):
@@ -78,3 +82,80 @@ def test_scrape_file_keeps_workflow_action_context(monkeypatch):
assert result is context
assert result.fileitems[0].path == "/library/movie.mkv"
assert scraped == [("/library/movie.mkv", "meta", "media")]
def test_execute_with_inputs_maps_contract_inputs_outputs_and_runtime(monkeypatch):
"""新版动作桥接方法应按契约映射输入、输出和运行期信息。"""
class ContractAction(BaseAction):
"""测试动作契约桥接。"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
}
@classmethod
@property
def name(cls) -> str:
return "契约动作"
@classmethod
@property
def description(cls) -> str:
return "测试契约动作"
@classmethod
@property
def data(cls) -> dict:
return {}
@property
def success(self) -> bool:
return True
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""执行测试动作。"""
_ = workflow_id, params
context.downloads = [
DownloadTask(download_id=f"{item}-hash", downloader="qbittorrent")
for item in context.torrents
]
self.job_done("完成")
return context
result = ContractAction("contract").execute_with_inputs(
workflow_id=1,
params={},
inputs={"torrents": ["movie"]},
runtime={"attempt": 1, "max_attempts": 1, "cancel_token": object()},
context=ActionContext(),
)
assert isinstance(result, ActionResult)
assert result.outputs["downloads"][0].download_id == "movie-hash"
assert result.context.runtime_state["current_action_runtime"] == {
"attempt": 1,
"max_attempts": 1,
}
path_result = ContractAction("contract").execute_with_inputs(
workflow_id=1,
params={},
inputs={"outputs.FetchRssAction.torrents": ["legacy"]},
runtime={},
context=ActionContext(),
)
assert path_result.outputs["downloads"][0].download_id == "legacy-hash"
def test_workflow_manager_list_actions_exposes_contract():
"""动作列表应返回固定输入输出契约。"""
manager = object.__new__(WorkFlowManager)
manager._actions = {"FetchRssAction": FetchRssAction}
actions = manager.list_actions()
assert actions[0]["contract"]["outputs"][0]["name"] == "torrents"
assert actions[0]["contract"]["condition_fields"][0]["label"] == "资源"

View File

@@ -40,9 +40,10 @@ def _encoded_context(context: ActionContext) -> dict:
class _FakeWorkflowManager:
"""记录执行动作的工作流管理器。"""
def __init__(self, calls, results=None):
def __init__(self, calls, results=None, contracts=None):
self.calls = calls
self.results = results or {}
self.contracts = contracts or {}
self.received_inputs = []
def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None):
@@ -61,6 +62,10 @@ class _FakeWorkflowManager:
result = self.execute(workflow_id, action, context)
return result.success, result.message, result.context
def get_action_contract(self, action_type):
"""获取伪动作契约。"""
return self.contracts.get(action_type) or {}
def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
"""恢复执行时应释放已完成节点的后继节点。"""
@@ -346,6 +351,44 @@ def test_workflow_executor_passes_declared_inputs(monkeypatch):
}
def test_workflow_executor_uses_contract_inputs(monkeypatch):
"""未手写输入声明时应按动作契约读取上下文字段。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
contracts={
"NeedsTorrentsAction": {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [],
}
},
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["a", "b"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "NeedsTorrentsAction", "name": "动作B", "data": {}},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
assert b_inputs == {"torrents": ["a", "b"]}
def test_workflow_executor_persists_structured_state(monkeypatch):
"""步骤回调应收到可持久化的结构化执行状态。"""
calls = []