feat(workflow): enhance workflow context serialization and execution state management

This commit is contained in:
jxxghp
2026-06-05 00:41:02 +08:00
parent 51981d151e
commit fc8933c648
3 changed files with 199 additions and 51 deletions

View File

@@ -6,11 +6,11 @@ import pickle
import threading
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import date, datetime
from time import sleep
from typing import Any, Callable, List, Optional, Tuple
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from app.chain import ChainBase
from app.core.config import global_vars
@@ -22,9 +22,64 @@ from app.schemas import ActionContext, ActionFlow, Action, ActionExecution, Acti
from app.schemas.types import EventType
from app.workflow import WorkFlowManager
ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"}
DEFAULT_WORKFLOW_MAX_WORKERS = 4
CIRCULAR_REFERENCE_PLACEHOLDER = "[Circular]"
def _serialize_workflow_key(key: Any) -> Any:
"""将映射键转换为 JSON 安全值。"""
if key is None or isinstance(key, (str, int, float, bool)):
return key
return str(key)
def _serialize_workflow_value(value: Any, stack: Optional[set[int]] = None) -> Any:
"""把工作流上下文和值转换为可持久化的 JSON 结构。"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (date, datetime)):
return value.isoformat()
stack = stack or set()
object_id = id(value)
if object_id in stack:
return CIRCULAR_REFERENCE_PLACEHOLDER
if isinstance(value, BaseModel):
stack.add(object_id)
try:
return {
field_name: _serialize_workflow_value(getattr(value, field_name, None), stack)
for field_name in value.__class__.model_fields
}
finally:
stack.remove(object_id)
if isinstance(value, dict):
stack.add(object_id)
try:
return {
_serialize_workflow_key(key): _serialize_workflow_value(item, stack)
for key, item in value.items()
}
finally:
stack.remove(object_id)
if isinstance(value, (list, tuple, set)):
stack.add(object_id)
try:
return [_serialize_workflow_value(item, stack) for item in value]
finally:
stack.remove(object_id)
return str(value)
def _serialize_workflow_context(context: ActionContext) -> dict:
"""构建可写入数据库的工作流上下文字典。"""
serialized = _serialize_workflow_value(context)
return serialized if isinstance(serialized, dict) else {}
class WorkflowCancelToken:
@@ -185,13 +240,16 @@ class WorkflowExecutor:
恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。
"""
context = ActionContext()
if self.workflow.current_action and self.workflow.context:
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
if self.workflow.context:
if self.workflow.current_action:
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
try:
decoded_data = base64.b64decode(self.workflow.context["content"])
context = pickle.loads(decoded_data)
except Exception as err:
logger.error(f"工作流上下文恢复失败: {str(err)}")
if isinstance(self.workflow.context, dict) and self.workflow.context.get("content"):
decoded_data = base64.b64decode(self.workflow.context["content"])
context = pickle.loads(decoded_data)
elif isinstance(self.workflow.context, dict):
context = ActionContext.model_validate(self.workflow.context)
except Exception:
context = ActionContext()
outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {}
if outputs and not getattr(context, "node_outputs", None):
@@ -247,38 +305,14 @@ class WorkflowExecutor:
构建可持久化的结构化执行状态。
"""
self.update_runtime_state()
return self.make_json_safe({
execution_state = _serialize_workflow_value({
"version": 1,
"nodes": self.node_metadata,
"outputs": self.context.node_outputs,
"errors": self.errors,
"runtime": self.context.runtime_state,
})
@staticmethod
def make_json_safe(value: Any) -> Any:
"""
将运行期对象转换为可写入 JSON 列的数据。
"""
if isinstance(value, dict):
return {
str(key): WorkflowExecutor.make_json_safe(item)
for key, item in value.items()
}
if isinstance(value, (list, tuple, set)):
return [WorkflowExecutor.make_json_safe(item) for item in value]
try:
encoded_value = jsonable_encoder(value)
except Exception:
return str(value)
if isinstance(encoded_value, dict):
return {
str(key): WorkflowExecutor.make_json_safe(item)
for key, item in encoded_value.items()
}
if isinstance(encoded_value, list):
return [WorkflowExecutor.make_json_safe(item) for item in encoded_value]
return encoded_value
return execution_state if isinstance(execution_state, dict) else {}
def execute(self) -> None:
"""
@@ -603,7 +637,8 @@ class WorkflowExecutor:
"runtime_state": self.context.runtime_state,
}
def ensure_result_context_partitions(self, context: ActionContext) -> None:
@staticmethod
def ensure_result_context_partitions(context: ActionContext) -> None:
"""
确保动作返回上下文具备新版分区字段。
"""
@@ -705,13 +740,19 @@ class WorkflowExecutor:
return outputs_config
return self.get_action_contract(action).get("outputs") or {}
def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str:
@staticmethod
def get_default_merge_policy(action: Action, key: str, value: Any) -> str:
"""
获取输出默认合并策略。
"""
if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"):
return "replace"
if isinstance(value, list):
action_type = action.type or ""
action_name = action.name or ""
if key in ARTIFACT_FIELDS and (
action_type.startswith("Filter")
or action_name.startswith("过滤")
):
return "replace"
return "append_unique"
if isinstance(value, dict):
return "merge_dict"
@@ -756,12 +797,11 @@ class WorkflowExecutor:
获取列表元素去重身份。
"""
if not identity:
identity_value = self.make_json_safe(item)
return self.make_hashable_identity(identity_value)
return self.make_hashable_identity(item)
value = item
for part in identity.split("."):
value = self.read_value(value, int(part) if part.isdigit() else part)
return self.make_hashable_identity(self.make_json_safe(value))
return self.make_hashable_identity(item)
@staticmethod
def make_hashable_identity(value: Any) -> Any:
@@ -1102,17 +1142,11 @@ class WorkflowChain(ChainBase):
"""
保存上下文到数据库
"""
# 序列化数据
serialized_data = pickle.dumps(context)
# 使用Base64编码字节流
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
WorkflowOper().step(
workflow_id,
action_id=action.id if completed else "",
context={
"content": encoded_data
},
execution_state=execution_state
context=_serialize_workflow_context(context),
execution_state=_serialize_workflow_value(execution_state)
)
# 重置工作流

View File

@@ -100,7 +100,7 @@ class WorkflowOper(DbOper):
wid,
action_id,
context,
execution_state=execution_state
execution_state
)
def reset(self, wid: int, reset_count: bool = False) -> bool:

View File

@@ -67,6 +67,57 @@ class _FakeWorkflowManager:
return self.contracts.get(action_type) or {}
class _FakeWorkflowOper:
"""记录工作流持久化调用。"""
def __init__(self, workflow):
self.workflow = workflow
self.steps = []
self.started = False
self.failed_result = None
self.succeeded = False
def reset(self, wid):
"""模拟重置工作流。"""
_ = wid
return True
def get(self, wid):
"""返回预置工作流。"""
_ = wid
return self.workflow
def start(self, wid):
"""记录启动调用。"""
_ = wid
self.started = True
return True
def step(self, wid, action_id, context, execution_state=None):
"""记录步骤持久化数据。"""
self.steps.append(
{
"wid": wid,
"action_id": action_id,
"context": context,
"execution_state": execution_state,
}
)
return True
def fail(self, wid, result):
"""记录失败结果。"""
_ = wid
self.failed_result = result
return True
def success(self, wid, result=None):
"""记录成功结果。"""
_ = wid, result
self.succeeded = True
return True
class _OpaqueValue:
"""模拟无法直接 JSON 序列化的值。"""
@@ -97,6 +148,31 @@ def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
assert executor.context.progress == 100
def test_workflow_executor_restores_structured_context(monkeypatch):
"""恢复执行时应兼容新版结构化上下文存储格式。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
current_action="A",
context={
"workflow_context": {"trace_id": "wf-1"},
"node_outputs": {"A": {"items": ["movie"]}},
"progress": 50,
},
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.context.workflow_context["trace_id"] == "wf-1"
assert executor.context.node_outputs["A"]["items"] == ["movie"]
def test_workflow_executor_reports_incremental_progress(monkeypatch):
"""顺序工作流的中间进度应按已完成比例计算。"""
calls = []
@@ -488,6 +564,44 @@ def test_workflow_executor_keeps_execution_state_dict_for_non_json_leaf(monkeypa
assert states[-1]["outputs"]["A"]["opaque"] == "opaque-value"
def test_workflow_chain_process_serializes_circular_context(monkeypatch):
"""工作流步骤持久化应清洗循环引用和不可序列化上下文。"""
calls = []
def run_action(action, context):
"""构造包含循环引用的上下文。"""
context.workflow_context["self"] = context.workflow_context
context.workflow_context["opaque"] = _OpaqueValue()
return ActionResult(success=True, message=f"{action.name}完成", context=context)
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action})
workflow = _build_workflow(
actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}],
flows=[{"id": "flow-end", "source": "A", "target": "END", "animated": True}],
)
fake_oper = _FakeWorkflowOper(workflow)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module, "WorkflowOper", lambda: fake_oper)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
success, message = workflow_module.WorkflowChain.process(workflow_id=1)
assert success is True
assert message == ""
assert fake_oper.succeeded is True
saved_workflow_context = fake_oper.steps[-1]["context"]["workflow_context"]
saved_self = saved_workflow_context["self"]
assert saved_workflow_context["opaque"] == "opaque-value"
if isinstance(saved_self, dict):
assert saved_self["self"] == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
assert saved_self["opaque"] == "opaque-value"
else:
assert saved_self == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
"""相同 concurrency_key 的并行节点不应同时运行。"""
calls = []