feat(workflow): add execution configuration and structured execution state to workflow

This commit is contained in:
jxxghp
2026-06-04 15:57:34 +08:00
parent 7474ecd02f
commit a2984530f8
8 changed files with 1000 additions and 107 deletions

View File

@@ -1,13 +1,17 @@
import ast
import base64
import copy
import inspect
import pickle
import threading
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Optional, Tuple
from fastapi.encoders import jsonable_encoder
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.event import Event, eventmanager
@@ -19,6 +23,29 @@ from app.schemas.types import EventType
from app.workflow import WorkFlowManager
ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"}
DEFAULT_WORKFLOW_MAX_WORKERS = 4
class WorkflowCancelToken:
"""
工作流取消令牌。
"""
def __init__(self, workflow_id: int):
"""
初始化取消令牌。
:param workflow_id: 工作流ID
"""
self.workflow_id = workflow_id
def is_cancelled(self) -> bool:
"""
判断工作流是否已被取消。
"""
return global_vars.is_workflow_stopped(self.workflow_id)
class WorkflowExecutor:
"""
工作流执行器
@@ -35,32 +62,40 @@ class WorkflowExecutor:
self.step_callback = step_callback
self.actions = {action['id']: Action(**action) for action in workflow.actions}
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
self.execution_config = getattr(workflow, "execution_config", None) or {}
self.restored_execution_state = getattr(workflow, "execution_state", None) or {}
self.total_actions = len(self.actions)
self.completed_actions = {
action_id for action_id in (workflow.current_action or "").split(",")
if action_id in self.actions
}
self.finished_actions = len(self.completed_actions)
self.success = True
self.has_failure = False
self.stopped = False
self.errmsg = ""
self.node_states = {action_id: "pending" for action_id in self.actions}
for action_id in self.completed_actions:
self.node_states[action_id] = "completed"
self.errors = self.get_restored_errors()
self.node_metadata = self.get_restored_node_metadata()
self.node_attempts = self.get_restored_attempts()
self.node_states = self.get_restored_node_states()
self.completed_actions = {
action_id for action_id, state in self.node_states.items()
if state == "success"
}
self.finished_actions = len([
state for state in self.node_states.values()
if state in ("success", "failed", "skipped")
])
self.flow_finished = set()
self.flow_satisfied = set()
self.flow_failed = set()
# 工作流管理器
self.workflowmanager = WorkFlowManager()
# 线程安全队列
self.queue = deque()
self.queued_actions = set()
self.active_concurrency_keys = set()
# 锁用于保证线程安全
self.lock = threading.Lock()
# 线程池
self.executor = ThreadPoolExecutor()
self.executor = ThreadPoolExecutor(max_workers=self.get_workflow_max_workers())
self.cancel_token = WorkflowCancelToken(self.workflow.id)
# 跟踪运行中的任务数
self.running_tasks = 0
@@ -74,26 +109,162 @@ class WorkflowExecutor:
self.incoming_flows[flow.target].append(flow)
# 初始上下文
if workflow.current_action and workflow.context:
logger.info(f"工作流已执行动作:{workflow.current_action}")
# Base64解码
decoded_data = base64.b64decode(workflow.context["content"])
# 反序列化数据
self.context = pickle.loads(decoded_data)
else:
self.context = ActionContext()
self.context.node_outputs = self.context.node_outputs or {}
self.context = self.restore_context()
self.ensure_context_partitions()
# 恢复工作流
global_vars.workflow_resume(self.workflow.id)
# 恢复时重新释放已完成节点的出边,使后继节点能继续执行。
for action_id in self.completed_actions:
self.release_successors(action_id, source_success=True)
# 恢复时重新释放已终态节点的出边,使后继节点能继续执行或保持跳过传播
for action_id, state in self.node_states.items():
if state == "success":
self.release_successors(action_id, source_success=True)
elif state in ("failed", "skipped"):
self.release_successors(action_id, source_success=False)
# 初始化队列,添加没有入边的起始节点。
for action_id in self.actions:
if action_id not in self.completed_actions and not self.incoming_flows.get(action_id):
if self.node_states.get(action_id) == "pending" and not self.incoming_flows.get(action_id):
self.enqueue_node(action_id)
def get_workflow_max_workers(self) -> int:
"""
获取工作流最大并发数。
"""
max_workers = self.execution_config.get("max_workers") if isinstance(self.execution_config, dict) else None
try:
return max(int(max_workers or DEFAULT_WORKFLOW_MAX_WORKERS), 1)
except (TypeError, ValueError):
return DEFAULT_WORKFLOW_MAX_WORKERS
def get_restored_node_metadata(self) -> dict:
"""
获取已持久化的节点状态元数据。
"""
nodes = self.restored_execution_state.get("nodes") if isinstance(self.restored_execution_state, dict) else {}
return nodes if isinstance(nodes, dict) else {}
def get_restored_errors(self) -> dict:
"""
获取已持久化的错误状态。
"""
errors = self.restored_execution_state.get("errors") if isinstance(self.restored_execution_state, dict) else {}
return errors if isinstance(errors, dict) else {}
def get_restored_attempts(self) -> dict:
"""
获取已持久化的节点尝试次数。
"""
attempts = {}
for action_id, metadata in self.get_restored_node_metadata().items():
if isinstance(metadata, dict) and metadata.get("attempt"):
attempts[action_id] = int(metadata.get("attempt") or 0)
return attempts
def get_restored_node_states(self) -> dict:
"""
获取结构化节点状态,兼容旧版 current_action 字符串。
"""
legacy_actions = {
action_id for action_id in (self.workflow.current_action or "").split(",")
if action_id in self.actions
}
states = {}
for action_id in self.actions:
metadata = self.node_metadata.get(action_id) or {}
state = metadata.get("state") if isinstance(metadata, dict) else None
if state == "completed":
state = "success"
if state in ("running", "queued"):
state = "pending"
if not state and action_id in legacy_actions:
state = "success"
states[action_id] = state or "pending"
return states
def restore_context(self) -> ActionContext:
"""
恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。
"""
context = ActionContext()
if self.workflow.current_action and self.workflow.context:
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
try:
decoded_data = base64.b64decode(self.workflow.context["content"])
context = pickle.loads(decoded_data)
except Exception as err:
logger.error(f"工作流上下文恢复失败: {str(err)}")
context = ActionContext()
outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {}
if outputs and not getattr(context, "node_outputs", None):
context.node_outputs = outputs
return context
def ensure_context_partitions(self) -> None:
"""
确保上下文具备新版分区结构,并把旧字段映射到 artifacts。
"""
self.context.workflow_context = self.context.workflow_context or {}
self.context.node_outputs = self.context.node_outputs or {}
self.context.runtime_state = self.context.runtime_state or {}
self.context.artifacts = self.context.artifacts or {}
for key in ARTIFACT_FIELDS:
value = getattr(self.context, key, None)
if value not in (None, "", [], {}) and key not in self.context.artifacts:
self.context.artifacts[key] = value
self.update_runtime_state()
def update_runtime_state(self) -> None:
"""
更新上下文中的运行期状态分区。
"""
self.context.runtime_state.update({
"progress": self.context.progress,
"finished_actions": self.finished_actions,
"running_tasks": self.running_tasks,
"errors": self.errors,
"node_states": self.node_states,
"attempts": self.node_attempts,
})
def set_node_state(self, action_id: str, state: str, message: Optional[str] = None) -> None:
"""
更新节点结构化状态。
"""
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
metadata = self.node_metadata.setdefault(action_id, {})
metadata["state"] = state
metadata["attempt"] = self.node_attempts.get(action_id, metadata.get("attempt") or 0)
if state == "running":
metadata["started_at"] = now
if state in ("success", "failed", "skipped"):
metadata["finished_at"] = now
if message is not None:
metadata["message"] = message
self.node_states[action_id] = state
self.update_runtime_state()
def build_execution_state(self) -> dict:
"""
构建可持久化的结构化执行状态。
"""
self.update_runtime_state()
return self.make_json_safe({
"version": 1,
"nodes": self.node_metadata,
"outputs": self.context.node_outputs,
"errors": self.errors,
"runtime": self.context.runtime_state,
})
@staticmethod
def make_json_safe(value: Any) -> Any:
"""
将运行期对象转换为可写入 JSON 列的数据。
"""
try:
return jsonable_encoder(value)
except Exception:
return str(value)
def execute(self) -> None:
"""
执行工作流
@@ -121,14 +292,9 @@ class WorkflowExecutor:
elif not self.queue:
should_sleep = True
else:
# 取出队首节点
node_id = self.queue.popleft()
self.queued_actions.discard(node_id)
if self.node_states.get(node_id) != "queued":
continue
self.node_states[node_id] = "running"
# 标记任务开始
self.running_tasks += 1
node_id = self.pop_dispatchable_node()
if not node_id:
should_sleep = True
if should_sleep:
sleep(0.1)
@@ -148,19 +314,48 @@ class WorkflowExecutor:
finally:
self.executor.shutdown(wait=True, cancel_futures=True)
def pop_dispatchable_node(self) -> Optional[str]:
"""
从队列中取出当前可调度节点。
"""
for _ in range(len(self.queue)):
node_id = self.queue.popleft()
self.queued_actions.discard(node_id)
if self.node_states.get(node_id) != "queued":
continue
concurrency_key = self.get_action_concurrency_key(self.actions[node_id])
if concurrency_key and concurrency_key in self.active_concurrency_keys:
self.queue.append(node_id)
self.queued_actions.add(node_id)
continue
if concurrency_key:
self.active_concurrency_keys.add(concurrency_key)
self.running_tasks += 1
self.set_node_state(node_id, "running")
return node_id
return None
def execute_node(self, workflow_id: int, node_id: str,
context: ActionContext) -> Tuple[Action, ActionResult]:
"""
执行单个节点操作返回修改后的上下文和节点ID
"""
action = self.actions[node_id]
action_result = self.workflowmanager.execute(workflow_id, action, context=context)
action_result = self.workflowmanager.execute(
workflow_id,
action,
context=context,
inputs=self.build_action_inputs(action),
runtime=self.build_action_runtime(action),
cancel_token=self.cancel_token
)
return action, action_result
def on_node_complete(self, future):
"""
节点完成回调:更新上下文、处理后继节点
"""
action = None
try:
action, action_result = future.result()
with self.lock:
@@ -172,6 +367,7 @@ class WorkflowExecutor:
state = bool(action_result.success)
message = action_result.message or ""
result_ctx = action_result.context or ActionContext()
self.node_attempts[action.id] = action_result.attempts or self.node_attempts.get(action.id, 1)
self.finished_actions += 1
self.update_progress()
@@ -186,31 +382,36 @@ class WorkflowExecutor:
# 节点执行失败时默认停止;显式配置 continue/ignore 时继续释放后续 all_done 汇合。
if not state:
self.node_states[action.id] = "failed"
self.errors[action.id] = message or f"{action.name} 失败"
self.set_node_state(action.id, "failed", message=message)
fail_policy = self.get_action_fail_policy(action)
if fail_policy != "ignore":
self.has_failure = True
self.errmsg = f"{action.name} 失败"
if fail_policy == "stop":
self.success = False
self.call_step_callback(action, completed=False)
return
if fail_policy not in ("continue", "ignore"):
self.success = False
self.errmsg = f"{action.name} 失败:无效失败策略 {fail_policy}"
self.call_step_callback(action, completed=False)
return
self.release_successors(action.id, source_success=False)
self.call_step_callback(action, completed=False)
return
# 更新主上下文
self.merge_context(result_ctx)
self.record_node_outputs(action.id, action_result, result_ctx)
self.ensure_result_context_partitions(result_ctx)
outputs = self.normalize_action_outputs(action, action_result, result_ctx)
self.merge_context_partitions(result_ctx)
self.merge_action_outputs(action, outputs)
self.record_node_outputs(action.id, outputs)
self.completed_actions.add(action.id)
self.node_states[action.id] = "completed"
self.set_node_state(action.id, "success", message=message)
# 处理后继节点
self.release_successors(action.id, source_success=True)
# 回调
if self.step_callback:
self.step_callback(action, self.context)
self.call_step_callback(action, completed=True)
except Exception as err:
logger.error(f"工作流节点执行回调失败: {str(err)}")
with self.lock:
@@ -219,7 +420,12 @@ class WorkflowExecutor:
finally:
# 标记任务完成
with self.lock:
if action:
concurrency_key = self.get_action_concurrency_key(action)
if concurrency_key:
self.active_concurrency_keys.discard(concurrency_key)
self.running_tasks -= 1
self.update_runtime_state()
def enqueue_node(self, node_id: str) -> None:
"""
@@ -231,7 +437,7 @@ class WorkflowExecutor:
return
self.queue.append(node_id)
self.queued_actions.add(node_id)
self.node_states[node_id] = "queued"
self.set_node_state(node_id, "queued")
def skip_node(self, node_id: str, message: str) -> None:
"""
@@ -242,9 +448,9 @@ class WorkflowExecutor:
if self.node_states.get(node_id) not in ("pending", "queued"):
return
self.queued_actions.discard(node_id)
self.node_states[node_id] = "skipped"
self.finished_actions += 1
self.update_progress()
self.set_node_state(node_id, "skipped", message=message)
self.context.execute_history.append(
ActionExecution(
action=self.actions[node_id].name,
@@ -252,13 +458,17 @@ class WorkflowExecutor:
message=message
)
)
self.call_step_callback(self.actions[node_id], completed=False)
self.release_successors(node_id, source_success=False)
def release_successors(self, source_id: str, source_success: bool) -> None:
"""
根据源节点状态释放出边,并重新判断目标节点是否可运行。
"""
for flow in self.outgoing_flows.get(source_id, []):
flows = self.outgoing_flows.get(source_id, [])
branch_policy = self.get_action_branch_policy(self.actions.get(source_id), flows)
matched_exclusive_flow = None
for flow in flows:
flow_key = self.get_flow_key(flow)
if flow_key in self.flow_finished:
continue
@@ -270,9 +480,15 @@ class WorkflowExecutor:
self.success = False
self.errmsg = f"流程条件判断失败:{err}"
return
if branch_policy == "exclusive" and condition_matched and matched_exclusive_flow:
condition_matched = False
elif branch_policy == "exclusive" and condition_matched:
matched_exclusive_flow = flow_key
self.flow_finished.add(flow_key)
if source_success and condition_matched:
self.flow_satisfied.add(flow_key)
if not source_success and self.node_states.get(source_id) == "failed":
self.flow_failed.add(flow_key)
self.evaluate_target_state(flow.target)
def evaluate_target_state(self, target_id: str) -> None:
@@ -291,8 +507,18 @@ class WorkflowExecutor:
total_count = len(incoming_flows)
finished_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_finished)
satisfied_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_satisfied)
failed_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_failed)
join_policy = self.get_action_join_policy(self.actions[target_id], incoming_flows)
if join_policy == "fail_fast":
if failed_count > 0:
self.skip_node(target_id, "上游失败触发 fail_fast已取消后续节点")
elif finished_count == total_count and satisfied_count == total_count:
self.enqueue_node(target_id)
elif finished_count == total_count:
self.skip_node(target_id, "上游条件未全部满足,已跳过")
return
if join_policy == "any_success":
if satisfied_count > 0:
self.enqueue_node(target_id)
@@ -322,14 +548,192 @@ class WorkflowExecutor:
根据已完成和已跳过节点数量更新整体进度。
"""
self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100
self.update_runtime_state()
def record_node_outputs(self, action_id: str, action_result: ActionResult, result_context: ActionContext) -> None:
def build_action_inputs(self, action: Action) -> dict:
"""
根据动作输入声明读取上游节点输出。
"""
inputs = {}
input_paths = action.inputs or self.get_action_data_value(action, "inputs") or []
if isinstance(input_paths, str):
input_paths = [item.strip() for item in input_paths.splitlines() if item.strip()]
for input_path in input_paths:
inputs[input_path] = self.resolve_context_path(input_path)
return inputs
def build_action_runtime(self, action: Action) -> dict:
"""
构建传递给动作的新运行期数据。
"""
return {
"workflow_id": self.workflow.id,
"action_id": action.id,
"execution_config": self.execution_config,
"runtime_state": self.context.runtime_state,
}
def ensure_result_context_partitions(self, context: ActionContext) -> None:
"""
确保动作返回上下文具备新版分区字段。
"""
context.workflow_context = context.workflow_context or {}
context.node_outputs = context.node_outputs or {}
context.runtime_state = context.runtime_state or {}
context.artifacts = context.artifacts or {}
def normalize_action_outputs(self, action: Action, action_result: ActionResult,
result_context: ActionContext) -> dict:
"""
根据动作输出声明整理当前节点输出。
"""
outputs = action_result.outputs or self.extract_context_outputs(result_context)
declared_outputs = action.outputs or self.get_action_data_value(action, "outputs")
if isinstance(declared_outputs, list):
return {key: outputs.get(key) for key in declared_outputs if outputs.get(key) not in (None, "", [], {})}
if isinstance(declared_outputs, dict):
return {
key: outputs.get(key)
for key in declared_outputs
if outputs.get(key) not in (None, "", [], {})
} or outputs
return outputs
def record_node_outputs(self, action_id: str, outputs: dict) -> None:
"""
记录当前节点输出,供后续条件表达式读取。
"""
outputs = action_result.outputs or self.extract_context_outputs(result_context)
if outputs:
self.context.node_outputs[action_id] = outputs
self.context.runtime_state["last_outputs"] = outputs
def merge_context_partitions(self, context: ActionContext) -> None:
"""
合并动作返回的新分区上下文。
"""
for key in ("workflow_context", "runtime_state", "artifacts"):
value = getattr(context, key, None)
if not value:
continue
current_value = getattr(self.context, key, None) or {}
current_value.update(value)
setattr(self.context, key, current_value)
def merge_action_outputs(self, action: Action, outputs: dict) -> None:
"""
按声明式合并策略写入全局上下文和 artifacts 分区。
"""
for key, value in outputs.items():
if value in (None, "", [], {}):
continue
output_config = self.get_action_output_config(action, key)
target_key = output_config.get("target") or key
merge_policy = output_config.get("merge") or self.get_default_merge_policy(action, target_key, value)
identity = output_config.get("identity")
self.merge_output_value(target_key, value, merge_policy, identity)
def merge_output_value(self, key: str, value: Any, merge_policy: str, identity: Optional[str] = None) -> None:
"""
按指定策略合并单个输出值。
"""
current_value = getattr(self.context, key, None) if key in ActionContext.model_fields else None
merged_value = self.apply_merge_policy(current_value, value, merge_policy, identity)
if key in ActionContext.model_fields:
setattr(self.context, key, merged_value)
if key in ARTIFACT_FIELDS:
current_artifact = self.context.artifacts.get(key)
self.context.artifacts[key] = self.apply_merge_policy(current_artifact, value, merge_policy, identity)
def get_action_output_config(self, action: Action, output_key: str) -> dict:
"""
获取动作输出声明配置。
"""
outputs_config = action.outputs or self.get_action_data_value(action, "outputs") or {}
if isinstance(outputs_config, dict):
value = outputs_config.get(output_key) or {}
return value if isinstance(value, dict) else {}
return {}
def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str:
"""
获取输出默认合并策略。
"""
if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"):
return "replace"
if isinstance(value, list):
return "append_unique"
if isinstance(value, dict):
return "merge_dict"
return "first_non_empty"
def apply_merge_policy(self, current_value: Any, value: Any, merge_policy: str,
identity: Optional[str] = None) -> Any:
"""
应用声明式合并策略。
"""
if merge_policy == "replace":
return value
if merge_policy == "merge_dict":
merged = current_value.copy() if isinstance(current_value, dict) else {}
if isinstance(value, dict):
merged.update(value)
return merged
return current_value or value
if merge_policy == "append_unique":
return self.append_unique_values(current_value, value, identity)
if merge_policy == "first_non_empty":
return current_value or value
return current_value or value
def append_unique_values(self, current_value: Any, value: Any, identity: Optional[str] = None) -> list:
"""
追加列表并按身份字段去重。
"""
current_list = list(current_value or [])
incoming_list = value if isinstance(value, list) else [value]
seen = {self.get_identity_value(item, identity) for item in current_list}
for item in incoming_list:
identity_value = self.get_identity_value(item, identity)
if identity_value in seen:
continue
current_list.append(item)
seen.add(identity_value)
return current_list
def get_identity_value(self, item: Any, identity: Optional[str] = None) -> Any:
"""
获取列表元素去重身份。
"""
if not identity:
identity_value = self.make_json_safe(item)
return self.make_hashable_identity(identity_value)
value = item
for part in identity.split("."):
value = self.read_value(value, int(part) if part.isdigit() else part)
return self.make_hashable_identity(self.make_json_safe(value))
@staticmethod
def make_hashable_identity(value: Any) -> Any:
"""
将身份值转换为可哈希对象。
"""
try:
hash(value)
return value
except TypeError:
return repr(value)
def call_step_callback(self, action: Action, completed: bool) -> None:
"""
持久化当前步骤上下文和结构化执行状态。
"""
if not self.step_callback:
return
callback_params = inspect.signature(self.step_callback).parameters
if len(callback_params) <= 2:
self.step_callback(action, self.context)
return
self.step_callback(action, self.context, self.build_execution_state(), completed)
@staticmethod
def extract_context_outputs(context: ActionContext) -> dict:
@@ -340,7 +744,7 @@ class WorkflowExecutor:
return {}
outputs = {}
for key in context.__class__.model_fields:
if key in ("execute_history", "progress", "node_outputs"):
if key in ("execute_history", "progress", "node_outputs", "runtime_state"):
continue
value = getattr(context, key, None)
if value in (None, "", [], {}):
@@ -368,12 +772,32 @@ class WorkflowExecutor:
return join_policy
return "all_success"
def get_action_branch_policy(self, action: Optional[Action], outgoing_flows: List[ActionFlow]) -> str:
"""
获取动作出边分支策略。
"""
if action:
branch_policy = action.branch_policy or self.get_action_data_value(action, "branch_policy")
if branch_policy:
return branch_policy
for flow in outgoing_flows:
branch_policy = flow.branch_policy or self.get_flow_data_value(flow, "branch_policy")
if branch_policy:
return branch_policy
return "parallel"
def get_action_fail_policy(self, action: Action) -> str:
"""
获取动作失败策略。
"""
return action.fail_policy or self.get_action_data_value(action, "fail_policy") or "stop"
def get_action_concurrency_key(self, action: Action) -> Optional[str]:
"""
获取动作并发互斥键。
"""
return action.concurrency_key or self.get_action_data_value(action, "concurrency_key")
def get_flow_condition(self, flow: ActionFlow) -> Optional[str]:
"""
获取流程边条件表达式。
@@ -381,10 +805,12 @@ class WorkflowExecutor:
return flow.condition or self.get_flow_data_value(flow, "condition")
@staticmethod
def get_action_data_value(action: Action, key: str) -> Any:
def get_action_data_value(action: Optional[Action], key: str) -> Any:
"""
从动作 data 中读取扩展配置。
"""
if not action:
return None
data = action.data or {}
return data.get(key) if isinstance(data, dict) else None
@@ -481,8 +907,18 @@ class WorkflowExecutor:
return None
if name == "context":
return self.context
if name == "workflow_context":
return self.context.workflow_context or {}
if name == "runtime_state":
return self.context.runtime_state or {}
if name == "artifacts":
return self.context.artifacts or {}
if name in ("outputs", "node_outputs"):
return self.context.node_outputs or {}
if name == "last":
return self.context.runtime_state.get("last_outputs") if self.context.runtime_state else {}
if name in (self.context.node_outputs or {}):
return self.context.node_outputs[name]
if name in ActionContext.model_fields:
return getattr(self.context, name, None)
raise ValueError(f"未知上下文变量 {name}")
@@ -598,7 +1034,7 @@ class WorkflowChain(ChainBase):
"""
workflowoper = WorkflowOper()
def save_step(action: Action, context: ActionContext):
def save_step(action: Action, context: ActionContext, execution_state: dict, completed: bool):
"""
保存上下文到数据库
"""
@@ -606,9 +1042,14 @@ class WorkflowChain(ChainBase):
serialized_data = pickle.dumps(context)
# 使用Base64编码字节流
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
WorkflowOper().step(workflow_id, action_id=action.id, context={
"content": encoded_data
})
WorkflowOper().step(
workflow_id,
action_id=action.id if completed else "",
context={
"content": encoded_data
},
execution_state=execution_state
)
# 重置工作流
if from_begin:

View File

@@ -40,6 +40,10 @@ class Workflow(Base):
flows = Column(JSON, default=builtin_list)
# 执行上下文
context = Column(JSON, default=dict)
# 执行配置
execution_config = Column(JSON, default=dict)
# 结构化执行状态
execution_state = Column(JSON, default=dict)
# 创建时间
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 最后执行时间
@@ -218,6 +222,7 @@ class Workflow(Base):
"result": None,
"current_action": None,
"context": {},
"execution_state": {},
"run_count": 0 if reset_count else cls.run_count,
})
return True
@@ -231,39 +236,48 @@ class Workflow(Base):
result=None,
current_action=None,
context={},
execution_state={},
run_count=0 if reset_count else cls.run_count,
))
return True
@classmethod
@db_update
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
def update_current_action(cls, db, wid: int, action_id: str, context: dict,
execution_state: Optional[dict] = None):
workflow = db.query(cls).filter(cls.id == wid).first()
current_actions = []
if workflow and workflow.current_action:
current_actions = [item for item in workflow.current_action.split(",") if item]
if action_id not in current_actions:
if action_id and action_id not in current_actions:
current_actions.append(action_id)
db.query(cls).filter(cls.id == wid).update({
update_values = {
"current_action": ",".join(current_actions),
"context": context
})
}
if execution_state is not None:
update_values["execution_state"] = execution_state
db.query(cls).filter(cls.id == wid).update(update_values)
return True
@classmethod
@async_db_update
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict,
execution_state: Optional[dict] = None):
from sqlalchemy import update
# 先获取当前current_action
result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar()
current_actions = [item for item in (current_action or "").split(",") if item]
if action_id not in current_actions:
if action_id and action_id not in current_actions:
current_actions.append(action_id)
new_current_action = ",".join(current_actions)
await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action,
context=context
))
update_values = {
"current_action": new_current_action,
"context": context
}
if execution_state is not None:
update_values["execution_state"] = execution_state
await db.execute(update(cls).where(cls.id == wid).values(**update_values))
return True

View File

@@ -91,11 +91,17 @@ class WorkflowOper(DbOper):
"""
return Workflow.fail(self._db, wid, result)
def step(self, wid: int, action_id: str, context: dict) -> bool:
def step(self, wid: int, action_id: str, context: dict, execution_state: Optional[dict] = None) -> bool:
"""
步进
"""
return Workflow.update_current_action(self._db, wid, action_id, context)
return Workflow.update_current_action(
self._db,
wid,
action_id,
context,
execution_state=execution_state
)
def reset(self, wid: int, reset_count: bool = False) -> bool:
"""

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional, List
from typing import Any, List, Optional
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.context import Context, MediaInfo
from app.schemas.download import DownloadTask
@@ -26,6 +26,8 @@ class Workflow(BaseModel):
run_count: Optional[int] = Field(default=0, description="已执行次数")
actions: Optional[list] = Field(default_factory=list, description="任务列表")
flows: Optional[list] = Field(default_factory=list, description="任务流")
execution_config: Optional[dict] = Field(default_factory=dict, description="工作流执行配置")
execution_state: Optional[dict] = Field(default_factory=dict, description="工作流结构化执行状态")
add_time: Optional[str] = Field(default=None, description="创建时间")
last_time: Optional[str] = Field(default=None, description="最后执行时间")
@@ -50,8 +52,14 @@ class Action(BaseModel):
description: Optional[str] = Field(default=None, description="动作描述")
position: Optional[dict] = Field(default_factory=dict, description="位置")
data: Optional[dict] = Field(default_factory=dict, description="参数")
inputs: Optional[List[str]] = Field(default_factory=list, description="动作输入声明")
outputs: Optional[dict] = Field(default_factory=dict, description="动作输出声明")
join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略")
fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略")
branch_policy: Optional[str] = Field(default=None, description="多出边分支策略")
concurrency_key: Optional[str] = Field(default=None, description="并发互斥键")
timeout: Optional[int] = Field(default=None, description="动作执行超时时间(秒)")
retry: Optional[dict] = Field(default_factory=dict, description="动作重试策略")
class ActionExecution(BaseModel):
@@ -74,7 +82,10 @@ class ActionContext(BaseModel):
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
workflow_context: Optional[dict] = Field(default_factory=dict, description="工作流全局上下文")
node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据")
runtime_state: Optional[dict] = Field(default_factory=dict, description="运行期状态")
artifacts: Optional[dict] = Field(default_factory=dict, description="大对象引用与产物数据")
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
progress: Optional[int] = Field(default=0, description="执行进度(%")
@@ -87,6 +98,8 @@ class ActionResult(BaseModel):
message: Optional[str] = Field(default=None, description="动作执行消息")
context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文")
outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出")
next_policy: Optional[str] = Field(default=None, description="动作完成后的调度策略")
attempts: Optional[int] = Field(default=1, description="动作实际尝试次数")
class ActionFlow(BaseModel):
@@ -100,6 +113,7 @@ class ActionFlow(BaseModel):
data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置")
condition: Optional[str] = Field(default=None, description="流转条件表达式")
join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略")
branch_policy: Optional[str] = Field(default=None, description="源节点分支策略")
class WorkflowShare(BaseModel):

View File

@@ -1,7 +1,6 @@
import threading
from time import sleep
from typing import Dict, Any, Optional
from typing import List, Tuple
from time import monotonic, sleep
from typing import Any, Dict, List, Optional, Tuple
from app.core.config import global_vars
from app.core.event import eventmanager, Event
@@ -68,48 +67,54 @@ class WorkFlowManager(metaclass=Singleton):
self._actions = {}
self._event_workflows = {}
def execute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> ActionResult:
def execute(self, workflow_id: int, action: Action, context: ActionContext = None,
inputs: Optional[dict] = None, runtime: Optional[dict] = None,
cancel_token: Optional[Any] = None) -> ActionResult:
"""
执行工作流动作
"""
if not context:
context = ActionContext()
if action.type in self._actions:
# 实例化之前,清理掉类对象的数据
# 实例化
action_obj = self._actions[action.type](action.id)
# 执行
logger.info(f"执行动作: {action.id} - {action.name}")
try:
result_context = action_obj.execute(workflow_id, action.data, context)
action_result = self._normalize_action_result(result_context, action_obj, context)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return ActionResult(success=False, message=f"{err}", context=context)
loop = (action.data or {}).get("loop")
loop_interval = (action.data or {}).get("loop_interval")
if loop and loop_interval:
while not action_obj.done:
if global_vars.is_workflow_stopped(workflow_id):
break
# 等待
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
sleep(loop_interval)
# 执行
logger.info(f"继续执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(workflow_id, action.data, action_result.context)
action_result = self._normalize_action_result(result_context, action_obj, action_result.context)
if action_result.success:
logger.info(f"{action.name} 执行成功")
else:
logger.error(f"{action.name} 执行失败!")
return action_result
else:
if action.type not in self._actions:
logger.error(f"未找到动作: {action.type} - {action.name}")
return ActionResult(success=False, message=" ", context=context)
retry_config = self._get_retry_config(action)
max_attempts = retry_config["max_attempts"]
interval = retry_config["interval"]
backoff = retry_config["backoff"]
action_result = ActionResult(success=False, message="", context=context)
for attempt in range(1, max_attempts + 1):
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=context)
runtime_data = {
**(runtime or {}),
"attempt": attempt,
"max_attempts": max_attempts,
"cancel_token": cancel_token,
}
action_result = self._execute_action_once(
workflow_id=workflow_id,
action=action,
context=context,
inputs=inputs or {},
runtime=runtime_data,
cancel_token=cancel_token
)
action_result.attempts = attempt
context = action_result.context or context
if action_result.success:
logger.info(f"{action.name} 执行成功")
return action_result
if attempt < max_attempts and not self._is_cancelled(workflow_id, cancel_token):
wait_seconds = interval * (backoff ** (attempt - 1))
logger.info(f"{action.name} 执行失败,{wait_seconds} 秒后重试({attempt}/{max_attempts}...")
self._sleep_with_cancel(workflow_id, wait_seconds, cancel_token)
logger.error(f"{action.name} 执行失败!")
return action_result
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
"""
@@ -134,6 +139,100 @@ class WorkFlowManager(metaclass=Singleton):
context=result or fallback_context
)
def _execute_action_once(self, workflow_id: int, action: Action, context: ActionContext,
inputs: dict, runtime: dict, cancel_token: Optional[Any]) -> ActionResult:
action_obj = self._actions[action.type](action.id)
logger.info(f"执行动作: {action.id} - {action.name}")
try:
action_result = self._run_action_with_loop(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=context,
inputs=inputs,
runtime=runtime,
cancel_token=cancel_token
)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return ActionResult(success=False, message=f"{err}", context=context)
return action_result
def _run_action_with_loop(self, workflow_id: int, action: Action, action_obj: Any,
context: ActionContext, inputs: dict, runtime: dict,
cancel_token: Optional[Any]) -> ActionResult:
timeout = self._get_action_timeout(action)
started_at = monotonic()
action_result = self._call_action(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=context,
inputs=inputs,
runtime=runtime
)
loop = self._get_action_data_value(action, "loop")
loop_interval = self._get_action_data_value(action, "loop_interval")
while loop and loop_interval and not action_obj.done:
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
if timeout and monotonic() - started_at >= timeout:
return ActionResult(success=False, message=f"动作执行超时({timeout}秒)", context=action_result.context or context)
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
self._sleep_with_cancel(workflow_id, loop_interval, cancel_token)
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
logger.info(f"继续执行动作: {action.id} - {action.name}")
action_result = self._call_action(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=action_result.context or context,
inputs=inputs,
runtime=runtime
)
return action_result
def _call_action(self, workflow_id: int, action: Action, action_obj: Any,
context: ActionContext, inputs: dict, runtime: dict) -> ActionResult:
if hasattr(action_obj, "execute_with_inputs"):
result = action_obj.execute_with_inputs(workflow_id, action.data, inputs, runtime, context)
else:
result = action_obj.execute(workflow_id, action.data, context)
return self._normalize_action_result(result, action_obj, context)
@staticmethod
def _get_action_data_value(action: Action, key: str) -> Any:
data = action.data or {}
return data.get(key) if isinstance(data, dict) else None
def _get_action_timeout(self, action: Action) -> Optional[int]:
timeout = action.timeout or self._get_action_data_value(action, "timeout")
return int(timeout) if timeout else None
def _get_retry_config(self, action: Action) -> dict:
retry_config = action.retry or self._get_action_data_value(action, "retry") or {}
if not isinstance(retry_config, dict):
retry_config = {}
return {
"max_attempts": max(int(retry_config.get("max_attempts") or 1), 1),
"interval": max(float(retry_config.get("interval") or 0), 0),
"backoff": max(float(retry_config.get("backoff") or 1), 1),
}
@staticmethod
def _is_cancelled(workflow_id: int, cancel_token: Optional[Any]) -> bool:
if cancel_token and cancel_token.is_cancelled():
return True
return global_vars.is_workflow_stopped(workflow_id)
def _sleep_with_cancel(self, workflow_id: int, seconds: float, cancel_token: Optional[Any]) -> None:
deadline = monotonic() + seconds
while monotonic() < deadline:
if self._is_cancelled(workflow_id, cancel_token):
return
sleep(min(0.1, deadline - monotonic()))
def list_actions(self) -> List[dict]:
"""
获取所有动作

View File

@@ -3,7 +3,7 @@ from typing import Union
from app.chain import ChainBase
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import ActionContext, ActionParams
from app.schemas import ActionContext, ActionParams, ActionResult
class ActionChain(ChainBase):
@@ -109,3 +109,16 @@ class BaseAction(ABC):
执行动作
"""
raise NotImplementedError
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
runtime: dict, context: ActionContext) -> ActionResult:
"""
使用显式输入与运行期信息执行动作。
"""
_ = inputs, runtime
result_context = self.execute(workflow_id, params, context)
return ActionResult(
success=self.success,
message=self.message,
context=result_context
)

