mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-08 07:26:48 +00:00
feat(workflow): enhance workflow context serialization and execution state management
This commit is contained in:
@@ -6,11 +6,11 @@ import pickle
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from time import sleep
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
@@ -22,9 +22,64 @@ from app.schemas import ActionContext, ActionFlow, Action, ActionExecution, Acti
|
||||
from app.schemas.types import EventType
|
||||
from app.workflow import WorkFlowManager
|
||||
|
||||
|
||||
ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"}
|
||||
DEFAULT_WORKFLOW_MAX_WORKERS = 4
|
||||
CIRCULAR_REFERENCE_PLACEHOLDER = "[Circular]"
|
||||
|
||||
|
||||
def _serialize_workflow_key(key: Any) -> Any:
|
||||
"""将映射键转换为 JSON 安全值。"""
|
||||
if key is None or isinstance(key, (str, int, float, bool)):
|
||||
return key
|
||||
return str(key)
|
||||
|
||||
|
||||
def _serialize_workflow_value(value: Any, stack: Optional[set[int]] = None) -> Any:
|
||||
"""把工作流上下文和值转换为可持久化的 JSON 结构。"""
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
|
||||
stack = stack or set()
|
||||
object_id = id(value)
|
||||
if object_id in stack:
|
||||
return CIRCULAR_REFERENCE_PLACEHOLDER
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
stack.add(object_id)
|
||||
try:
|
||||
return {
|
||||
field_name: _serialize_workflow_value(getattr(value, field_name, None), stack)
|
||||
for field_name in value.__class__.model_fields
|
||||
}
|
||||
finally:
|
||||
stack.remove(object_id)
|
||||
|
||||
if isinstance(value, dict):
|
||||
stack.add(object_id)
|
||||
try:
|
||||
return {
|
||||
_serialize_workflow_key(key): _serialize_workflow_value(item, stack)
|
||||
for key, item in value.items()
|
||||
}
|
||||
finally:
|
||||
stack.remove(object_id)
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
stack.add(object_id)
|
||||
try:
|
||||
return [_serialize_workflow_value(item, stack) for item in value]
|
||||
finally:
|
||||
stack.remove(object_id)
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _serialize_workflow_context(context: ActionContext) -> dict:
|
||||
"""构建可写入数据库的工作流上下文字典。"""
|
||||
serialized = _serialize_workflow_value(context)
|
||||
return serialized if isinstance(serialized, dict) else {}
|
||||
|
||||
|
||||
class WorkflowCancelToken:
|
||||
@@ -185,13 +240,16 @@ class WorkflowExecutor:
|
||||
恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。
|
||||
"""
|
||||
context = ActionContext()
|
||||
if self.workflow.current_action and self.workflow.context:
|
||||
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
|
||||
if self.workflow.context:
|
||||
if self.workflow.current_action:
|
||||
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
|
||||
try:
|
||||
decoded_data = base64.b64decode(self.workflow.context["content"])
|
||||
context = pickle.loads(decoded_data)
|
||||
except Exception as err:
|
||||
logger.error(f"工作流上下文恢复失败: {str(err)}")
|
||||
if isinstance(self.workflow.context, dict) and self.workflow.context.get("content"):
|
||||
decoded_data = base64.b64decode(self.workflow.context["content"])
|
||||
context = pickle.loads(decoded_data)
|
||||
elif isinstance(self.workflow.context, dict):
|
||||
context = ActionContext.model_validate(self.workflow.context)
|
||||
except Exception:
|
||||
context = ActionContext()
|
||||
outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {}
|
||||
if outputs and not getattr(context, "node_outputs", None):
|
||||
@@ -247,38 +305,14 @@ class WorkflowExecutor:
|
||||
构建可持久化的结构化执行状态。
|
||||
"""
|
||||
self.update_runtime_state()
|
||||
return self.make_json_safe({
|
||||
execution_state = _serialize_workflow_value({
|
||||
"version": 1,
|
||||
"nodes": self.node_metadata,
|
||||
"outputs": self.context.node_outputs,
|
||||
"errors": self.errors,
|
||||
"runtime": self.context.runtime_state,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def make_json_safe(value: Any) -> Any:
|
||||
"""
|
||||
将运行期对象转换为可写入 JSON 列的数据。
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(key): WorkflowExecutor.make_json_safe(item)
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [WorkflowExecutor.make_json_safe(item) for item in value]
|
||||
try:
|
||||
encoded_value = jsonable_encoder(value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
if isinstance(encoded_value, dict):
|
||||
return {
|
||||
str(key): WorkflowExecutor.make_json_safe(item)
|
||||
for key, item in encoded_value.items()
|
||||
}
|
||||
if isinstance(encoded_value, list):
|
||||
return [WorkflowExecutor.make_json_safe(item) for item in encoded_value]
|
||||
return encoded_value
|
||||
return execution_state if isinstance(execution_state, dict) else {}
|
||||
|
||||
def execute(self) -> None:
|
||||
"""
|
||||
@@ -603,7 +637,8 @@ class WorkflowExecutor:
|
||||
"runtime_state": self.context.runtime_state,
|
||||
}
|
||||
|
||||
def ensure_result_context_partitions(self, context: ActionContext) -> None:
|
||||
@staticmethod
|
||||
def ensure_result_context_partitions(context: ActionContext) -> None:
|
||||
"""
|
||||
确保动作返回上下文具备新版分区字段。
|
||||
"""
|
||||
@@ -705,13 +740,19 @@ class WorkflowExecutor:
|
||||
return outputs_config
|
||||
return self.get_action_contract(action).get("outputs") or {}
|
||||
|
||||
def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str:
|
||||
@staticmethod
|
||||
def get_default_merge_policy(action: Action, key: str, value: Any) -> str:
|
||||
"""
|
||||
获取输出默认合并策略。
|
||||
"""
|
||||
if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"):
|
||||
return "replace"
|
||||
if isinstance(value, list):
|
||||
action_type = action.type or ""
|
||||
action_name = action.name or ""
|
||||
if key in ARTIFACT_FIELDS and (
|
||||
action_type.startswith("Filter")
|
||||
or action_name.startswith("过滤")
|
||||
):
|
||||
return "replace"
|
||||
return "append_unique"
|
||||
if isinstance(value, dict):
|
||||
return "merge_dict"
|
||||
@@ -756,12 +797,11 @@ class WorkflowExecutor:
|
||||
获取列表元素去重身份。
|
||||
"""
|
||||
if not identity:
|
||||
identity_value = self.make_json_safe(item)
|
||||
return self.make_hashable_identity(identity_value)
|
||||
return self.make_hashable_identity(item)
|
||||
value = item
|
||||
for part in identity.split("."):
|
||||
value = self.read_value(value, int(part) if part.isdigit() else part)
|
||||
return self.make_hashable_identity(self.make_json_safe(value))
|
||||
return self.make_hashable_identity(item)
|
||||
|
||||
@staticmethod
|
||||
def make_hashable_identity(value: Any) -> Any:
|
||||
@@ -1102,17 +1142,11 @@ class WorkflowChain(ChainBase):
|
||||
"""
|
||||
保存上下文到数据库
|
||||
"""
|
||||
# 序列化数据
|
||||
serialized_data = pickle.dumps(context)
|
||||
# 使用Base64编码字节流
|
||||
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
|
||||
WorkflowOper().step(
|
||||
workflow_id,
|
||||
action_id=action.id if completed else "",
|
||||
context={
|
||||
"content": encoded_data
|
||||
},
|
||||
execution_state=execution_state
|
||||
context=_serialize_workflow_context(context),
|
||||
execution_state=_serialize_workflow_value(execution_state)
|
||||
)
|
||||
|
||||
# 重置工作流
|
||||
|
||||
@@ -100,7 +100,7 @@ class WorkflowOper(DbOper):
|
||||
wid,
|
||||
action_id,
|
||||
context,
|
||||
execution_state=execution_state
|
||||
execution_state
|
||||
)
|
||||
|
||||
def reset(self, wid: int, reset_count: bool = False) -> bool:
|
||||
|
||||
@@ -67,6 +67,57 @@ class _FakeWorkflowManager:
|
||||
return self.contracts.get(action_type) or {}
|
||||
|
||||
|
||||
class _FakeWorkflowOper:
|
||||
"""记录工作流持久化调用。"""
|
||||
|
||||
def __init__(self, workflow):
|
||||
self.workflow = workflow
|
||||
self.steps = []
|
||||
self.started = False
|
||||
self.failed_result = None
|
||||
self.succeeded = False
|
||||
|
||||
def reset(self, wid):
|
||||
"""模拟重置工作流。"""
|
||||
_ = wid
|
||||
return True
|
||||
|
||||
def get(self, wid):
|
||||
"""返回预置工作流。"""
|
||||
_ = wid
|
||||
return self.workflow
|
||||
|
||||
def start(self, wid):
|
||||
"""记录启动调用。"""
|
||||
_ = wid
|
||||
self.started = True
|
||||
return True
|
||||
|
||||
def step(self, wid, action_id, context, execution_state=None):
|
||||
"""记录步骤持久化数据。"""
|
||||
self.steps.append(
|
||||
{
|
||||
"wid": wid,
|
||||
"action_id": action_id,
|
||||
"context": context,
|
||||
"execution_state": execution_state,
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
def fail(self, wid, result):
|
||||
"""记录失败结果。"""
|
||||
_ = wid
|
||||
self.failed_result = result
|
||||
return True
|
||||
|
||||
def success(self, wid, result=None):
|
||||
"""记录成功结果。"""
|
||||
_ = wid, result
|
||||
self.succeeded = True
|
||||
return True
|
||||
|
||||
|
||||
class _OpaqueValue:
|
||||
"""模拟无法直接 JSON 序列化的值。"""
|
||||
|
||||
@@ -97,6 +148,31 @@ def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
|
||||
assert executor.context.progress == 100
|
||||
|
||||
|
||||
def test_workflow_executor_restores_structured_context(monkeypatch):
|
||||
"""恢复执行时应兼容新版结构化上下文存储格式。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
current_action="A",
|
||||
context={
|
||||
"workflow_context": {"trace_id": "wf-1"},
|
||||
"node_outputs": {"A": {"items": ["movie"]}},
|
||||
"progress": 50,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.context.workflow_context["trace_id"] == "wf-1"
|
||||
assert executor.context.node_outputs["A"]["items"] == ["movie"]
|
||||
|
||||
|
||||
def test_workflow_executor_reports_incremental_progress(monkeypatch):
|
||||
"""顺序工作流的中间进度应按已完成比例计算。"""
|
||||
calls = []
|
||||
@@ -488,6 +564,44 @@ def test_workflow_executor_keeps_execution_state_dict_for_non_json_leaf(monkeypa
|
||||
assert states[-1]["outputs"]["A"]["opaque"] == "opaque-value"
|
||||
|
||||
|
||||
def test_workflow_chain_process_serializes_circular_context(monkeypatch):
|
||||
"""工作流步骤持久化应清洗循环引用和不可序列化上下文。"""
|
||||
calls = []
|
||||
|
||||
def run_action(action, context):
|
||||
"""构造包含循环引用的上下文。"""
|
||||
context.workflow_context["self"] = context.workflow_context
|
||||
context.workflow_context["opaque"] = _OpaqueValue()
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context)
|
||||
|
||||
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action})
|
||||
workflow = _build_workflow(
|
||||
actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}],
|
||||
flows=[{"id": "flow-end", "source": "A", "target": "END", "animated": True}],
|
||||
)
|
||||
fake_oper = _FakeWorkflowOper(workflow)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowOper", lambda: fake_oper)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
success, message = workflow_module.WorkflowChain.process(workflow_id=1)
|
||||
|
||||
assert success is True
|
||||
assert message == ""
|
||||
assert fake_oper.succeeded is True
|
||||
saved_workflow_context = fake_oper.steps[-1]["context"]["workflow_context"]
|
||||
saved_self = saved_workflow_context["self"]
|
||||
|
||||
assert saved_workflow_context["opaque"] == "opaque-value"
|
||||
if isinstance(saved_self, dict):
|
||||
assert saved_self["self"] == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
|
||||
assert saved_self["opaque"] == "opaque-value"
|
||||
else:
|
||||
assert saved_self == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
|
||||
|
||||
|
||||
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
|
||||
"""相同 concurrency_key 的并行节点不应同时运行。"""
|
||||
calls = []
|
||||
|
||||
Reference in New Issue
Block a user