diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 60c72acf..1f14fb4c 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -1,13 +1,17 @@ import ast import base64 import copy +import inspect import pickle import threading from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from time import sleep from typing import Any, Callable, List, Optional, Tuple +from fastapi.encoders import jsonable_encoder + from app.chain import ChainBase from app.core.config import global_vars from app.core.event import Event, eventmanager @@ -19,6 +23,29 @@ from app.schemas.types import EventType from app.workflow import WorkFlowManager +ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"} +DEFAULT_WORKFLOW_MAX_WORKERS = 4 + + +class WorkflowCancelToken: + """ + 工作流取消令牌。 + """ + + def __init__(self, workflow_id: int): + """ + 初始化取消令牌。 + :param workflow_id: 工作流ID + """ + self.workflow_id = workflow_id + + def is_cancelled(self) -> bool: + """ + 判断工作流是否已被取消。 + """ + return global_vars.is_workflow_stopped(self.workflow_id) + + class WorkflowExecutor: """ 工作流执行器 @@ -35,32 +62,40 @@ class WorkflowExecutor: self.step_callback = step_callback self.actions = {action['id']: Action(**action) for action in workflow.actions} self.flows = [ActionFlow(**flow) for flow in workflow.flows] + self.execution_config = getattr(workflow, "execution_config", None) or {} + self.restored_execution_state = getattr(workflow, "execution_state", None) or {} self.total_actions = len(self.actions) - self.completed_actions = { - action_id for action_id in (workflow.current_action or "").split(",") - if action_id in self.actions - } - self.finished_actions = len(self.completed_actions) - self.success = True self.has_failure = False self.stopped = False self.errmsg = "" - self.node_states = {action_id: "pending" for action_id in self.actions} - for action_id in self.completed_actions: - self.node_states[action_id] = "completed" + self.errors = self.get_restored_errors() + self.node_metadata = self.get_restored_node_metadata() + self.node_attempts = self.get_restored_attempts() + self.node_states = self.get_restored_node_states() + self.completed_actions = { + action_id for action_id, state in self.node_states.items() + if state == "success" + } + self.finished_actions = len([ + state for state in self.node_states.values() + if state in ("success", "failed", "skipped") + ]) self.flow_finished = set() self.flow_satisfied = set() + self.flow_failed = set() # 工作流管理器 self.workflowmanager = WorkFlowManager() # 线程安全队列 self.queue = deque() self.queued_actions = set() + self.active_concurrency_keys = set() # 锁用于保证线程安全 self.lock = threading.Lock() # 线程池 - self.executor = ThreadPoolExecutor() + self.executor = ThreadPoolExecutor(max_workers=self.get_workflow_max_workers()) + self.cancel_token = WorkflowCancelToken(self.workflow.id) # 跟踪运行中的任务数 self.running_tasks = 0 @@ -74,26 +109,162 @@ class WorkflowExecutor: self.incoming_flows[flow.target].append(flow) # 初始上下文 - if workflow.current_action and workflow.context: - logger.info(f"工作流已执行动作:{workflow.current_action}") - # Base64解码 - decoded_data = base64.b64decode(workflow.context["content"]) - # 反序列化数据 - self.context = pickle.loads(decoded_data) - else: - self.context = ActionContext() - self.context.node_outputs = self.context.node_outputs or {} + self.context = self.restore_context() + self.ensure_context_partitions() # 恢复工作流 global_vars.workflow_resume(self.workflow.id) - # 恢复时重新释放已完成节点的出边,使后继节点能继续执行。 - for action_id in self.completed_actions: - self.release_successors(action_id, source_success=True) + # 恢复时重新释放已终态节点的出边,使后继节点能继续执行或保持跳过传播。 + for action_id, state in self.node_states.items(): + if state == "success": + self.release_successors(action_id, source_success=True) + elif state in ("failed", "skipped"): + self.release_successors(action_id, source_success=False) # 初始化队列,添加没有入边的起始节点。 for action_id in self.actions: - if action_id not in self.completed_actions and not self.incoming_flows.get(action_id): + if self.node_states.get(action_id) == "pending" and not self.incoming_flows.get(action_id): self.enqueue_node(action_id) + def get_workflow_max_workers(self) -> int: + """ + 获取工作流最大并发数。 + """ + max_workers = self.execution_config.get("max_workers") if isinstance(self.execution_config, dict) else None + try: + return max(int(max_workers or DEFAULT_WORKFLOW_MAX_WORKERS), 1) + except (TypeError, ValueError): + return DEFAULT_WORKFLOW_MAX_WORKERS + + def get_restored_node_metadata(self) -> dict: + """ + 获取已持久化的节点状态元数据。 + """ + nodes = self.restored_execution_state.get("nodes") if isinstance(self.restored_execution_state, dict) else {} + return nodes if isinstance(nodes, dict) else {} + + def get_restored_errors(self) -> dict: + """ + 获取已持久化的错误状态。 + """ + errors = self.restored_execution_state.get("errors") if isinstance(self.restored_execution_state, dict) else {} + return errors if isinstance(errors, dict) else {} + + def get_restored_attempts(self) -> dict: + """ + 获取已持久化的节点尝试次数。 + """ + attempts = {} + for action_id, metadata in self.get_restored_node_metadata().items(): + if isinstance(metadata, dict) and metadata.get("attempt"): + attempts[action_id] = int(metadata.get("attempt") or 0) + return attempts + + def get_restored_node_states(self) -> dict: + """ + 获取结构化节点状态,兼容旧版 current_action 字符串。 + """ + legacy_actions = { + action_id for action_id in (self.workflow.current_action or "").split(",") + if action_id in self.actions + } + states = {} + for action_id in self.actions: + metadata = self.node_metadata.get(action_id) or {} + state = metadata.get("state") if isinstance(metadata, dict) else None + if state == "completed": + state = "success" + if state in ("running", "queued"): + state = "pending" + if not state and action_id in legacy_actions: + state = "success" + states[action_id] = state or "pending" + return states + + def restore_context(self) -> ActionContext: + """ + 恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。 + """ + context = ActionContext() + if self.workflow.current_action and self.workflow.context: + logger.info(f"工作流已执行动作:{self.workflow.current_action}") + try: + decoded_data = base64.b64decode(self.workflow.context["content"]) + context = pickle.loads(decoded_data) + except Exception as err: + logger.error(f"工作流上下文恢复失败: {str(err)}") + context = ActionContext() + outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {} + if outputs and not getattr(context, "node_outputs", None): + context.node_outputs = outputs + return context + + def ensure_context_partitions(self) -> None: + """ + 确保上下文具备新版分区结构,并把旧字段映射到 artifacts。 + """ + self.context.workflow_context = self.context.workflow_context or {} + self.context.node_outputs = self.context.node_outputs or {} + self.context.runtime_state = self.context.runtime_state or {} + self.context.artifacts = self.context.artifacts or {} + for key in ARTIFACT_FIELDS: + value = getattr(self.context, key, None) + if value not in (None, "", [], {}) and key not in self.context.artifacts: + self.context.artifacts[key] = value + self.update_runtime_state() + + def update_runtime_state(self) -> None: + """ + 更新上下文中的运行期状态分区。 + """ + self.context.runtime_state.update({ + "progress": self.context.progress, + "finished_actions": self.finished_actions, + "running_tasks": self.running_tasks, + "errors": self.errors, + "node_states": self.node_states, + "attempts": self.node_attempts, + }) + + def set_node_state(self, action_id: str, state: str, message: Optional[str] = None) -> None: + """ + 更新节点结构化状态。 + """ + now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + metadata = self.node_metadata.setdefault(action_id, {}) + metadata["state"] = state + metadata["attempt"] = self.node_attempts.get(action_id, metadata.get("attempt") or 0) + if state == "running": + metadata["started_at"] = now + if state in ("success", "failed", "skipped"): + metadata["finished_at"] = now + if message is not None: + metadata["message"] = message + self.node_states[action_id] = state + self.update_runtime_state() + + def build_execution_state(self) -> dict: + """ + 构建可持久化的结构化执行状态。 + """ + self.update_runtime_state() + return self.make_json_safe({ + "version": 1, + "nodes": self.node_metadata, + "outputs": self.context.node_outputs, + "errors": self.errors, + "runtime": self.context.runtime_state, + }) + + @staticmethod + def make_json_safe(value: Any) -> Any: + """ + 将运行期对象转换为可写入 JSON 列的数据。 + """ + try: + return jsonable_encoder(value) + except Exception: + return str(value) + def execute(self) -> None: """ 执行工作流 @@ -121,14 +292,9 @@ class WorkflowExecutor: elif not self.queue: should_sleep = True else: - # 取出队首节点 - node_id = self.queue.popleft() - self.queued_actions.discard(node_id) - if self.node_states.get(node_id) != "queued": - continue - self.node_states[node_id] = "running" - # 标记任务开始 - self.running_tasks += 1 + node_id = self.pop_dispatchable_node() + if not node_id: + should_sleep = True if should_sleep: sleep(0.1) @@ -148,19 +314,48 @@ class WorkflowExecutor: finally: self.executor.shutdown(wait=True, cancel_futures=True) + def pop_dispatchable_node(self) -> Optional[str]: + """ + 从队列中取出当前可调度节点。 + """ + for _ in range(len(self.queue)): + node_id = self.queue.popleft() + self.queued_actions.discard(node_id) + if self.node_states.get(node_id) != "queued": + continue + concurrency_key = self.get_action_concurrency_key(self.actions[node_id]) + if concurrency_key and concurrency_key in self.active_concurrency_keys: + self.queue.append(node_id) + self.queued_actions.add(node_id) + continue + if concurrency_key: + self.active_concurrency_keys.add(concurrency_key) + self.running_tasks += 1 + self.set_node_state(node_id, "running") + return node_id + return None + def execute_node(self, workflow_id: int, node_id: str, context: ActionContext) -> Tuple[Action, ActionResult]: """ 执行单个节点操作,返回修改后的上下文和节点ID """ action = self.actions[node_id] - action_result = self.workflowmanager.execute(workflow_id, action, context=context) + action_result = self.workflowmanager.execute( + workflow_id, + action, + context=context, + inputs=self.build_action_inputs(action), + runtime=self.build_action_runtime(action), + cancel_token=self.cancel_token + ) return action, action_result def on_node_complete(self, future): """ 节点完成回调:更新上下文、处理后继节点 """ + action = None try: action, action_result = future.result() with self.lock: @@ -172,6 +367,7 @@ class WorkflowExecutor: state = bool(action_result.success) message = action_result.message or "" result_ctx = action_result.context or ActionContext() + self.node_attempts[action.id] = action_result.attempts or self.node_attempts.get(action.id, 1) self.finished_actions += 1 self.update_progress() @@ -186,31 +382,36 @@ class WorkflowExecutor: # 节点执行失败时默认停止;显式配置 continue/ignore 时继续释放后续 all_done 汇合。 if not state: - self.node_states[action.id] = "failed" + self.errors[action.id] = message or f"{action.name} 失败" + self.set_node_state(action.id, "failed", message=message) fail_policy = self.get_action_fail_policy(action) if fail_policy != "ignore": self.has_failure = True self.errmsg = f"{action.name} 失败" if fail_policy == "stop": self.success = False + self.call_step_callback(action, completed=False) return if fail_policy not in ("continue", "ignore"): self.success = False self.errmsg = f"{action.name} 失败:无效失败策略 {fail_policy}" + self.call_step_callback(action, completed=False) return self.release_successors(action.id, source_success=False) + self.call_step_callback(action, completed=False) return - # 更新主上下文 - self.merge_context(result_ctx) - self.record_node_outputs(action.id, action_result, result_ctx) + self.ensure_result_context_partitions(result_ctx) + outputs = self.normalize_action_outputs(action, action_result, result_ctx) + self.merge_context_partitions(result_ctx) + self.merge_action_outputs(action, outputs) + self.record_node_outputs(action.id, outputs) self.completed_actions.add(action.id) - self.node_states[action.id] = "completed" + self.set_node_state(action.id, "success", message=message) # 处理后继节点 self.release_successors(action.id, source_success=True) # 回调 - if self.step_callback: - self.step_callback(action, self.context) + self.call_step_callback(action, completed=True) except Exception as err: logger.error(f"工作流节点执行回调失败: {str(err)}") with self.lock: @@ -219,7 +420,12 @@ class WorkflowExecutor: finally: # 标记任务完成 with self.lock: + if action: + concurrency_key = self.get_action_concurrency_key(action) + if concurrency_key: + self.active_concurrency_keys.discard(concurrency_key) self.running_tasks -= 1 + self.update_runtime_state() def enqueue_node(self, node_id: str) -> None: """ @@ -231,7 +437,7 @@ class WorkflowExecutor: return self.queue.append(node_id) self.queued_actions.add(node_id) - self.node_states[node_id] = "queued" + self.set_node_state(node_id, "queued") def skip_node(self, node_id: str, message: str) -> None: """ @@ -242,9 +448,9 @@ class WorkflowExecutor: if self.node_states.get(node_id) not in ("pending", "queued"): return self.queued_actions.discard(node_id) - self.node_states[node_id] = "skipped" self.finished_actions += 1 self.update_progress() + self.set_node_state(node_id, "skipped", message=message) self.context.execute_history.append( ActionExecution( action=self.actions[node_id].name, @@ -252,13 +458,17 @@ class WorkflowExecutor: message=message ) ) + self.call_step_callback(self.actions[node_id], completed=False) self.release_successors(node_id, source_success=False) def release_successors(self, source_id: str, source_success: bool) -> None: """ 根据源节点状态释放出边,并重新判断目标节点是否可运行。 """ - for flow in self.outgoing_flows.get(source_id, []): + flows = self.outgoing_flows.get(source_id, []) + branch_policy = self.get_action_branch_policy(self.actions.get(source_id), flows) + matched_exclusive_flow = None + for flow in flows: flow_key = self.get_flow_key(flow) if flow_key in self.flow_finished: continue @@ -270,9 +480,15 @@ class WorkflowExecutor: self.success = False self.errmsg = f"流程条件判断失败:{err}" return + if branch_policy == "exclusive" and condition_matched and matched_exclusive_flow: + condition_matched = False + elif branch_policy == "exclusive" and condition_matched: + matched_exclusive_flow = flow_key self.flow_finished.add(flow_key) if source_success and condition_matched: self.flow_satisfied.add(flow_key) + if not source_success and self.node_states.get(source_id) == "failed": + self.flow_failed.add(flow_key) self.evaluate_target_state(flow.target) def evaluate_target_state(self, target_id: str) -> None: @@ -291,8 +507,18 @@ class WorkflowExecutor: total_count = len(incoming_flows) finished_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_finished) satisfied_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_satisfied) + failed_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_failed) join_policy = self.get_action_join_policy(self.actions[target_id], incoming_flows) + if join_policy == "fail_fast": + if failed_count > 0: + self.skip_node(target_id, "上游失败触发 fail_fast,已取消后续节点") + elif finished_count == total_count and satisfied_count == total_count: + self.enqueue_node(target_id) + elif finished_count == total_count: + self.skip_node(target_id, "上游条件未全部满足,已跳过") + return + if join_policy == "any_success": if satisfied_count > 0: self.enqueue_node(target_id) @@ -322,14 +548,192 @@ class WorkflowExecutor: 根据已完成和已跳过节点数量更新整体进度。 """ self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100 + self.update_runtime_state() - def record_node_outputs(self, action_id: str, action_result: ActionResult, result_context: ActionContext) -> None: + def build_action_inputs(self, action: Action) -> dict: + """ + 根据动作输入声明读取上游节点输出。 + """ + inputs = {} + input_paths = action.inputs or self.get_action_data_value(action, "inputs") or [] + if isinstance(input_paths, str): + input_paths = [item.strip() for item in input_paths.splitlines() if item.strip()] + for input_path in input_paths: + inputs[input_path] = self.resolve_context_path(input_path) + return inputs + + def build_action_runtime(self, action: Action) -> dict: + """ + 构建传递给动作的新运行期数据。 + """ + return { + "workflow_id": self.workflow.id, + "action_id": action.id, + "execution_config": self.execution_config, + "runtime_state": self.context.runtime_state, + } + + def ensure_result_context_partitions(self, context: ActionContext) -> None: + """ + 确保动作返回上下文具备新版分区字段。 + """ + context.workflow_context = context.workflow_context or {} + context.node_outputs = context.node_outputs or {} + context.runtime_state = context.runtime_state or {} + context.artifacts = context.artifacts or {} + + def normalize_action_outputs(self, action: Action, action_result: ActionResult, + result_context: ActionContext) -> dict: + """ + 根据动作输出声明整理当前节点输出。 + """ + outputs = action_result.outputs or self.extract_context_outputs(result_context) + declared_outputs = action.outputs or self.get_action_data_value(action, "outputs") + if isinstance(declared_outputs, list): + return {key: outputs.get(key) for key in declared_outputs if outputs.get(key) not in (None, "", [], {})} + if isinstance(declared_outputs, dict): + return { + key: outputs.get(key) + for key in declared_outputs + if outputs.get(key) not in (None, "", [], {}) + } or outputs + return outputs + + def record_node_outputs(self, action_id: str, outputs: dict) -> None: """ 记录当前节点输出,供后续条件表达式读取。 """ - outputs = action_result.outputs or self.extract_context_outputs(result_context) if outputs: self.context.node_outputs[action_id] = outputs + self.context.runtime_state["last_outputs"] = outputs + + def merge_context_partitions(self, context: ActionContext) -> None: + """ + 合并动作返回的新分区上下文。 + """ + for key in ("workflow_context", "runtime_state", "artifacts"): + value = getattr(context, key, None) + if not value: + continue + current_value = getattr(self.context, key, None) or {} + current_value.update(value) + setattr(self.context, key, current_value) + + def merge_action_outputs(self, action: Action, outputs: dict) -> None: + """ + 按声明式合并策略写入全局上下文和 artifacts 分区。 + """ + for key, value in outputs.items(): + if value in (None, "", [], {}): + continue + output_config = self.get_action_output_config(action, key) + target_key = output_config.get("target") or key + merge_policy = output_config.get("merge") or self.get_default_merge_policy(action, target_key, value) + identity = output_config.get("identity") + self.merge_output_value(target_key, value, merge_policy, identity) + + def merge_output_value(self, key: str, value: Any, merge_policy: str, identity: Optional[str] = None) -> None: + """ + 按指定策略合并单个输出值。 + """ + current_value = getattr(self.context, key, None) if key in ActionContext.model_fields else None + merged_value = self.apply_merge_policy(current_value, value, merge_policy, identity) + if key in ActionContext.model_fields: + setattr(self.context, key, merged_value) + if key in ARTIFACT_FIELDS: + current_artifact = self.context.artifacts.get(key) + self.context.artifacts[key] = self.apply_merge_policy(current_artifact, value, merge_policy, identity) + + def get_action_output_config(self, action: Action, output_key: str) -> dict: + """ + 获取动作输出声明配置。 + """ + outputs_config = action.outputs or self.get_action_data_value(action, "outputs") or {} + if isinstance(outputs_config, dict): + value = outputs_config.get(output_key) or {} + return value if isinstance(value, dict) else {} + return {} + + def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str: + """ + 获取输出默认合并策略。 + """ + if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"): + return "replace" + if isinstance(value, list): + return "append_unique" + if isinstance(value, dict): + return "merge_dict" + return "first_non_empty" + + def apply_merge_policy(self, current_value: Any, value: Any, merge_policy: str, + identity: Optional[str] = None) -> Any: + """ + 应用声明式合并策略。 + """ + if merge_policy == "replace": + return value + if merge_policy == "merge_dict": + merged = current_value.copy() if isinstance(current_value, dict) else {} + if isinstance(value, dict): + merged.update(value) + return merged + return current_value or value + if merge_policy == "append_unique": + return self.append_unique_values(current_value, value, identity) + if merge_policy == "first_non_empty": + return current_value or value + return current_value or value + + def append_unique_values(self, current_value: Any, value: Any, identity: Optional[str] = None) -> list: + """ + 追加列表并按身份字段去重。 + """ + current_list = list(current_value or []) + incoming_list = value if isinstance(value, list) else [value] + seen = {self.get_identity_value(item, identity) for item in current_list} + for item in incoming_list: + identity_value = self.get_identity_value(item, identity) + if identity_value in seen: + continue + current_list.append(item) + seen.add(identity_value) + return current_list + + def get_identity_value(self, item: Any, identity: Optional[str] = None) -> Any: + """ + 获取列表元素去重身份。 + """ + if not identity: + identity_value = self.make_json_safe(item) + return self.make_hashable_identity(identity_value) + value = item + for part in identity.split("."): + value = self.read_value(value, int(part) if part.isdigit() else part) + return self.make_hashable_identity(self.make_json_safe(value)) + + @staticmethod + def make_hashable_identity(value: Any) -> Any: + """ + 将身份值转换为可哈希对象。 + """ + try: + hash(value) + return value + except TypeError: + return repr(value) + + def call_step_callback(self, action: Action, completed: bool) -> None: + """ + 持久化当前步骤上下文和结构化执行状态。 + """ + if not self.step_callback: + return + callback_params = inspect.signature(self.step_callback).parameters + if len(callback_params) <= 2: + self.step_callback(action, self.context) + return + self.step_callback(action, self.context, self.build_execution_state(), completed) @staticmethod def extract_context_outputs(context: ActionContext) -> dict: @@ -340,7 +744,7 @@ class WorkflowExecutor: return {} outputs = {} for key in context.__class__.model_fields: - if key in ("execute_history", "progress", "node_outputs"): + if key in ("execute_history", "progress", "node_outputs", "runtime_state"): continue value = getattr(context, key, None) if value in (None, "", [], {}): @@ -368,12 +772,32 @@ class WorkflowExecutor: return join_policy return "all_success" + def get_action_branch_policy(self, action: Optional[Action], outgoing_flows: List[ActionFlow]) -> str: + """ + 获取动作出边分支策略。 + """ + if action: + branch_policy = action.branch_policy or self.get_action_data_value(action, "branch_policy") + if branch_policy: + return branch_policy + for flow in outgoing_flows: + branch_policy = flow.branch_policy or self.get_flow_data_value(flow, "branch_policy") + if branch_policy: + return branch_policy + return "parallel" + def get_action_fail_policy(self, action: Action) -> str: """ 获取动作失败策略。 """ return action.fail_policy or self.get_action_data_value(action, "fail_policy") or "stop" + def get_action_concurrency_key(self, action: Action) -> Optional[str]: + """ + 获取动作并发互斥键。 + """ + return action.concurrency_key or self.get_action_data_value(action, "concurrency_key") + def get_flow_condition(self, flow: ActionFlow) -> Optional[str]: """ 获取流程边条件表达式。 @@ -381,10 +805,12 @@ class WorkflowExecutor: return flow.condition or self.get_flow_data_value(flow, "condition") @staticmethod - def get_action_data_value(action: Action, key: str) -> Any: + def get_action_data_value(action: Optional[Action], key: str) -> Any: """ 从动作 data 中读取扩展配置。 """ + if not action: + return None data = action.data or {} return data.get(key) if isinstance(data, dict) else None @@ -481,8 +907,18 @@ class WorkflowExecutor: return None if name == "context": return self.context + if name == "workflow_context": + return self.context.workflow_context or {} + if name == "runtime_state": + return self.context.runtime_state or {} + if name == "artifacts": + return self.context.artifacts or {} if name in ("outputs", "node_outputs"): return self.context.node_outputs or {} + if name == "last": + return self.context.runtime_state.get("last_outputs") if self.context.runtime_state else {} + if name in (self.context.node_outputs or {}): + return self.context.node_outputs[name] if name in ActionContext.model_fields: return getattr(self.context, name, None) raise ValueError(f"未知上下文变量 {name}") @@ -598,7 +1034,7 @@ class WorkflowChain(ChainBase): """ workflowoper = WorkflowOper() - def save_step(action: Action, context: ActionContext): + def save_step(action: Action, context: ActionContext, execution_state: dict, completed: bool): """ 保存上下文到数据库 """ @@ -606,9 +1042,14 @@ class WorkflowChain(ChainBase): serialized_data = pickle.dumps(context) # 使用Base64编码字节流 encoded_data = base64.b64encode(serialized_data).decode('utf-8') - WorkflowOper().step(workflow_id, action_id=action.id, context={ - "content": encoded_data - }) + WorkflowOper().step( + workflow_id, + action_id=action.id if completed else "", + context={ + "content": encoded_data + }, + execution_state=execution_state + ) # 重置工作流 if from_begin: diff --git a/app/db/models/workflow.py b/app/db/models/workflow.py index caa4bf32..4a251ecd 100644 --- a/app/db/models/workflow.py +++ b/app/db/models/workflow.py @@ -40,6 +40,10 @@ class Workflow(Base): flows = Column(JSON, default=builtin_list) # 执行上下文 context = Column(JSON, default=dict) + # 执行配置 + execution_config = Column(JSON, default=dict) + # 结构化执行状态 + execution_state = Column(JSON, default=dict) # 创建时间 add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S')) # 最后执行时间 @@ -218,6 +222,7 @@ class Workflow(Base): "result": None, "current_action": None, "context": {}, + "execution_state": {}, "run_count": 0 if reset_count else cls.run_count, }) return True @@ -231,39 +236,48 @@ class Workflow(Base): result=None, current_action=None, context={}, + execution_state={}, run_count=0 if reset_count else cls.run_count, )) return True @classmethod @db_update - def update_current_action(cls, db, wid: int, action_id: str, context: dict): + def update_current_action(cls, db, wid: int, action_id: str, context: dict, + execution_state: Optional[dict] = None): workflow = db.query(cls).filter(cls.id == wid).first() current_actions = [] if workflow and workflow.current_action: current_actions = [item for item in workflow.current_action.split(",") if item] - if action_id not in current_actions: + if action_id and action_id not in current_actions: current_actions.append(action_id) - db.query(cls).filter(cls.id == wid).update({ + update_values = { "current_action": ",".join(current_actions), "context": context - }) + } + if execution_state is not None: + update_values["execution_state"] = execution_state + db.query(cls).filter(cls.id == wid).update(update_values) return True @classmethod @async_db_update - async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict): + async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict, + execution_state: Optional[dict] = None): from sqlalchemy import update # 先获取当前current_action result = await db.execute(select(cls.current_action).where(cls.id == wid)) current_action = result.scalar() current_actions = [item for item in (current_action or "").split(",") if item] - if action_id not in current_actions: + if action_id and action_id not in current_actions: current_actions.append(action_id) new_current_action = ",".join(current_actions) - await db.execute(update(cls).where(cls.id == wid).values( - current_action=new_current_action, - context=context - )) + update_values = { + "current_action": new_current_action, + "context": context + } + if execution_state is not None: + update_values["execution_state"] = execution_state + await db.execute(update(cls).where(cls.id == wid).values(**update_values)) return True diff --git a/app/db/workflow_oper.py b/app/db/workflow_oper.py index 0175dbb0..f4c82942 100644 --- a/app/db/workflow_oper.py +++ b/app/db/workflow_oper.py @@ -91,11 +91,17 @@ class WorkflowOper(DbOper): """ return Workflow.fail(self._db, wid, result) - def step(self, wid: int, action_id: str, context: dict) -> bool: + def step(self, wid: int, action_id: str, context: dict, execution_state: Optional[dict] = None) -> bool: """ 步进 """ - return Workflow.update_current_action(self._db, wid, action_id, context) + return Workflow.update_current_action( + self._db, + wid, + action_id, + context, + execution_state=execution_state + ) def reset(self, wid: int, reset_count: bool = False) -> bool: """ diff --git a/app/schemas/workflow.py b/app/schemas/workflow.py index 30f52dc0..02309455 100644 --- a/app/schemas/workflow.py +++ b/app/schemas/workflow.py @@ -1,6 +1,6 @@ -from typing import Any, Optional, List +from typing import Any, List, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from app.schemas.context import Context, MediaInfo from app.schemas.download import DownloadTask @@ -26,6 +26,8 @@ class Workflow(BaseModel): run_count: Optional[int] = Field(default=0, description="已执行次数") actions: Optional[list] = Field(default_factory=list, description="任务列表") flows: Optional[list] = Field(default_factory=list, description="任务流") + execution_config: Optional[dict] = Field(default_factory=dict, description="工作流执行配置") + execution_state: Optional[dict] = Field(default_factory=dict, description="工作流结构化执行状态") add_time: Optional[str] = Field(default=None, description="创建时间") last_time: Optional[str] = Field(default=None, description="最后执行时间") @@ -50,8 +52,14 @@ class Action(BaseModel): description: Optional[str] = Field(default=None, description="动作描述") position: Optional[dict] = Field(default_factory=dict, description="位置") data: Optional[dict] = Field(default_factory=dict, description="参数") + inputs: Optional[List[str]] = Field(default_factory=list, description="动作输入声明") + outputs: Optional[dict] = Field(default_factory=dict, description="动作输出声明") join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略") fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略") + branch_policy: Optional[str] = Field(default=None, description="多出边分支策略") + concurrency_key: Optional[str] = Field(default=None, description="并发互斥键") + timeout: Optional[int] = Field(default=None, description="动作执行超时时间(秒)") + retry: Optional[dict] = Field(default_factory=dict, description="动作重试策略") class ActionExecution(BaseModel): @@ -74,7 +82,10 @@ class ActionContext(BaseModel): downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表") sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表") subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表") + workflow_context: Optional[dict] = Field(default_factory=dict, description="工作流全局上下文") node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据") + runtime_state: Optional[dict] = Field(default_factory=dict, description="运行期状态") + artifacts: Optional[dict] = Field(default_factory=dict, description="大对象引用与产物数据") execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史") progress: Optional[int] = Field(default=0, description="执行进度(%)") @@ -87,6 +98,8 @@ class ActionResult(BaseModel): message: Optional[str] = Field(default=None, description="动作执行消息") context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文") outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出") + next_policy: Optional[str] = Field(default=None, description="动作完成后的调度策略") + attempts: Optional[int] = Field(default=1, description="动作实际尝试次数") class ActionFlow(BaseModel): @@ -100,6 +113,7 @@ class ActionFlow(BaseModel): data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置") condition: Optional[str] = Field(default=None, description="流转条件表达式") join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略") + branch_policy: Optional[str] = Field(default=None, description="源节点分支策略") class WorkflowShare(BaseModel): diff --git a/app/workflow/__init__.py b/app/workflow/__init__.py index 4f88b569..eeb51ede 100644 --- a/app/workflow/__init__.py +++ b/app/workflow/__init__.py @@ -1,7 +1,6 @@ import threading -from time import sleep -from typing import Dict, Any, Optional -from typing import List, Tuple +from time import monotonic, sleep +from typing import Any, Dict, List, Optional, Tuple from app.core.config import global_vars from app.core.event import eventmanager, Event @@ -68,48 +67,54 @@ class WorkFlowManager(metaclass=Singleton): self._actions = {} self._event_workflows = {} - def execute(self, workflow_id: int, action: Action, - context: ActionContext = None) -> ActionResult: + def execute(self, workflow_id: int, action: Action, context: ActionContext = None, + inputs: Optional[dict] = None, runtime: Optional[dict] = None, + cancel_token: Optional[Any] = None) -> ActionResult: """ 执行工作流动作 """ if not context: context = ActionContext() - if action.type in self._actions: - # 实例化之前,清理掉类对象的数据 - - # 实例化 - action_obj = self._actions[action.type](action.id) - # 执行 - logger.info(f"执行动作: {action.id} - {action.name}") - try: - result_context = action_obj.execute(workflow_id, action.data, context) - action_result = self._normalize_action_result(result_context, action_obj, context) - except Exception as err: - logger.error(f"{action.name} 执行失败: {err}") - return ActionResult(success=False, message=f"{err}", context=context) - loop = (action.data or {}).get("loop") - loop_interval = (action.data or {}).get("loop_interval") - if loop and loop_interval: - while not action_obj.done: - if global_vars.is_workflow_stopped(workflow_id): - break - # 等待 - logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...") - sleep(loop_interval) - # 执行 - logger.info(f"继续执行动作: {action.id} - {action.name}") - result_context = action_obj.execute(workflow_id, action.data, action_result.context) - action_result = self._normalize_action_result(result_context, action_obj, action_result.context) - if action_result.success: - logger.info(f"{action.name} 执行成功") - else: - logger.error(f"{action.name} 执行失败!") - return action_result - else: + if action.type not in self._actions: logger.error(f"未找到动作: {action.type} - {action.name}") return ActionResult(success=False, message=" ", context=context) + retry_config = self._get_retry_config(action) + max_attempts = retry_config["max_attempts"] + interval = retry_config["interval"] + backoff = retry_config["backoff"] + action_result = ActionResult(success=False, message="", context=context) + + for attempt in range(1, max_attempts + 1): + if self._is_cancelled(workflow_id, cancel_token): + return ActionResult(success=False, message="工作流已取消", context=context) + runtime_data = { + **(runtime or {}), + "attempt": attempt, + "max_attempts": max_attempts, + "cancel_token": cancel_token, + } + action_result = self._execute_action_once( + workflow_id=workflow_id, + action=action, + context=context, + inputs=inputs or {}, + runtime=runtime_data, + cancel_token=cancel_token + ) + action_result.attempts = attempt + context = action_result.context or context + if action_result.success: + logger.info(f"{action.name} 执行成功") + return action_result + if attempt < max_attempts and not self._is_cancelled(workflow_id, cancel_token): + wait_seconds = interval * (backoff ** (attempt - 1)) + logger.info(f"{action.name} 执行失败,{wait_seconds} 秒后重试({attempt}/{max_attempts})...") + self._sleep_with_cancel(workflow_id, wait_seconds, cancel_token) + + logger.error(f"{action.name} 执行失败!") + return action_result + def excute(self, workflow_id: int, action: Action, context: ActionContext = None) -> Tuple[bool, str, ActionContext]: """ @@ -134,6 +139,100 @@ class WorkFlowManager(metaclass=Singleton): context=result or fallback_context ) + def _execute_action_once(self, workflow_id: int, action: Action, context: ActionContext, + inputs: dict, runtime: dict, cancel_token: Optional[Any]) -> ActionResult: + action_obj = self._actions[action.type](action.id) + logger.info(f"执行动作: {action.id} - {action.name}") + try: + action_result = self._run_action_with_loop( + workflow_id=workflow_id, + action=action, + action_obj=action_obj, + context=context, + inputs=inputs, + runtime=runtime, + cancel_token=cancel_token + ) + except Exception as err: + logger.error(f"{action.name} 执行失败: {err}") + return ActionResult(success=False, message=f"{err}", context=context) + return action_result + + def _run_action_with_loop(self, workflow_id: int, action: Action, action_obj: Any, + context: ActionContext, inputs: dict, runtime: dict, + cancel_token: Optional[Any]) -> ActionResult: + timeout = self._get_action_timeout(action) + started_at = monotonic() + action_result = self._call_action( + workflow_id=workflow_id, + action=action, + action_obj=action_obj, + context=context, + inputs=inputs, + runtime=runtime + ) + loop = self._get_action_data_value(action, "loop") + loop_interval = self._get_action_data_value(action, "loop_interval") + while loop and loop_interval and not action_obj.done: + if self._is_cancelled(workflow_id, cancel_token): + return ActionResult(success=False, message="工作流已取消", context=action_result.context or context) + if timeout and monotonic() - started_at >= timeout: + return ActionResult(success=False, message=f"动作执行超时({timeout}秒)", context=action_result.context or context) + logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...") + self._sleep_with_cancel(workflow_id, loop_interval, cancel_token) + if self._is_cancelled(workflow_id, cancel_token): + return ActionResult(success=False, message="工作流已取消", context=action_result.context or context) + logger.info(f"继续执行动作: {action.id} - {action.name}") + action_result = self._call_action( + workflow_id=workflow_id, + action=action, + action_obj=action_obj, + context=action_result.context or context, + inputs=inputs, + runtime=runtime + ) + return action_result + + def _call_action(self, workflow_id: int, action: Action, action_obj: Any, + context: ActionContext, inputs: dict, runtime: dict) -> ActionResult: + if hasattr(action_obj, "execute_with_inputs"): + result = action_obj.execute_with_inputs(workflow_id, action.data, inputs, runtime, context) + else: + result = action_obj.execute(workflow_id, action.data, context) + return self._normalize_action_result(result, action_obj, context) + + @staticmethod + def _get_action_data_value(action: Action, key: str) -> Any: + data = action.data or {} + return data.get(key) if isinstance(data, dict) else None + + def _get_action_timeout(self, action: Action) -> Optional[int]: + timeout = action.timeout or self._get_action_data_value(action, "timeout") + return int(timeout) if timeout else None + + def _get_retry_config(self, action: Action) -> dict: + retry_config = action.retry or self._get_action_data_value(action, "retry") or {} + if not isinstance(retry_config, dict): + retry_config = {} + return { + "max_attempts": max(int(retry_config.get("max_attempts") or 1), 1), + "interval": max(float(retry_config.get("interval") or 0), 0), + "backoff": max(float(retry_config.get("backoff") or 1), 1), + } + + @staticmethod + def _is_cancelled(workflow_id: int, cancel_token: Optional[Any]) -> bool: + if cancel_token and cancel_token.is_cancelled(): + return True + return global_vars.is_workflow_stopped(workflow_id) + + def _sleep_with_cancel(self, workflow_id: int, seconds: float, cancel_token: Optional[Any]) -> None: + deadline = monotonic() + seconds + while monotonic() < deadline: + if self._is_cancelled(workflow_id, cancel_token): + return + sleep(min(0.1, deadline - monotonic())) + def list_actions(self) -> List[dict]: """ 获取所有动作 diff --git a/app/workflow/actions/__init__.py b/app/workflow/actions/__init__.py index 895a784a..2f5a481c 100644 --- a/app/workflow/actions/__init__.py +++ b/app/workflow/actions/__init__.py @@ -3,7 +3,7 @@ from typing import Union from app.chain import ChainBase from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import ActionContext, ActionParams +from app.schemas import ActionContext, ActionParams, ActionResult class ActionChain(ChainBase): @@ -109,3 +109,16 @@ class BaseAction(ABC): 执行动作 """ raise NotImplementedError + + def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict, + runtime: dict, context: ActionContext) -> ActionResult: + """ + 使用显式输入与运行期信息执行动作。 + """ + _ = inputs, runtime + result_context = self.execute(workflow_id, params, context) + return ActionResult( + success=self.success, + message=self.message, + context=result_context + ) diff --git a/database/versions/7c1a2b3d4e5f_2_2_9.py b/database/versions/7c1a2b3d4e5f_2_2_9.py new file mode 100644 index 00000000..978ed0ed --- /dev/null +++ b/database/versions/7c1a2b3d4e5f_2_2_9.py @@ -0,0 +1,45 @@ +"""2.2.9 +为工作流增加执行配置和结构化执行状态 + +Revision ID: 7c1a2b3d4e5f +Revises: d5e6f7a8b9c0 +Create Date: 2026-06-04 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "7c1a2b3d4e5f" +down_revision = "d5e6f7a8b9c0" +branch_labels = None +depends_on = None + + +def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool: + """检查数据表是否已存在指定列。""" + if table_name not in inspector.get_table_names(): + return False + return any(column["name"] == column_name for column in inspector.get_columns(table_name)) + + +def _add_json_column_if_missing(table_name: str, column_name: str) -> None: + """缺失时为数据表新增 JSON 列。""" + inspector = sa.inspect(op.get_bind()) + if not _has_column(inspector, table_name, column_name): + op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True)) + + +def upgrade() -> None: + """升级数据库结构。""" + _add_json_column_if_missing("workflow", "execution_config") + _add_json_column_if_missing("workflow", "execution_state") + + +def downgrade() -> None: + """回滚数据库结构。""" + inspector = sa.inspect(op.get_bind()) + if _has_column(inspector, "workflow", "execution_state"): + op.drop_column("workflow", "execution_state") + inspector = sa.inspect(op.get_bind()) + if _has_column(inspector, "workflow", "execution_config"): + op.drop_column("workflow", "execution_config") diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index e4a409db..b93288a4 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -1,28 +1,32 @@ import base64 import pickle import threading +import time from types import SimpleNamespace from app.chain import workflow as workflow_module -from app.schemas import ActionContext, ActionResult +from app.schemas import Action, ActionContext, ActionResult from app.schemas.types import EventType from app import workflow as workflow_package -def _build_workflow(current_action=None, context=None, actions=None, flows=None): +def _build_workflow(current_action=None, context=None, actions=None, flows=None, + execution_config=None, execution_state=None): """构造最小工作流对象。""" return SimpleNamespace( id=1, name="测试工作流", - actions=actions or [ + actions=actions if actions is not None else [ {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, ], - flows=flows or [ + flows=flows if flows is not None else [ {"id": "flow-1", "source": "A", "target": "B", "animated": True}, ], current_action=current_action, context=context, + execution_config=execution_config or {}, + execution_state=execution_state or {}, ) @@ -39,9 +43,12 @@ class _FakeWorkflowManager: def __init__(self, calls, results=None): self.calls = calls self.results = results or {} + self.received_inputs = [] - def execute(self, workflow_id, action, context=None): + def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None): + """执行伪动作并记录新版输入。""" self.calls.append(action.id) + self.received_inputs.append((action.id, inputs or {}, runtime or {}, cancel_token)) result = self.results.get(action.id) if callable(result): return result(action, context or ActionContext()) @@ -262,6 +269,220 @@ def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch) assert executor.success is True +def test_workflow_executor_exclusive_branch_uses_first_matching_flow(monkeypatch): + """互斥分支应只执行第一条满足条件的出边。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"count": 2} + ) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {"branch_policy": "exclusive"}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + {"id": "C", "type": "FakeAction", "name": "动作C", "data": {}}, + ], + flows=[ + {"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.count > 0"}, + {"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.count > 1"}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["A", "B"] + assert executor.node_states["C"] == "skipped" + + +def test_workflow_executor_passes_declared_inputs(monkeypatch): + """动作输入声明应从 node_outputs 中读取指定路径。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"torrents": ["a", "b"]} + ) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + { + "id": "B", + "type": "FakeAction", + "name": "动作B", + "data": {"inputs": ["A.torrents", "outputs.A.torrents.count"]}, + }, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0] + assert b_inputs == { + "A.torrents": ["a", "b"], + "outputs.A.torrents.count": 2, + } + + +def test_workflow_executor_persists_structured_state(monkeypatch): + """步骤回调应收到可持久化的结构化执行状态。""" + calls = [] + states = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"items": ["movie"]} + ) + } + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor( + _build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]), + step_callback=lambda action, context, execution_state, completed: states.append(execution_state), + ) + executor.execute() + + assert states[-1]["nodes"]["A"]["state"] == "success" + assert states[-1]["outputs"]["A"]["items"] == ["movie"] + assert states[-1]["runtime"]["progress"] == 100 + + +def test_workflow_executor_restores_outputs_from_execution_state(monkeypatch): + """恢复执行时应从结构化状态读取节点输出并继续判断条件边。""" + calls = [] + fake_manager = _FakeWorkflowManager(calls) + workflow = _build_workflow( + execution_state={ + "nodes": { + "A": {"state": "success", "attempt": 1}, + }, + "outputs": { + "A": {"torrents": ["movie"]}, + }, + }, + flows=[ + {"id": "flow-ab", "source": "A", "target": "B", "condition": "A.torrents.count > 0"}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["B"] + assert executor.context.node_outputs["A"]["torrents"] == ["movie"] + + +def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch): + """相同 concurrency_key 的并行节点不应同时运行。""" + calls = [] + active_count = 0 + max_active_count = 0 + lock = threading.Lock() + + def run_action(action, context): + """记录同一并发键下的同时运行数量。""" + nonlocal active_count, max_active_count + with lock: + active_count += 1 + max_active_count = max(max_active_count, active_count) + time.sleep(0.05) + with lock: + active_count -= 1 + return ActionResult(success=True, message=f"{action.name}完成", context=context) + + fake_manager = _FakeWorkflowManager(calls, results={"A": run_action, "B": run_action}) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {"concurrency_key": "download"}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {"concurrency_key": "download"}}, + ], + flows=[], + execution_config={"max_workers": 2}, + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert set(calls) == {"A", "B"} + assert max_active_count == 1 + + +def test_workflow_executor_filter_action_replaces_artifact_outputs(monkeypatch): + """过滤类动作默认应替换列表输出,避免把过滤前数据重新合并回来。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"torrents": ["old", "keep"]} + ), + "B": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"torrents": ["keep"]} + ), + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "FilterTorrentsAction", "name": "过滤资源", "data": {}}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert executor.context.torrents == ["keep"] + assert executor.context.artifacts["torrents"] == ["keep"] + + def test_workflow_executor_stop_is_not_success(monkeypatch): """停止信号不应被执行器汇报为成功完成。""" calls = [] @@ -328,3 +549,43 @@ def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkey assert fake_eventmanager.removed == [EventType.DownloadAdded] assert manager.get_event_workflows() == {} + + +def test_workflow_manager_retries_action_until_success(monkeypatch): + """动作管理器应按 retry 配置重试失败动作。""" + + class RetryAction: + """模拟第二次才成功的动作。""" + + call_count = 0 + + def __init__(self, action_id): + self.action_id = action_id + + def execute_with_inputs(self, workflow_id, params, inputs, runtime, context): + """执行动作并在第二次返回成功。""" + _ = workflow_id, params, inputs, runtime + RetryAction.call_count += 1 + if RetryAction.call_count == 1: + return ActionResult(success=False, message="第一次失败", context=context) + return ActionResult(success=True, message="第二次成功", context=context, outputs={"ok": True}) + + manager = object.__new__(workflow_package.WorkFlowManager) + manager._actions = {"RetryAction": RetryAction} + monkeypatch.setattr(workflow_package.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + result = manager.execute( + workflow_id=1, + action=Action( + id="retry", + type="RetryAction", + name="重试动作", + data={"retry": {"max_attempts": 2, "interval": 0}}, + ), + context=ActionContext(), + ) + + assert result.success is True + assert result.attempts == 2 + assert result.outputs == {"ok": True} + assert RetryAction.call_count == 2