mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 23:16:47 +00:00
feat(workflow): add execution configuration and structured execution state to workflow
This commit is contained in:
@@ -1,13 +1,17 @@
|
||||
import ast
|
||||
import base64
|
||||
import copy
|
||||
import inspect
|
||||
import pickle
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import Event, eventmanager
|
||||
@@ -19,6 +23,29 @@ from app.schemas.types import EventType
|
||||
from app.workflow import WorkFlowManager
|
||||
|
||||
|
||||
ARTIFACT_FIELDS = {"torrents", "medias", "fileitems", "downloads", "sites", "subscribes"}
|
||||
DEFAULT_WORKFLOW_MAX_WORKERS = 4
|
||||
|
||||
|
||||
class WorkflowCancelToken:
|
||||
"""
|
||||
工作流取消令牌。
|
||||
"""
|
||||
|
||||
def __init__(self, workflow_id: int):
|
||||
"""
|
||||
初始化取消令牌。
|
||||
:param workflow_id: 工作流ID
|
||||
"""
|
||||
self.workflow_id = workflow_id
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
"""
|
||||
判断工作流是否已被取消。
|
||||
"""
|
||||
return global_vars.is_workflow_stopped(self.workflow_id)
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""
|
||||
工作流执行器
|
||||
@@ -35,32 +62,40 @@ class WorkflowExecutor:
|
||||
self.step_callback = step_callback
|
||||
self.actions = {action['id']: Action(**action) for action in workflow.actions}
|
||||
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
|
||||
self.execution_config = getattr(workflow, "execution_config", None) or {}
|
||||
self.restored_execution_state = getattr(workflow, "execution_state", None) or {}
|
||||
self.total_actions = len(self.actions)
|
||||
self.completed_actions = {
|
||||
action_id for action_id in (workflow.current_action or "").split(",")
|
||||
if action_id in self.actions
|
||||
}
|
||||
self.finished_actions = len(self.completed_actions)
|
||||
|
||||
self.success = True
|
||||
self.has_failure = False
|
||||
self.stopped = False
|
||||
self.errmsg = ""
|
||||
self.node_states = {action_id: "pending" for action_id in self.actions}
|
||||
for action_id in self.completed_actions:
|
||||
self.node_states[action_id] = "completed"
|
||||
self.errors = self.get_restored_errors()
|
||||
self.node_metadata = self.get_restored_node_metadata()
|
||||
self.node_attempts = self.get_restored_attempts()
|
||||
self.node_states = self.get_restored_node_states()
|
||||
self.completed_actions = {
|
||||
action_id for action_id, state in self.node_states.items()
|
||||
if state == "success"
|
||||
}
|
||||
self.finished_actions = len([
|
||||
state for state in self.node_states.values()
|
||||
if state in ("success", "failed", "skipped")
|
||||
])
|
||||
self.flow_finished = set()
|
||||
self.flow_satisfied = set()
|
||||
self.flow_failed = set()
|
||||
|
||||
# 工作流管理器
|
||||
self.workflowmanager = WorkFlowManager()
|
||||
# 线程安全队列
|
||||
self.queue = deque()
|
||||
self.queued_actions = set()
|
||||
self.active_concurrency_keys = set()
|
||||
# 锁用于保证线程安全
|
||||
self.lock = threading.Lock()
|
||||
# 线程池
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.get_workflow_max_workers())
|
||||
self.cancel_token = WorkflowCancelToken(self.workflow.id)
|
||||
# 跟踪运行中的任务数
|
||||
self.running_tasks = 0
|
||||
|
||||
@@ -74,26 +109,162 @@ class WorkflowExecutor:
|
||||
self.incoming_flows[flow.target].append(flow)
|
||||
|
||||
# 初始上下文
|
||||
if workflow.current_action and workflow.context:
|
||||
logger.info(f"工作流已执行动作:{workflow.current_action}")
|
||||
# Base64解码
|
||||
decoded_data = base64.b64decode(workflow.context["content"])
|
||||
# 反序列化数据
|
||||
self.context = pickle.loads(decoded_data)
|
||||
else:
|
||||
self.context = ActionContext()
|
||||
self.context.node_outputs = self.context.node_outputs or {}
|
||||
self.context = self.restore_context()
|
||||
self.ensure_context_partitions()
|
||||
|
||||
# 恢复工作流
|
||||
global_vars.workflow_resume(self.workflow.id)
|
||||
# 恢复时重新释放已完成节点的出边,使后继节点能继续执行。
|
||||
for action_id in self.completed_actions:
|
||||
self.release_successors(action_id, source_success=True)
|
||||
# 恢复时重新释放已终态节点的出边,使后继节点能继续执行或保持跳过传播。
|
||||
for action_id, state in self.node_states.items():
|
||||
if state == "success":
|
||||
self.release_successors(action_id, source_success=True)
|
||||
elif state in ("failed", "skipped"):
|
||||
self.release_successors(action_id, source_success=False)
|
||||
# 初始化队列,添加没有入边的起始节点。
|
||||
for action_id in self.actions:
|
||||
if action_id not in self.completed_actions and not self.incoming_flows.get(action_id):
|
||||
if self.node_states.get(action_id) == "pending" and not self.incoming_flows.get(action_id):
|
||||
self.enqueue_node(action_id)
|
||||
|
||||
def get_workflow_max_workers(self) -> int:
|
||||
"""
|
||||
获取工作流最大并发数。
|
||||
"""
|
||||
max_workers = self.execution_config.get("max_workers") if isinstance(self.execution_config, dict) else None
|
||||
try:
|
||||
return max(int(max_workers or DEFAULT_WORKFLOW_MAX_WORKERS), 1)
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_WORKFLOW_MAX_WORKERS
|
||||
|
||||
def get_restored_node_metadata(self) -> dict:
|
||||
"""
|
||||
获取已持久化的节点状态元数据。
|
||||
"""
|
||||
nodes = self.restored_execution_state.get("nodes") if isinstance(self.restored_execution_state, dict) else {}
|
||||
return nodes if isinstance(nodes, dict) else {}
|
||||
|
||||
def get_restored_errors(self) -> dict:
|
||||
"""
|
||||
获取已持久化的错误状态。
|
||||
"""
|
||||
errors = self.restored_execution_state.get("errors") if isinstance(self.restored_execution_state, dict) else {}
|
||||
return errors if isinstance(errors, dict) else {}
|
||||
|
||||
def get_restored_attempts(self) -> dict:
|
||||
"""
|
||||
获取已持久化的节点尝试次数。
|
||||
"""
|
||||
attempts = {}
|
||||
for action_id, metadata in self.get_restored_node_metadata().items():
|
||||
if isinstance(metadata, dict) and metadata.get("attempt"):
|
||||
attempts[action_id] = int(metadata.get("attempt") or 0)
|
||||
return attempts
|
||||
|
||||
def get_restored_node_states(self) -> dict:
|
||||
"""
|
||||
获取结构化节点状态,兼容旧版 current_action 字符串。
|
||||
"""
|
||||
legacy_actions = {
|
||||
action_id for action_id in (self.workflow.current_action or "").split(",")
|
||||
if action_id in self.actions
|
||||
}
|
||||
states = {}
|
||||
for action_id in self.actions:
|
||||
metadata = self.node_metadata.get(action_id) or {}
|
||||
state = metadata.get("state") if isinstance(metadata, dict) else None
|
||||
if state == "completed":
|
||||
state = "success"
|
||||
if state in ("running", "queued"):
|
||||
state = "pending"
|
||||
if not state and action_id in legacy_actions:
|
||||
state = "success"
|
||||
states[action_id] = state or "pending"
|
||||
return states
|
||||
|
||||
def restore_context(self) -> ActionContext:
|
||||
"""
|
||||
恢复工作流上下文,兼容旧版 Base64 Pickle 存储格式。
|
||||
"""
|
||||
context = ActionContext()
|
||||
if self.workflow.current_action and self.workflow.context:
|
||||
logger.info(f"工作流已执行动作:{self.workflow.current_action}")
|
||||
try:
|
||||
decoded_data = base64.b64decode(self.workflow.context["content"])
|
||||
context = pickle.loads(decoded_data)
|
||||
except Exception as err:
|
||||
logger.error(f"工作流上下文恢复失败: {str(err)}")
|
||||
context = ActionContext()
|
||||
outputs = self.restored_execution_state.get("outputs") if isinstance(self.restored_execution_state, dict) else {}
|
||||
if outputs and not getattr(context, "node_outputs", None):
|
||||
context.node_outputs = outputs
|
||||
return context
|
||||
|
||||
def ensure_context_partitions(self) -> None:
|
||||
"""
|
||||
确保上下文具备新版分区结构,并把旧字段映射到 artifacts。
|
||||
"""
|
||||
self.context.workflow_context = self.context.workflow_context or {}
|
||||
self.context.node_outputs = self.context.node_outputs or {}
|
||||
self.context.runtime_state = self.context.runtime_state or {}
|
||||
self.context.artifacts = self.context.artifacts or {}
|
||||
for key in ARTIFACT_FIELDS:
|
||||
value = getattr(self.context, key, None)
|
||||
if value not in (None, "", [], {}) and key not in self.context.artifacts:
|
||||
self.context.artifacts[key] = value
|
||||
self.update_runtime_state()
|
||||
|
||||
def update_runtime_state(self) -> None:
|
||||
"""
|
||||
更新上下文中的运行期状态分区。
|
||||
"""
|
||||
self.context.runtime_state.update({
|
||||
"progress": self.context.progress,
|
||||
"finished_actions": self.finished_actions,
|
||||
"running_tasks": self.running_tasks,
|
||||
"errors": self.errors,
|
||||
"node_states": self.node_states,
|
||||
"attempts": self.node_attempts,
|
||||
})
|
||||
|
||||
def set_node_state(self, action_id: str, state: str, message: Optional[str] = None) -> None:
|
||||
"""
|
||||
更新节点结构化状态。
|
||||
"""
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
metadata = self.node_metadata.setdefault(action_id, {})
|
||||
metadata["state"] = state
|
||||
metadata["attempt"] = self.node_attempts.get(action_id, metadata.get("attempt") or 0)
|
||||
if state == "running":
|
||||
metadata["started_at"] = now
|
||||
if state in ("success", "failed", "skipped"):
|
||||
metadata["finished_at"] = now
|
||||
if message is not None:
|
||||
metadata["message"] = message
|
||||
self.node_states[action_id] = state
|
||||
self.update_runtime_state()
|
||||
|
||||
def build_execution_state(self) -> dict:
|
||||
"""
|
||||
构建可持久化的结构化执行状态。
|
||||
"""
|
||||
self.update_runtime_state()
|
||||
return self.make_json_safe({
|
||||
"version": 1,
|
||||
"nodes": self.node_metadata,
|
||||
"outputs": self.context.node_outputs,
|
||||
"errors": self.errors,
|
||||
"runtime": self.context.runtime_state,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def make_json_safe(value: Any) -> Any:
|
||||
"""
|
||||
将运行期对象转换为可写入 JSON 列的数据。
|
||||
"""
|
||||
try:
|
||||
return jsonable_encoder(value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
def execute(self) -> None:
|
||||
"""
|
||||
执行工作流
|
||||
@@ -121,14 +292,9 @@ class WorkflowExecutor:
|
||||
elif not self.queue:
|
||||
should_sleep = True
|
||||
else:
|
||||
# 取出队首节点
|
||||
node_id = self.queue.popleft()
|
||||
self.queued_actions.discard(node_id)
|
||||
if self.node_states.get(node_id) != "queued":
|
||||
continue
|
||||
self.node_states[node_id] = "running"
|
||||
# 标记任务开始
|
||||
self.running_tasks += 1
|
||||
node_id = self.pop_dispatchable_node()
|
||||
if not node_id:
|
||||
should_sleep = True
|
||||
|
||||
if should_sleep:
|
||||
sleep(0.1)
|
||||
@@ -148,19 +314,48 @@ class WorkflowExecutor:
|
||||
finally:
|
||||
self.executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
def pop_dispatchable_node(self) -> Optional[str]:
|
||||
"""
|
||||
从队列中取出当前可调度节点。
|
||||
"""
|
||||
for _ in range(len(self.queue)):
|
||||
node_id = self.queue.popleft()
|
||||
self.queued_actions.discard(node_id)
|
||||
if self.node_states.get(node_id) != "queued":
|
||||
continue
|
||||
concurrency_key = self.get_action_concurrency_key(self.actions[node_id])
|
||||
if concurrency_key and concurrency_key in self.active_concurrency_keys:
|
||||
self.queue.append(node_id)
|
||||
self.queued_actions.add(node_id)
|
||||
continue
|
||||
if concurrency_key:
|
||||
self.active_concurrency_keys.add(concurrency_key)
|
||||
self.running_tasks += 1
|
||||
self.set_node_state(node_id, "running")
|
||||
return node_id
|
||||
return None
|
||||
|
||||
def execute_node(self, workflow_id: int, node_id: str,
|
||||
context: ActionContext) -> Tuple[Action, ActionResult]:
|
||||
"""
|
||||
执行单个节点操作,返回修改后的上下文和节点ID
|
||||
"""
|
||||
action = self.actions[node_id]
|
||||
action_result = self.workflowmanager.execute(workflow_id, action, context=context)
|
||||
action_result = self.workflowmanager.execute(
|
||||
workflow_id,
|
||||
action,
|
||||
context=context,
|
||||
inputs=self.build_action_inputs(action),
|
||||
runtime=self.build_action_runtime(action),
|
||||
cancel_token=self.cancel_token
|
||||
)
|
||||
return action, action_result
|
||||
|
||||
def on_node_complete(self, future):
|
||||
"""
|
||||
节点完成回调:更新上下文、处理后继节点
|
||||
"""
|
||||
action = None
|
||||
try:
|
||||
action, action_result = future.result()
|
||||
with self.lock:
|
||||
@@ -172,6 +367,7 @@ class WorkflowExecutor:
|
||||
state = bool(action_result.success)
|
||||
message = action_result.message or ""
|
||||
result_ctx = action_result.context or ActionContext()
|
||||
self.node_attempts[action.id] = action_result.attempts or self.node_attempts.get(action.id, 1)
|
||||
|
||||
self.finished_actions += 1
|
||||
self.update_progress()
|
||||
@@ -186,31 +382,36 @@ class WorkflowExecutor:
|
||||
|
||||
# 节点执行失败时默认停止;显式配置 continue/ignore 时继续释放后续 all_done 汇合。
|
||||
if not state:
|
||||
self.node_states[action.id] = "failed"
|
||||
self.errors[action.id] = message or f"{action.name} 失败"
|
||||
self.set_node_state(action.id, "failed", message=message)
|
||||
fail_policy = self.get_action_fail_policy(action)
|
||||
if fail_policy != "ignore":
|
||||
self.has_failure = True
|
||||
self.errmsg = f"{action.name} 失败"
|
||||
if fail_policy == "stop":
|
||||
self.success = False
|
||||
self.call_step_callback(action, completed=False)
|
||||
return
|
||||
if fail_policy not in ("continue", "ignore"):
|
||||
self.success = False
|
||||
self.errmsg = f"{action.name} 失败:无效失败策略 {fail_policy}"
|
||||
self.call_step_callback(action, completed=False)
|
||||
return
|
||||
self.release_successors(action.id, source_success=False)
|
||||
self.call_step_callback(action, completed=False)
|
||||
return
|
||||
|
||||
# 更新主上下文
|
||||
self.merge_context(result_ctx)
|
||||
self.record_node_outputs(action.id, action_result, result_ctx)
|
||||
self.ensure_result_context_partitions(result_ctx)
|
||||
outputs = self.normalize_action_outputs(action, action_result, result_ctx)
|
||||
self.merge_context_partitions(result_ctx)
|
||||
self.merge_action_outputs(action, outputs)
|
||||
self.record_node_outputs(action.id, outputs)
|
||||
self.completed_actions.add(action.id)
|
||||
self.node_states[action.id] = "completed"
|
||||
self.set_node_state(action.id, "success", message=message)
|
||||
# 处理后继节点
|
||||
self.release_successors(action.id, source_success=True)
|
||||
# 回调
|
||||
if self.step_callback:
|
||||
self.step_callback(action, self.context)
|
||||
self.call_step_callback(action, completed=True)
|
||||
except Exception as err:
|
||||
logger.error(f"工作流节点执行回调失败: {str(err)}")
|
||||
with self.lock:
|
||||
@@ -219,7 +420,12 @@ class WorkflowExecutor:
|
||||
finally:
|
||||
# 标记任务完成
|
||||
with self.lock:
|
||||
if action:
|
||||
concurrency_key = self.get_action_concurrency_key(action)
|
||||
if concurrency_key:
|
||||
self.active_concurrency_keys.discard(concurrency_key)
|
||||
self.running_tasks -= 1
|
||||
self.update_runtime_state()
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
@@ -231,7 +437,7 @@ class WorkflowExecutor:
|
||||
return
|
||||
self.queue.append(node_id)
|
||||
self.queued_actions.add(node_id)
|
||||
self.node_states[node_id] = "queued"
|
||||
self.set_node_state(node_id, "queued")
|
||||
|
||||
def skip_node(self, node_id: str, message: str) -> None:
|
||||
"""
|
||||
@@ -242,9 +448,9 @@ class WorkflowExecutor:
|
||||
if self.node_states.get(node_id) not in ("pending", "queued"):
|
||||
return
|
||||
self.queued_actions.discard(node_id)
|
||||
self.node_states[node_id] = "skipped"
|
||||
self.finished_actions += 1
|
||||
self.update_progress()
|
||||
self.set_node_state(node_id, "skipped", message=message)
|
||||
self.context.execute_history.append(
|
||||
ActionExecution(
|
||||
action=self.actions[node_id].name,
|
||||
@@ -252,13 +458,17 @@ class WorkflowExecutor:
|
||||
message=message
|
||||
)
|
||||
)
|
||||
self.call_step_callback(self.actions[node_id], completed=False)
|
||||
self.release_successors(node_id, source_success=False)
|
||||
|
||||
def release_successors(self, source_id: str, source_success: bool) -> None:
|
||||
"""
|
||||
根据源节点状态释放出边,并重新判断目标节点是否可运行。
|
||||
"""
|
||||
for flow in self.outgoing_flows.get(source_id, []):
|
||||
flows = self.outgoing_flows.get(source_id, [])
|
||||
branch_policy = self.get_action_branch_policy(self.actions.get(source_id), flows)
|
||||
matched_exclusive_flow = None
|
||||
for flow in flows:
|
||||
flow_key = self.get_flow_key(flow)
|
||||
if flow_key in self.flow_finished:
|
||||
continue
|
||||
@@ -270,9 +480,15 @@ class WorkflowExecutor:
|
||||
self.success = False
|
||||
self.errmsg = f"流程条件判断失败:{err}"
|
||||
return
|
||||
if branch_policy == "exclusive" and condition_matched and matched_exclusive_flow:
|
||||
condition_matched = False
|
||||
elif branch_policy == "exclusive" and condition_matched:
|
||||
matched_exclusive_flow = flow_key
|
||||
self.flow_finished.add(flow_key)
|
||||
if source_success and condition_matched:
|
||||
self.flow_satisfied.add(flow_key)
|
||||
if not source_success and self.node_states.get(source_id) == "failed":
|
||||
self.flow_failed.add(flow_key)
|
||||
self.evaluate_target_state(flow.target)
|
||||
|
||||
def evaluate_target_state(self, target_id: str) -> None:
|
||||
@@ -291,8 +507,18 @@ class WorkflowExecutor:
|
||||
total_count = len(incoming_flows)
|
||||
finished_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_finished)
|
||||
satisfied_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_satisfied)
|
||||
failed_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_failed)
|
||||
join_policy = self.get_action_join_policy(self.actions[target_id], incoming_flows)
|
||||
|
||||
if join_policy == "fail_fast":
|
||||
if failed_count > 0:
|
||||
self.skip_node(target_id, "上游失败触发 fail_fast,已取消后续节点")
|
||||
elif finished_count == total_count and satisfied_count == total_count:
|
||||
self.enqueue_node(target_id)
|
||||
elif finished_count == total_count:
|
||||
self.skip_node(target_id, "上游条件未全部满足,已跳过")
|
||||
return
|
||||
|
||||
if join_policy == "any_success":
|
||||
if satisfied_count > 0:
|
||||
self.enqueue_node(target_id)
|
||||
@@ -322,14 +548,192 @@ class WorkflowExecutor:
|
||||
根据已完成和已跳过节点数量更新整体进度。
|
||||
"""
|
||||
self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100
|
||||
self.update_runtime_state()
|
||||
|
||||
def record_node_outputs(self, action_id: str, action_result: ActionResult, result_context: ActionContext) -> None:
|
||||
def build_action_inputs(self, action: Action) -> dict:
|
||||
"""
|
||||
根据动作输入声明读取上游节点输出。
|
||||
"""
|
||||
inputs = {}
|
||||
input_paths = action.inputs or self.get_action_data_value(action, "inputs") or []
|
||||
if isinstance(input_paths, str):
|
||||
input_paths = [item.strip() for item in input_paths.splitlines() if item.strip()]
|
||||
for input_path in input_paths:
|
||||
inputs[input_path] = self.resolve_context_path(input_path)
|
||||
return inputs
|
||||
|
||||
def build_action_runtime(self, action: Action) -> dict:
|
||||
"""
|
||||
构建传递给动作的新运行期数据。
|
||||
"""
|
||||
return {
|
||||
"workflow_id": self.workflow.id,
|
||||
"action_id": action.id,
|
||||
"execution_config": self.execution_config,
|
||||
"runtime_state": self.context.runtime_state,
|
||||
}
|
||||
|
||||
def ensure_result_context_partitions(self, context: ActionContext) -> None:
|
||||
"""
|
||||
确保动作返回上下文具备新版分区字段。
|
||||
"""
|
||||
context.workflow_context = context.workflow_context or {}
|
||||
context.node_outputs = context.node_outputs or {}
|
||||
context.runtime_state = context.runtime_state or {}
|
||||
context.artifacts = context.artifacts or {}
|
||||
|
||||
def normalize_action_outputs(self, action: Action, action_result: ActionResult,
|
||||
result_context: ActionContext) -> dict:
|
||||
"""
|
||||
根据动作输出声明整理当前节点输出。
|
||||
"""
|
||||
outputs = action_result.outputs or self.extract_context_outputs(result_context)
|
||||
declared_outputs = action.outputs or self.get_action_data_value(action, "outputs")
|
||||
if isinstance(declared_outputs, list):
|
||||
return {key: outputs.get(key) for key in declared_outputs if outputs.get(key) not in (None, "", [], {})}
|
||||
if isinstance(declared_outputs, dict):
|
||||
return {
|
||||
key: outputs.get(key)
|
||||
for key in declared_outputs
|
||||
if outputs.get(key) not in (None, "", [], {})
|
||||
} or outputs
|
||||
return outputs
|
||||
|
||||
def record_node_outputs(self, action_id: str, outputs: dict) -> None:
|
||||
"""
|
||||
记录当前节点输出,供后续条件表达式读取。
|
||||
"""
|
||||
outputs = action_result.outputs or self.extract_context_outputs(result_context)
|
||||
if outputs:
|
||||
self.context.node_outputs[action_id] = outputs
|
||||
self.context.runtime_state["last_outputs"] = outputs
|
||||
|
||||
def merge_context_partitions(self, context: ActionContext) -> None:
|
||||
"""
|
||||
合并动作返回的新分区上下文。
|
||||
"""
|
||||
for key in ("workflow_context", "runtime_state", "artifacts"):
|
||||
value = getattr(context, key, None)
|
||||
if not value:
|
||||
continue
|
||||
current_value = getattr(self.context, key, None) or {}
|
||||
current_value.update(value)
|
||||
setattr(self.context, key, current_value)
|
||||
|
||||
def merge_action_outputs(self, action: Action, outputs: dict) -> None:
|
||||
"""
|
||||
按声明式合并策略写入全局上下文和 artifacts 分区。
|
||||
"""
|
||||
for key, value in outputs.items():
|
||||
if value in (None, "", [], {}):
|
||||
continue
|
||||
output_config = self.get_action_output_config(action, key)
|
||||
target_key = output_config.get("target") or key
|
||||
merge_policy = output_config.get("merge") or self.get_default_merge_policy(action, target_key, value)
|
||||
identity = output_config.get("identity")
|
||||
self.merge_output_value(target_key, value, merge_policy, identity)
|
||||
|
||||
def merge_output_value(self, key: str, value: Any, merge_policy: str, identity: Optional[str] = None) -> None:
|
||||
"""
|
||||
按指定策略合并单个输出值。
|
||||
"""
|
||||
current_value = getattr(self.context, key, None) if key in ActionContext.model_fields else None
|
||||
merged_value = self.apply_merge_policy(current_value, value, merge_policy, identity)
|
||||
if key in ActionContext.model_fields:
|
||||
setattr(self.context, key, merged_value)
|
||||
if key in ARTIFACT_FIELDS:
|
||||
current_artifact = self.context.artifacts.get(key)
|
||||
self.context.artifacts[key] = self.apply_merge_policy(current_artifact, value, merge_policy, identity)
|
||||
|
||||
def get_action_output_config(self, action: Action, output_key: str) -> dict:
|
||||
"""
|
||||
获取动作输出声明配置。
|
||||
"""
|
||||
outputs_config = action.outputs or self.get_action_data_value(action, "outputs") or {}
|
||||
if isinstance(outputs_config, dict):
|
||||
value = outputs_config.get(output_key) or {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
return {}
|
||||
|
||||
def get_default_merge_policy(self, action: Action, key: str, value: Any) -> str:
|
||||
"""
|
||||
获取输出默认合并策略。
|
||||
"""
|
||||
if action.type in ("FilterTorrentsAction", "FilterMediasAction", "FetchDownloadsAction"):
|
||||
return "replace"
|
||||
if isinstance(value, list):
|
||||
return "append_unique"
|
||||
if isinstance(value, dict):
|
||||
return "merge_dict"
|
||||
return "first_non_empty"
|
||||
|
||||
def apply_merge_policy(self, current_value: Any, value: Any, merge_policy: str,
|
||||
identity: Optional[str] = None) -> Any:
|
||||
"""
|
||||
应用声明式合并策略。
|
||||
"""
|
||||
if merge_policy == "replace":
|
||||
return value
|
||||
if merge_policy == "merge_dict":
|
||||
merged = current_value.copy() if isinstance(current_value, dict) else {}
|
||||
if isinstance(value, dict):
|
||||
merged.update(value)
|
||||
return merged
|
||||
return current_value or value
|
||||
if merge_policy == "append_unique":
|
||||
return self.append_unique_values(current_value, value, identity)
|
||||
if merge_policy == "first_non_empty":
|
||||
return current_value or value
|
||||
return current_value or value
|
||||
|
||||
def append_unique_values(self, current_value: Any, value: Any, identity: Optional[str] = None) -> list:
|
||||
"""
|
||||
追加列表并按身份字段去重。
|
||||
"""
|
||||
current_list = list(current_value or [])
|
||||
incoming_list = value if isinstance(value, list) else [value]
|
||||
seen = {self.get_identity_value(item, identity) for item in current_list}
|
||||
for item in incoming_list:
|
||||
identity_value = self.get_identity_value(item, identity)
|
||||
if identity_value in seen:
|
||||
continue
|
||||
current_list.append(item)
|
||||
seen.add(identity_value)
|
||||
return current_list
|
||||
|
||||
def get_identity_value(self, item: Any, identity: Optional[str] = None) -> Any:
|
||||
"""
|
||||
获取列表元素去重身份。
|
||||
"""
|
||||
if not identity:
|
||||
identity_value = self.make_json_safe(item)
|
||||
return self.make_hashable_identity(identity_value)
|
||||
value = item
|
||||
for part in identity.split("."):
|
||||
value = self.read_value(value, int(part) if part.isdigit() else part)
|
||||
return self.make_hashable_identity(self.make_json_safe(value))
|
||||
|
||||
@staticmethod
|
||||
def make_hashable_identity(value: Any) -> Any:
|
||||
"""
|
||||
将身份值转换为可哈希对象。
|
||||
"""
|
||||
try:
|
||||
hash(value)
|
||||
return value
|
||||
except TypeError:
|
||||
return repr(value)
|
||||
|
||||
def call_step_callback(self, action: Action, completed: bool) -> None:
|
||||
"""
|
||||
持久化当前步骤上下文和结构化执行状态。
|
||||
"""
|
||||
if not self.step_callback:
|
||||
return
|
||||
callback_params = inspect.signature(self.step_callback).parameters
|
||||
if len(callback_params) <= 2:
|
||||
self.step_callback(action, self.context)
|
||||
return
|
||||
self.step_callback(action, self.context, self.build_execution_state(), completed)
|
||||
|
||||
@staticmethod
|
||||
def extract_context_outputs(context: ActionContext) -> dict:
|
||||
@@ -340,7 +744,7 @@ class WorkflowExecutor:
|
||||
return {}
|
||||
outputs = {}
|
||||
for key in context.__class__.model_fields:
|
||||
if key in ("execute_history", "progress", "node_outputs"):
|
||||
if key in ("execute_history", "progress", "node_outputs", "runtime_state"):
|
||||
continue
|
||||
value = getattr(context, key, None)
|
||||
if value in (None, "", [], {}):
|
||||
@@ -368,12 +772,32 @@ class WorkflowExecutor:
|
||||
return join_policy
|
||||
return "all_success"
|
||||
|
||||
def get_action_branch_policy(self, action: Optional[Action], outgoing_flows: List[ActionFlow]) -> str:
|
||||
"""
|
||||
获取动作出边分支策略。
|
||||
"""
|
||||
if action:
|
||||
branch_policy = action.branch_policy or self.get_action_data_value(action, "branch_policy")
|
||||
if branch_policy:
|
||||
return branch_policy
|
||||
for flow in outgoing_flows:
|
||||
branch_policy = flow.branch_policy or self.get_flow_data_value(flow, "branch_policy")
|
||||
if branch_policy:
|
||||
return branch_policy
|
||||
return "parallel"
|
||||
|
||||
def get_action_fail_policy(self, action: Action) -> str:
|
||||
"""
|
||||
获取动作失败策略。
|
||||
"""
|
||||
return action.fail_policy or self.get_action_data_value(action, "fail_policy") or "stop"
|
||||
|
||||
def get_action_concurrency_key(self, action: Action) -> Optional[str]:
|
||||
"""
|
||||
获取动作并发互斥键。
|
||||
"""
|
||||
return action.concurrency_key or self.get_action_data_value(action, "concurrency_key")
|
||||
|
||||
def get_flow_condition(self, flow: ActionFlow) -> Optional[str]:
|
||||
"""
|
||||
获取流程边条件表达式。
|
||||
@@ -381,10 +805,12 @@ class WorkflowExecutor:
|
||||
return flow.condition or self.get_flow_data_value(flow, "condition")
|
||||
|
||||
@staticmethod
|
||||
def get_action_data_value(action: Action, key: str) -> Any:
|
||||
def get_action_data_value(action: Optional[Action], key: str) -> Any:
|
||||
"""
|
||||
从动作 data 中读取扩展配置。
|
||||
"""
|
||||
if not action:
|
||||
return None
|
||||
data = action.data or {}
|
||||
return data.get(key) if isinstance(data, dict) else None
|
||||
|
||||
@@ -481,8 +907,18 @@ class WorkflowExecutor:
|
||||
return None
|
||||
if name == "context":
|
||||
return self.context
|
||||
if name == "workflow_context":
|
||||
return self.context.workflow_context or {}
|
||||
if name == "runtime_state":
|
||||
return self.context.runtime_state or {}
|
||||
if name == "artifacts":
|
||||
return self.context.artifacts or {}
|
||||
if name in ("outputs", "node_outputs"):
|
||||
return self.context.node_outputs or {}
|
||||
if name == "last":
|
||||
return self.context.runtime_state.get("last_outputs") if self.context.runtime_state else {}
|
||||
if name in (self.context.node_outputs or {}):
|
||||
return self.context.node_outputs[name]
|
||||
if name in ActionContext.model_fields:
|
||||
return getattr(self.context, name, None)
|
||||
raise ValueError(f"未知上下文变量 {name}")
|
||||
@@ -598,7 +1034,7 @@ class WorkflowChain(ChainBase):
|
||||
"""
|
||||
workflowoper = WorkflowOper()
|
||||
|
||||
def save_step(action: Action, context: ActionContext):
|
||||
def save_step(action: Action, context: ActionContext, execution_state: dict, completed: bool):
|
||||
"""
|
||||
保存上下文到数据库
|
||||
"""
|
||||
@@ -606,9 +1042,14 @@ class WorkflowChain(ChainBase):
|
||||
serialized_data = pickle.dumps(context)
|
||||
# 使用Base64编码字节流
|
||||
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
|
||||
WorkflowOper().step(workflow_id, action_id=action.id, context={
|
||||
"content": encoded_data
|
||||
})
|
||||
WorkflowOper().step(
|
||||
workflow_id,
|
||||
action_id=action.id if completed else "",
|
||||
context={
|
||||
"content": encoded_data
|
||||
},
|
||||
execution_state=execution_state
|
||||
)
|
||||
|
||||
# 重置工作流
|
||||
if from_begin:
|
||||
|
||||
@@ -40,6 +40,10 @@ class Workflow(Base):
|
||||
flows = Column(JSON, default=builtin_list)
|
||||
# 执行上下文
|
||||
context = Column(JSON, default=dict)
|
||||
# 执行配置
|
||||
execution_config = Column(JSON, default=dict)
|
||||
# 结构化执行状态
|
||||
execution_state = Column(JSON, default=dict)
|
||||
# 创建时间
|
||||
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 最后执行时间
|
||||
@@ -218,6 +222,7 @@ class Workflow(Base):
|
||||
"result": None,
|
||||
"current_action": None,
|
||||
"context": {},
|
||||
"execution_state": {},
|
||||
"run_count": 0 if reset_count else cls.run_count,
|
||||
})
|
||||
return True
|
||||
@@ -231,39 +236,48 @@ class Workflow(Base):
|
||||
result=None,
|
||||
current_action=None,
|
||||
context={},
|
||||
execution_state={},
|
||||
run_count=0 if reset_count else cls.run_count,
|
||||
))
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
|
||||
def update_current_action(cls, db, wid: int, action_id: str, context: dict,
|
||||
execution_state: Optional[dict] = None):
|
||||
workflow = db.query(cls).filter(cls.id == wid).first()
|
||||
current_actions = []
|
||||
if workflow and workflow.current_action:
|
||||
current_actions = [item for item in workflow.current_action.split(",") if item]
|
||||
if action_id not in current_actions:
|
||||
if action_id and action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
db.query(cls).filter(cls.id == wid).update({
|
||||
update_values = {
|
||||
"current_action": ",".join(current_actions),
|
||||
"context": context
|
||||
})
|
||||
}
|
||||
if execution_state is not None:
|
||||
update_values["execution_state"] = execution_state
|
||||
db.query(cls).filter(cls.id == wid).update(update_values)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
|
||||
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict,
|
||||
execution_state: Optional[dict] = None):
|
||||
from sqlalchemy import update
|
||||
# 先获取当前current_action
|
||||
result = await db.execute(select(cls.current_action).where(cls.id == wid))
|
||||
current_action = result.scalar()
|
||||
current_actions = [item for item in (current_action or "").split(",") if item]
|
||||
if action_id not in current_actions:
|
||||
if action_id and action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
new_current_action = ",".join(current_actions)
|
||||
|
||||
await db.execute(update(cls).where(cls.id == wid).values(
|
||||
current_action=new_current_action,
|
||||
context=context
|
||||
))
|
||||
update_values = {
|
||||
"current_action": new_current_action,
|
||||
"context": context
|
||||
}
|
||||
if execution_state is not None:
|
||||
update_values["execution_state"] = execution_state
|
||||
await db.execute(update(cls).where(cls.id == wid).values(**update_values))
|
||||
return True
|
||||
|
||||
@@ -91,11 +91,17 @@ class WorkflowOper(DbOper):
|
||||
"""
|
||||
return Workflow.fail(self._db, wid, result)
|
||||
|
||||
def step(self, wid: int, action_id: str, context: dict) -> bool:
|
||||
def step(self, wid: int, action_id: str, context: dict, execution_state: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
步进
|
||||
"""
|
||||
return Workflow.update_current_action(self._db, wid, action_id, context)
|
||||
return Workflow.update_current_action(
|
||||
self._db,
|
||||
wid,
|
||||
action_id,
|
||||
context,
|
||||
execution_state=execution_state
|
||||
)
|
||||
|
||||
def reset(self, wid: int, reset_count: bool = False) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.context import Context, MediaInfo
|
||||
from app.schemas.download import DownloadTask
|
||||
@@ -26,6 +26,8 @@ class Workflow(BaseModel):
|
||||
run_count: Optional[int] = Field(default=0, description="已执行次数")
|
||||
actions: Optional[list] = Field(default_factory=list, description="任务列表")
|
||||
flows: Optional[list] = Field(default_factory=list, description="任务流")
|
||||
execution_config: Optional[dict] = Field(default_factory=dict, description="工作流执行配置")
|
||||
execution_state: Optional[dict] = Field(default_factory=dict, description="工作流结构化执行状态")
|
||||
add_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
last_time: Optional[str] = Field(default=None, description="最后执行时间")
|
||||
|
||||
@@ -50,8 +52,14 @@ class Action(BaseModel):
|
||||
description: Optional[str] = Field(default=None, description="动作描述")
|
||||
position: Optional[dict] = Field(default_factory=dict, description="位置")
|
||||
data: Optional[dict] = Field(default_factory=dict, description="参数")
|
||||
inputs: Optional[List[str]] = Field(default_factory=list, description="动作输入声明")
|
||||
outputs: Optional[dict] = Field(default_factory=dict, description="动作输出声明")
|
||||
join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略")
|
||||
fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略")
|
||||
branch_policy: Optional[str] = Field(default=None, description="多出边分支策略")
|
||||
concurrency_key: Optional[str] = Field(default=None, description="并发互斥键")
|
||||
timeout: Optional[int] = Field(default=None, description="动作执行超时时间(秒)")
|
||||
retry: Optional[dict] = Field(default_factory=dict, description="动作重试策略")
|
||||
|
||||
|
||||
class ActionExecution(BaseModel):
|
||||
@@ -74,7 +82,10 @@ class ActionContext(BaseModel):
|
||||
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
|
||||
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
|
||||
workflow_context: Optional[dict] = Field(default_factory=dict, description="工作流全局上下文")
|
||||
node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据")
|
||||
runtime_state: Optional[dict] = Field(default_factory=dict, description="运行期状态")
|
||||
artifacts: Optional[dict] = Field(default_factory=dict, description="大对象引用与产物数据")
|
||||
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
|
||||
progress: Optional[int] = Field(default=0, description="执行进度(%)")
|
||||
|
||||
@@ -87,6 +98,8 @@ class ActionResult(BaseModel):
|
||||
message: Optional[str] = Field(default=None, description="动作执行消息")
|
||||
context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文")
|
||||
outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出")
|
||||
next_policy: Optional[str] = Field(default=None, description="动作完成后的调度策略")
|
||||
attempts: Optional[int] = Field(default=1, description="动作实际尝试次数")
|
||||
|
||||
|
||||
class ActionFlow(BaseModel):
|
||||
@@ -100,6 +113,7 @@ class ActionFlow(BaseModel):
|
||||
data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置")
|
||||
condition: Optional[str] = Field(default=None, description="流转条件表达式")
|
||||
join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略")
|
||||
branch_policy: Optional[str] = Field(default=None, description="源节点分支策略")
|
||||
|
||||
|
||||
class WorkflowShare(BaseModel):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import threading
|
||||
from time import sleep
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import List, Tuple
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import eventmanager, Event
|
||||
@@ -68,48 +67,54 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
self._actions = {}
|
||||
self._event_workflows = {}
|
||||
|
||||
def execute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> ActionResult:
|
||||
def execute(self, workflow_id: int, action: Action, context: ActionContext = None,
|
||||
inputs: Optional[dict] = None, runtime: Optional[dict] = None,
|
||||
cancel_token: Optional[Any] = None) -> ActionResult:
|
||||
"""
|
||||
执行工作流动作
|
||||
"""
|
||||
if not context:
|
||||
context = ActionContext()
|
||||
if action.type in self._actions:
|
||||
# 实例化之前,清理掉类对象的数据
|
||||
|
||||
# 实例化
|
||||
action_obj = self._actions[action.type](action.id)
|
||||
# 执行
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
try:
|
||||
result_context = action_obj.execute(workflow_id, action.data, context)
|
||||
action_result = self._normalize_action_result(result_context, action_obj, context)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return ActionResult(success=False, message=f"{err}", context=context)
|
||||
loop = (action.data or {}).get("loop")
|
||||
loop_interval = (action.data or {}).get("loop_interval")
|
||||
if loop and loop_interval:
|
||||
while not action_obj.done:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
# 等待
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
sleep(loop_interval)
|
||||
# 执行
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
result_context = action_obj.execute(workflow_id, action.data, action_result.context)
|
||||
action_result = self._normalize_action_result(result_context, action_obj, action_result.context)
|
||||
if action_result.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
else:
|
||||
logger.error(f"{action.name} 执行失败!")
|
||||
return action_result
|
||||
else:
|
||||
if action.type not in self._actions:
|
||||
logger.error(f"未找到动作: {action.type} - {action.name}")
|
||||
return ActionResult(success=False, message=" ", context=context)
|
||||
|
||||
retry_config = self._get_retry_config(action)
|
||||
max_attempts = retry_config["max_attempts"]
|
||||
interval = retry_config["interval"]
|
||||
backoff = retry_config["backoff"]
|
||||
action_result = ActionResult(success=False, message="", context=context)
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=context)
|
||||
runtime_data = {
|
||||
**(runtime or {}),
|
||||
"attempt": attempt,
|
||||
"max_attempts": max_attempts,
|
||||
"cancel_token": cancel_token,
|
||||
}
|
||||
action_result = self._execute_action_once(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
context=context,
|
||||
inputs=inputs or {},
|
||||
runtime=runtime_data,
|
||||
cancel_token=cancel_token
|
||||
)
|
||||
action_result.attempts = attempt
|
||||
context = action_result.context or context
|
||||
if action_result.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
return action_result
|
||||
if attempt < max_attempts and not self._is_cancelled(workflow_id, cancel_token):
|
||||
wait_seconds = interval * (backoff ** (attempt - 1))
|
||||
logger.info(f"{action.name} 执行失败,{wait_seconds} 秒后重试({attempt}/{max_attempts})...")
|
||||
self._sleep_with_cancel(workflow_id, wait_seconds, cancel_token)
|
||||
|
||||
logger.error(f"{action.name} 执行失败!")
|
||||
return action_result
|
||||
|
||||
def excute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
"""
|
||||
@@ -134,6 +139,100 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
context=result or fallback_context
|
||||
)
|
||||
|
||||
def _execute_action_once(self, workflow_id: int, action: Action, context: ActionContext,
|
||||
inputs: dict, runtime: dict, cancel_token: Optional[Any]) -> ActionResult:
|
||||
action_obj = self._actions[action.type](action.id)
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
try:
|
||||
action_result = self._run_action_with_loop(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=context,
|
||||
inputs=inputs,
|
||||
runtime=runtime,
|
||||
cancel_token=cancel_token
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return ActionResult(success=False, message=f"{err}", context=context)
|
||||
return action_result
|
||||
|
||||
def _run_action_with_loop(self, workflow_id: int, action: Action, action_obj: Any,
|
||||
context: ActionContext, inputs: dict, runtime: dict,
|
||||
cancel_token: Optional[Any]) -> ActionResult:
|
||||
timeout = self._get_action_timeout(action)
|
||||
started_at = monotonic()
|
||||
action_result = self._call_action(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=context,
|
||||
inputs=inputs,
|
||||
runtime=runtime
|
||||
)
|
||||
loop = self._get_action_data_value(action, "loop")
|
||||
loop_interval = self._get_action_data_value(action, "loop_interval")
|
||||
while loop and loop_interval and not action_obj.done:
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
|
||||
if timeout and monotonic() - started_at >= timeout:
|
||||
return ActionResult(success=False, message=f"动作执行超时({timeout}秒)", context=action_result.context or context)
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
self._sleep_with_cancel(workflow_id, loop_interval, cancel_token)
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
action_result = self._call_action(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=action_result.context or context,
|
||||
inputs=inputs,
|
||||
runtime=runtime
|
||||
)
|
||||
return action_result
|
||||
|
||||
def _call_action(self, workflow_id: int, action: Action, action_obj: Any,
|
||||
context: ActionContext, inputs: dict, runtime: dict) -> ActionResult:
|
||||
if hasattr(action_obj, "execute_with_inputs"):
|
||||
result = action_obj.execute_with_inputs(workflow_id, action.data, inputs, runtime, context)
|
||||
else:
|
||||
result = action_obj.execute(workflow_id, action.data, context)
|
||||
return self._normalize_action_result(result, action_obj, context)
|
||||
|
||||
@staticmethod
|
||||
def _get_action_data_value(action: Action, key: str) -> Any:
|
||||
data = action.data or {}
|
||||
return data.get(key) if isinstance(data, dict) else None
|
||||
|
||||
def _get_action_timeout(self, action: Action) -> Optional[int]:
|
||||
timeout = action.timeout or self._get_action_data_value(action, "timeout")
|
||||
return int(timeout) if timeout else None
|
||||
|
||||
def _get_retry_config(self, action: Action) -> dict:
|
||||
retry_config = action.retry or self._get_action_data_value(action, "retry") or {}
|
||||
if not isinstance(retry_config, dict):
|
||||
retry_config = {}
|
||||
return {
|
||||
"max_attempts": max(int(retry_config.get("max_attempts") or 1), 1),
|
||||
"interval": max(float(retry_config.get("interval") or 0), 0),
|
||||
"backoff": max(float(retry_config.get("backoff") or 1), 1),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_cancelled(workflow_id: int, cancel_token: Optional[Any]) -> bool:
|
||||
if cancel_token and cancel_token.is_cancelled():
|
||||
return True
|
||||
return global_vars.is_workflow_stopped(workflow_id)
|
||||
|
||||
def _sleep_with_cancel(self, workflow_id: int, seconds: float, cancel_token: Optional[Any]) -> None:
|
||||
deadline = monotonic() + seconds
|
||||
while monotonic() < deadline:
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return
|
||||
sleep(min(0.1, deadline - monotonic()))
|
||||
|
||||
def list_actions(self) -> List[dict]:
|
||||
"""
|
||||
获取所有动作
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.schemas import ActionContext, ActionParams
|
||||
from app.schemas import ActionContext, ActionParams, ActionResult
|
||||
|
||||
|
||||
class ActionChain(ChainBase):
|
||||
@@ -109,3 +109,16 @@ class BaseAction(ABC):
|
||||
执行动作
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
|
||||
runtime: dict, context: ActionContext) -> ActionResult:
|
||||
"""
|
||||
使用显式输入与运行期信息执行动作。
|
||||
"""
|
||||
_ = inputs, runtime
|
||||
result_context = self.execute(workflow_id, params, context)
|
||||
return ActionResult(
|
||||
success=self.success,
|
||||
message=self.message,
|
||||
context=result_context
|
||||
)
|
||||
|
||||
45
database/versions/7c1a2b3d4e5f_2_2_9.py
Normal file
45
database/versions/7c1a2b3d4e5f_2_2_9.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""2.2.9
|
||||
为工作流增加执行配置和结构化执行状态
|
||||
|
||||
Revision ID: 7c1a2b3d4e5f
|
||||
Revises: d5e6f7a8b9c0
|
||||
Create Date: 2026-06-04
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7c1a2b3d4e5f"
|
||||
down_revision = "d5e6f7a8b9c0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
|
||||
"""检查数据表是否已存在指定列。"""
|
||||
if table_name not in inspector.get_table_names():
|
||||
return False
|
||||
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
|
||||
|
||||
|
||||
def _add_json_column_if_missing(table_name: str, column_name: str) -> None:
|
||||
"""缺失时为数据表新增 JSON 列。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if not _has_column(inspector, table_name, column_name):
|
||||
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""升级数据库结构。"""
|
||||
_add_json_column_if_missing("workflow", "execution_config")
|
||||
_add_json_column_if_missing("workflow", "execution_state")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚数据库结构。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "workflow", "execution_state"):
|
||||
op.drop_column("workflow", "execution_state")
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "workflow", "execution_config"):
|
||||
op.drop_column("workflow", "execution_config")
|
||||
@@ -1,28 +1,32 @@
|
||||
import base64
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.chain import workflow as workflow_module
|
||||
from app.schemas import ActionContext, ActionResult
|
||||
from app.schemas import Action, ActionContext, ActionResult
|
||||
from app.schemas.types import EventType
|
||||
from app import workflow as workflow_package
|
||||
|
||||
|
||||
def _build_workflow(current_action=None, context=None, actions=None, flows=None):
|
||||
def _build_workflow(current_action=None, context=None, actions=None, flows=None,
|
||||
execution_config=None, execution_state=None):
|
||||
"""构造最小工作流对象。"""
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
name="测试工作流",
|
||||
actions=actions or [
|
||||
actions=actions if actions is not None else [
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
],
|
||||
flows=flows or [
|
||||
flows=flows if flows is not None else [
|
||||
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
|
||||
],
|
||||
current_action=current_action,
|
||||
context=context,
|
||||
execution_config=execution_config or {},
|
||||
execution_state=execution_state or {},
|
||||
)
|
||||
|
||||
|
||||
@@ -39,9 +43,12 @@ class _FakeWorkflowManager:
|
||||
def __init__(self, calls, results=None):
|
||||
self.calls = calls
|
||||
self.results = results or {}
|
||||
self.received_inputs = []
|
||||
|
||||
def execute(self, workflow_id, action, context=None):
|
||||
def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None):
|
||||
"""执行伪动作并记录新版输入。"""
|
||||
self.calls.append(action.id)
|
||||
self.received_inputs.append((action.id, inputs or {}, runtime or {}, cancel_token))
|
||||
result = self.results.get(action.id)
|
||||
if callable(result):
|
||||
return result(action, context or ActionContext())
|
||||
@@ -262,6 +269,220 @@ def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch)
|
||||
assert executor.success is True
|
||||
|
||||
|
||||
def test_workflow_executor_exclusive_branch_uses_first_matching_flow(monkeypatch):
|
||||
"""互斥分支应只执行第一条满足条件的出边。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"count": 2}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"branch_policy": "exclusive"}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.count > 0"},
|
||||
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.count > 1"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "B"]
|
||||
assert executor.node_states["C"] == "skipped"
|
||||
|
||||
|
||||
def test_workflow_executor_passes_declared_inputs(monkeypatch):
|
||||
"""动作输入声明应从 node_outputs 中读取指定路径。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["a", "b"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{
|
||||
"id": "B",
|
||||
"type": "FakeAction",
|
||||
"name": "动作B",
|
||||
"data": {"inputs": ["A.torrents", "outputs.A.torrents.count"]},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
|
||||
assert b_inputs == {
|
||||
"A.torrents": ["a", "b"],
|
||||
"outputs.A.torrents.count": 2,
|
||||
}
|
||||
|
||||
|
||||
def test_workflow_executor_persists_structured_state(monkeypatch):
|
||||
"""步骤回调应收到可持久化的结构化执行状态。"""
|
||||
calls = []
|
||||
states = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"items": ["movie"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(
|
||||
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
|
||||
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
|
||||
)
|
||||
executor.execute()
|
||||
|
||||
assert states[-1]["nodes"]["A"]["state"] == "success"
|
||||
assert states[-1]["outputs"]["A"]["items"] == ["movie"]
|
||||
assert states[-1]["runtime"]["progress"] == 100
|
||||
|
||||
|
||||
def test_workflow_executor_restores_outputs_from_execution_state(monkeypatch):
|
||||
"""恢复执行时应从结构化状态读取节点输出并继续判断条件边。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
execution_state={
|
||||
"nodes": {
|
||||
"A": {"state": "success", "attempt": 1},
|
||||
},
|
||||
"outputs": {
|
||||
"A": {"torrents": ["movie"]},
|
||||
},
|
||||
},
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "A.torrents.count > 0"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.context.node_outputs["A"]["torrents"] == ["movie"]
|
||||
|
||||
|
||||
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
|
||||
"""相同 concurrency_key 的并行节点不应同时运行。"""
|
||||
calls = []
|
||||
active_count = 0
|
||||
max_active_count = 0
|
||||
lock = threading.Lock()
|
||||
|
||||
def run_action(action, context):
|
||||
"""记录同一并发键下的同时运行数量。"""
|
||||
nonlocal active_count, max_active_count
|
||||
with lock:
|
||||
active_count += 1
|
||||
max_active_count = max(max_active_count, active_count)
|
||||
time.sleep(0.05)
|
||||
with lock:
|
||||
active_count -= 1
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context)
|
||||
|
||||
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action, "B": run_action})
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"concurrency_key": "download"}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {"concurrency_key": "download"}},
|
||||
],
|
||||
flows=[],
|
||||
execution_config={"max_workers": 2},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert set(calls) == {"A", "B"}
|
||||
assert max_active_count == 1
|
||||
|
||||
|
||||
def test_workflow_executor_filter_action_replaces_artifact_outputs(monkeypatch):
|
||||
"""过滤类动作默认应替换列表输出,避免把过滤前数据重新合并回来。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["old", "keep"]}
|
||||
),
|
||||
"B": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["keep"]}
|
||||
),
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FilterTorrentsAction", "name": "过滤资源", "data": {}},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert executor.context.torrents == ["keep"]
|
||||
assert executor.context.artifacts["torrents"] == ["keep"]
|
||||
|
||||
|
||||
def test_workflow_executor_stop_is_not_success(monkeypatch):
|
||||
"""停止信号不应被执行器汇报为成功完成。"""
|
||||
calls = []
|
||||
@@ -328,3 +549,43 @@ def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkey
|
||||
|
||||
assert fake_eventmanager.removed == [EventType.DownloadAdded]
|
||||
assert manager.get_event_workflows() == {}
|
||||
|
||||
|
||||
def test_workflow_manager_retries_action_until_success(monkeypatch):
|
||||
"""动作管理器应按 retry 配置重试失败动作。"""
|
||||
|
||||
class RetryAction:
|
||||
"""模拟第二次才成功的动作。"""
|
||||
|
||||
call_count = 0
|
||||
|
||||
def __init__(self, action_id):
|
||||
self.action_id = action_id
|
||||
|
||||
def execute_with_inputs(self, workflow_id, params, inputs, runtime, context):
|
||||
"""执行动作并在第二次返回成功。"""
|
||||
_ = workflow_id, params, inputs, runtime
|
||||
RetryAction.call_count += 1
|
||||
if RetryAction.call_count == 1:
|
||||
return ActionResult(success=False, message="第一次失败", context=context)
|
||||
return ActionResult(success=True, message="第二次成功", context=context, outputs={"ok": True})
|
||||
|
||||
manager = object.__new__(workflow_package.WorkFlowManager)
|
||||
manager._actions = {"RetryAction": RetryAction}
|
||||
monkeypatch.setattr(workflow_package.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
result = manager.execute(
|
||||
workflow_id=1,
|
||||
action=Action(
|
||||
id="retry",
|
||||
type="RetryAction",
|
||||
name="重试动作",
|
||||
data={"retry": {"max_attempts": 2, "interval": 0}},
|
||||
),
|
||||
context=ActionContext(),
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.attempts == 2
|
||||
assert result.outputs == {"ok": True}
|
||||
assert RetryAction.call_count == 2
|
||||
|
||||
Reference in New Issue
Block a user