From fc8933c64894f177f0e8d7ab60885a26c234df79 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 5 Jun 2026 00:41:02 +0800 Subject: [PATCH] feat(workflow): enhance workflow context serialization and execution state management --- app/chain/workflow.py | 134 +++++++++++++++++++------------ app/db/workflow_oper.py | 2 +- tests/test_workflow_execution.py | 114 ++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 51 deletions(-) diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 7e47590e..f726ce44 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -6,11 +6,11 @@ import pickle import threading from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor -from datetime import datetime +from datetime import date, datetime from time import sleep from typing import Any, Callable, List, Optional, Tuple -from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel from app.chain import ChainBase from app.core.config import global_vars @@ -22,9 +22,64 @@ from app.schemas import ActionContext, ActionFlow, Action, ActionExecution, Acti from app.schemas.types import EventType from app.workflow import WorkFlowManager - ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"} DEFAULT_WORKFLOW_MAX_WORKERS = 4 +CIRCULAR_REFERENCE_PLACEHOLDER = "[Circular]" + + +def _serialize_workflow_key(key: Any) -> Any: + """将映射键转换为 JSON 安全值。""" + if key is None or isinstance(key, (str, int, float, bool)): + return key + return str(key) + + +def _serialize_workflow_value(value: Any, stack: Optional[set[int]] = None) -> Any: + """把工作流上下文和值转换为可持久化的 JSON 结构。""" + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (date, datetime)): + return value.isoformat() + + stack = stack or set() + object_id = id(value) + if object_id in stack: + return CIRCULAR_REFERENCE_PLACEHOLDER + + if isinstance(value, BaseModel): + stack.add(object_id) + try: + return { + field_name: _serialize_workflow_value(getattr(value, field_name, None), stack) + for field_name in value.__class__.model_fields + } + finally: + stack.remove(object_id) + + if isinstance(value, dict): + stack.add(object_id) + try: + return { + _serialize_workflow_key(key): _serialize_workflow_value(item, stack) + for key, item in value.items() + } + finally: + stack.remove(object_id) + + if isinstance(value, (list, tuple, set)): + stack.add(object_id) + try: + return [_serialize_workflow_value(item, stack) for item in value] + finally: + stack.remove(object_id) + + return str(value) + + +def _serialize_workflow_context(context: ActionContext) -> dict: + """构建可写入数据库的工作流上下文字典。""" + serialized = _serialize_workflow_value(context) + return serialized if isinstance(serialized, dict) else {} class WorkflowCancelToken: @@ -185,13 +240,16 @@ class WorkflowExecutor: 恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。 """ context = ActionContext() - if self.workflow.current_action and self.workflow.context: - logger.info(f"工作流已执行动作:{self.workflow.current_action}") + if self.workflow.context: + if self.workflow.current_action: + logger.info(f"工作流已执行动作:{self.workflow.current_action}") try: - decoded_data = base64.b64decode(self.workflow.context["content"]) - context = pickle.loads(decoded_data) - except Exception as err: - logger.error(f"工作流上下文恢复失败: {str(err)}") + if isinstance(self.workflow.context, dict) and self.workflow.context.get("content"): + decoded_data = base64.b64decode(self.workflow.context["content"]) + context = pickle.loads(decoded_data) + elif isinstance(self.workflow.context, dict): + context = ActionContext.model_validate(self.workflow.context) + except Exception: context = ActionContext() outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {} if outputs and not getattr(context, "node_outputs", None): @@ -247,38 +305,14 @@ class WorkflowExecutor: 构建可持久化的结构化执行状态。 """ self.update_runtime_state() - return self.make_json_safe({ + execution_state = _serialize_workflow_value({ "version": 1, "nodes": self.node_metadata, "outputs": self.context.node_outputs, "errors": self.errors, "runtime": self.context.runtime_state, }) - - @staticmethod - def make_json_safe(value: Any) -> Any: - """ - 将运行期对象转换为可写入 JSON 列的数据。 - """ - if isinstance(value, dict): - return { - str(key): WorkflowExecutor.make_json_safe(item) - for key, item in value.items() - } - if isinstance(value, (list, tuple, set)): - return [WorkflowExecutor.make_json_safe(item) for item in value] - try: - encoded_value = jsonable_encoder(value) - except Exception: - return str(value) - if isinstance(encoded_value, dict): - return { - str(key): WorkflowExecutor.make_json_safe(item) - for key, item in encoded_value.items() - } - if isinstance(encoded_value, list): - return [WorkflowExecutor.make_json_safe(item) for item in encoded_value] - return encoded_value + return execution_state if isinstance(execution_state, dict) else {} def execute(self) -> None: """ @@ -603,7 +637,8 @@ class WorkflowExecutor: "runtime_state": self.context.runtime_state, } - def ensure_result_context_partitions(self, context: ActionContext) -> None: + @staticmethod + def ensure_result_context_partitions(context: ActionContext) -> None: """ 确保动作返回上下文具备新版分区字段。 """ @@ -705,13 +740,19 @@ class WorkflowExecutor: return outputs_config return self.get_action_contract(action).get("outputs") or {} - def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str: + @staticmethod + def get_default_merge_policy(action: Action, key: str, value: Any) -> str: """ 获取输出默认合并策略。 """ - if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"): - return "replace" if isinstance(value, list): + action_type = action.type or "" + action_name = action.name or "" + if key in ARTIFACT_FIELDS and ( + action_type.startswith("Filter") + or action_name.startswith("过滤") + ): + return "replace" return "append_unique" if isinstance(value, dict): return "merge_dict" @@ -756,12 +797,11 @@ class WorkflowExecutor: 获取列表元素去重身份。 """ if not identity: - identity_value = self.make_json_safe(item) - return self.make_hashable_identity(identity_value) + return self.make_hashable_identity(item) value = item for part in identity.split("."): value = self.read_value(value, int(part) if part.isdigit() else part) - return self.make_hashable_identity(self.make_json_safe(value)) + return self.make_hashable_identity(item) @staticmethod def make_hashable_identity(value: Any) -> Any: @@ -1102,17 +1142,11 @@ class WorkflowChain(ChainBase): """ 保存上下文到数据库 """ - # 序列化数据 - serialized_data = pickle.dumps(context) - # 使用Base64编码字节流 - encoded_data = base64.b64encode(serialized_data).decode('utf-8') WorkflowOper().step( workflow_id, action_id=action.id if completed else "", - context={ - "content": encoded_data - }, - execution_state=execution_state + context=_serialize_workflow_context(context), + execution_state=_serialize_workflow_value(execution_state) ) # 重置工作流 diff --git a/app/db/workflow_oper.py b/app/db/workflow_oper.py index f4c82942..73e2f87d 100644 --- a/app/db/workflow_oper.py +++ b/app/db/workflow_oper.py @@ -100,7 +100,7 @@ class WorkflowOper(DbOper): wid, action_id, context, - execution_state=execution_state + execution_state ) def reset(self, wid: int, reset_count: bool = False) -> bool: diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index 34faf97f..4337b7bb 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -67,6 +67,57 @@ class _FakeWorkflowManager: return self.contracts.get(action_type) or {} +class _FakeWorkflowOper: + """记录工作流持久化调用。""" + + def __init__(self, workflow): + self.workflow = workflow + self.steps = [] + self.started = False + self.failed_result = None + self.succeeded = False + + def reset(self, wid): + """模拟重置工作流。""" + _ = wid + return True + + def get(self, wid): + """返回预置工作流。""" + _ = wid + return self.workflow + + def start(self, wid): + """记录启动调用。""" + _ = wid + self.started = True + return True + + def step(self, wid, action_id, context, execution_state=None): + """记录步骤持久化数据。""" + self.steps.append( + { + "wid": wid, + "action_id": action_id, + "context": context, + "execution_state": execution_state, + } + ) + return True + + def fail(self, wid, result): + """记录失败结果。""" + _ = wid + self.failed_result = result + return True + + def success(self, wid, result=None): + """记录成功结果。""" + _ = wid, result + self.succeeded = True + return True + + class _OpaqueValue: """模拟无法直接 JSON 序列化的值。""" @@ -97,6 +148,31 @@ def test_workflow_executor_resumes_downstream_nodes(monkeypatch): assert executor.context.progress == 100 +def test_workflow_executor_restores_structured_context(monkeypatch): + """恢复执行时应兼容新版结构化上下文存储格式。""" + calls = [] + fake_manager = _FakeWorkflowManager(calls) + workflow = _build_workflow( + current_action="A", + context={ + "workflow_context": {"trace_id": "wf-1"}, + "node_outputs": {"A": {"items": ["movie"]}}, + "progress": 50, + }, + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["B"] + assert executor.context.workflow_context["trace_id"] == "wf-1" + assert executor.context.node_outputs["A"]["items"] == ["movie"] + + def test_workflow_executor_reports_incremental_progress(monkeypatch): """顺序工作流的中间进度应按已完成比例计算。""" calls = [] @@ -488,6 +564,44 @@ def test_workflow_executor_keeps_execution_state_dict_for_non_json_leaf(monkeypa assert states[-1]["outputs"]["A"]["opaque"] == "opaque-value" +def test_workflow_chain_process_serializes_circular_context(monkeypatch): + """工作流步骤持久化应清洗循环引用和不可序列化上下文。""" + calls = [] + + def run_action(action, context): + """构造包含循环引用的上下文。""" + context.workflow_context["self"] = context.workflow_context + context.workflow_context["opaque"] = _OpaqueValue() + return ActionResult(success=True, message=f"{action.name}完成", context=context) + + fake_manager = _FakeWorkflowManager(calls, results={"A": run_action}) + workflow = _build_workflow( + actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], + flows=[{"id": "flow-end", "source": "A", "target": "END", "animated": True}], + ) + fake_oper = _FakeWorkflowOper(workflow) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module, "WorkflowOper", lambda: fake_oper) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + success, message = workflow_module.WorkflowChain.process(workflow_id=1) + + assert success is True + assert message == "" + assert fake_oper.succeeded is True + saved_workflow_context = fake_oper.steps[-1]["context"]["workflow_context"] + saved_self = saved_workflow_context["self"] + + assert saved_workflow_context["opaque"] == "opaque-value" + if isinstance(saved_self, dict): + assert saved_self["self"] == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER + assert saved_self["opaque"] == "opaque-value" + else: + assert saved_self == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER + + def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch): """相同 concurrency_key 的并行节点不应同时运行。""" calls = []