View File

@@ -0,0 +1,45 @@
"""2.2.9
为工作流增加执行配置和结构化执行状态
Revision ID: 7c1a2b3d4e5f
Revises: d5e6f7a8b9c0
Create Date: 2026-06-04
"""
from alembic import op
import sqlalchemy as sa
revision = "7c1a2b3d4e5f"
down_revision = "d5e6f7a8b9c0"
branch_labels = None
depends_on = None
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
"""检查数据表是否已存在指定列。"""
if table_name not in inspector.get_table_names():
return False
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
def _add_json_column_if_missing(table_name: str, column_name: str) -> None:
"""缺失时为数据表新增 JSON 列。"""
inspector = sa.inspect(op.get_bind())
if not _has_column(inspector, table_name, column_name):
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
def upgrade() -> None:
"""升级数据库结构。"""
_add_json_column_if_missing("workflow", "execution_config")
_add_json_column_if_missing("workflow", "execution_state")
def downgrade() -> None:
"""回滚数据库结构。"""
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "workflow", "execution_state"):
op.drop_column("workflow", "execution_state")
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "workflow", "execution_config"):
op.drop_column("workflow", "execution_config")

View File

@@ -1,28 +1,32 @@
import base64
import pickle
import threading
import time
from types import SimpleNamespace
from app.chain import workflow as workflow_module
from app.schemas import ActionContext, ActionResult
from app.schemas import Action, ActionContext, ActionResult
from app.schemas.types import EventType
from app import workflow as workflow_package
def _build_workflow(current_action=None, context=None, actions=None, flows=None):
def _build_workflow(current_action=None, context=None, actions=None, flows=None,
execution_config=None, execution_state=None):
"""构造最小工作流对象。"""
return SimpleNamespace(
id=1,
name="测试工作流",
actions=actions or [
actions=actions if actions is not None else [
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
],
flows=flows or [
flows=flows if flows is not None else [
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
],
current_action=current_action,
context=context,
execution_config=execution_config or {},
execution_state=execution_state or {},
)
@@ -39,9 +43,12 @@ class _FakeWorkflowManager:
def __init__(self, calls, results=None):
self.calls = calls
self.results = results or {}
self.received_inputs = []
def execute(self, workflow_id, action, context=None):
def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None):
"""执行伪动作并记录新版输入。"""
self.calls.append(action.id)
self.received_inputs.append((action.id, inputs or {}, runtime or {}, cancel_token))
result = self.results.get(action.id)
if callable(result):
return result(action, context or ActionContext())
@@ -262,6 +269,220 @@ def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch)
assert executor.success is True
def test_workflow_executor_exclusive_branch_uses_first_matching_flow(monkeypatch):
"""互斥分支应只执行第一条满足条件的出边。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"count": 2}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"branch_policy": "exclusive"}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
],
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.count > 0"},
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.count > 1"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["A", "B"]
assert executor.node_states["C"] == "skipped"
def test_workflow_executor_passes_declared_inputs(monkeypatch):
"""动作输入声明应从 node_outputs 中读取指定路径。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["a", "b"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{
"id": "B",
"type": "FakeAction",
"name": "动作B",
"data": {"inputs": ["A.torrents", "outputs.A.torrents.count"]},
},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
assert b_inputs == {
"A.torrents": ["a", "b"],
"outputs.A.torrents.count": 2,
}
def test_workflow_executor_persists_structured_state(monkeypatch):
"""步骤回调应收到可持久化的结构化执行状态。"""
calls = []
states = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"items": ["movie"]}
)
}
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
)
executor.execute()
assert states[-1]["nodes"]["A"]["state"] == "success"
assert states[-1]["outputs"]["A"]["items"] == ["movie"]
assert states[-1]["runtime"]["progress"] == 100
def test_workflow_executor_restores_outputs_from_execution_state(monkeypatch):
"""恢复执行时应从结构化状态读取节点输出并继续判断条件边。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
execution_state={
"nodes": {
"A": {"state": "success", "attempt": 1},
},
"outputs": {
"A": {"torrents": ["movie"]},
},
},
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "A.torrents.count > 0"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.context.node_outputs["A"]["torrents"] == ["movie"]
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
"""相同 concurrency_key 的并行节点不应同时运行。"""
calls = []
active_count = 0
max_active_count = 0
lock = threading.Lock()
def run_action(action, context):
"""记录同一并发键下的同时运行数量。"""
nonlocal active_count, max_active_count
with lock:
active_count += 1
max_active_count = max(max_active_count, active_count)
time.sleep(0.05)
with lock:
active_count -= 1
return ActionResult(success=True, message=f"{action.name}完成", context=context)
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action, "B": run_action})
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"concurrency_key": "download"}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {"concurrency_key": "download"}},
],
flows=[],
execution_config={"max_workers": 2},
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert set(calls) == {"A", "B"}
assert max_active_count == 1
def test_workflow_executor_filter_action_replaces_artifact_outputs(monkeypatch):
"""过滤类动作默认应替换列表输出,避免把过滤前数据重新合并回来。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["old", "keep"]}
),
"B": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["keep"]}
),
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FilterTorrentsAction", "name": "过滤资源", "data": {}},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert executor.context.torrents == ["keep"]
assert executor.context.artifacts["torrents"] == ["keep"]
def test_workflow_executor_stop_is_not_success(monkeypatch):
"""停止信号不应被执行器汇报为成功完成。"""
calls = []
@@ -328,3 +549,43 @@ def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkey
assert fake_eventmanager.removed == [EventType.DownloadAdded]
assert manager.get_event_workflows() == {}
def test_workflow_manager_retries_action_until_success(monkeypatch):
"""动作管理器应按 retry 配置重试失败动作。"""
class RetryAction:
"""模拟第二次才成功的动作。"""
call_count = 0
def __init__(self, action_id):
self.action_id = action_id
def execute_with_inputs(self, workflow_id, params, inputs, runtime, context):
"""执行动作并在第二次返回成功。"""
_ = workflow_id, params, inputs, runtime
RetryAction.call_count += 1
if RetryAction.call_count == 1:
return ActionResult(success=False, message="第一次失败", context=context)
return ActionResult(success=True, message="第二次成功", context=context, outputs={"ok": True})
manager = object.__new__(workflow_package.WorkFlowManager)
manager._actions = {"RetryAction": RetryAction}
monkeypatch.setattr(workflow_package.global_vars, "is_workflow_stopped", lambda workflow_id: False)
result = manager.execute(
workflow_id=1,
action=Action(
id="retry",
type="RetryAction",
name="重试动作",
data={"retry": {"max_attempts": 2, "interval": 0}},
),
context=ActionContext(),
)
assert result.success is True
assert result.attempts == 2
assert result.outputs == {"ok": True}
assert RetryAction.call_count == 2