From 9056caae404daae1f4918a33b81f9d788b50c356 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 4 Jun 2026 14:10:06 +0800 Subject: [PATCH] feat(workflow): enhance workflow execution and context management --- app/api/endpoints/workflow.py | 54 ++++--- app/chain/workflow.py | 182 +++++++++++++++++------- app/db/models/workflow.py | 21 ++- app/schemas/workflow.py | 24 ++-- app/workflow/__init__.py | 58 +++++--- app/workflow/actions/__init__.py | 9 +- app/workflow/actions/fetch_downloads.py | 13 +- app/workflow/actions/scrape_file.py | 8 +- tests/test_workflow_actions.py | 80 +++++++++++ tests/test_workflow_execution.py | 153 ++++++++++++++++++++ 10 files changed, 483 insertions(+), 119 deletions(-) create mode 100644 tests/test_workflow_actions.py create mode 100644 tests/test_workflow_execution.py diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index 15d3b642..5c3351a4 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -22,6 +22,10 @@ from app.schemas.types import EventType, EVENT_TYPE_NAMES router = APIRouter() +WORKFLOW_TRIGGER_TIMER = "timer" +WORKFLOW_TRIGGER_EVENT = "event" +WORKFLOW_TRIGGER_MANUAL = "manual" + @router.get("/", summary="所有工作流", response_model=List[schemas.Workflow]) async def list_workflows( @@ -148,16 +152,20 @@ async def workflow_fork( except json.JSONDecodeError: return schemas.Response(success=False, message="context字段JSON格式错误") + try: + event_conditions = json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {} + except json.JSONDecodeError: + return schemas.Response(success=False, message="event_conditions字段JSON格式错误") + + share_id = workflow.id # 创建工作流 workflow_dict = { "name": workflow.name, "description": workflow.description, "timer": workflow.timer, - "trigger_type": workflow.trigger_type or "timer", + "trigger_type": workflow.trigger_type or WORKFLOW_TRIGGER_TIMER, "event_type": workflow.event_type, - "event_conditions": json.loads(workflow.event_conditions or "{}") - if workflow.event_conditions - else {}, + "event_conditions": event_conditions, "actions": actions, "flows": flows, "context": context, @@ -170,11 +178,11 @@ async def workflow_fork( return schemas.Response(success=False, message="已存在相同名称的工作流") # 创建新工作流 - workflow = await Workflow(**workflow_dict).async_create(db) + workflow_obj = await Workflow(**workflow_dict).async_create(db) # 更新复用次数 - if workflow: - await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=workflow.id) + if workflow_obj and share_id: + await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=share_id) return schemas.Response(success=True, message="复用成功") @@ -225,14 +233,23 @@ def start_workflow( workflow = WorkflowOper(db).get(workflow_id) if not workflow: return schemas.Response(success=False, message="工作流不存在") - if not workflow.trigger_type or workflow.trigger_type == "timer": + trigger_type = workflow.trigger_type or WORKFLOW_TRIGGER_TIMER + if trigger_type == WORKFLOW_TRIGGER_TIMER and not workflow.timer: + return schemas.Response(success=False, message="定时工作流缺少定时器配置") + if trigger_type not in { + WORKFLOW_TRIGGER_TIMER, + WORKFLOW_TRIGGER_EVENT, + WORKFLOW_TRIGGER_MANUAL, + }: + return schemas.Response(success=False, message="工作流触发类型不支持") + # 先更新状态,事件触发注册会重新读取工作流并跳过暂停状态。 + workflow.update_state(db, workflow_id, "W") + if trigger_type == WORKFLOW_TRIGGER_TIMER: # 添加定时任务 Scheduler().update_workflow_job(workflow) - else: + elif trigger_type == WORKFLOW_TRIGGER_EVENT: # 事件触发:添加到事件触发器 WorkFlowManager().load_workflow_events(workflow_id) - # 更新状态 - workflow.update_state(db, workflow_id, "W") return schemas.Response(success=True) @@ -251,10 +268,10 @@ def pause_workflow( if not workflow: return schemas.Response(success=False, message="工作流不存在") # 根据触发类型进行不同处理 - if workflow.trigger_type == "timer": + if workflow.trigger_type == WORKFLOW_TRIGGER_TIMER: # 定时触发:移除定时任务 Scheduler().remove_workflow_job(workflow) - elif workflow.trigger_type == "event": + elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT: # 事件触发:从事件触发器中移除 WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type) # 停止工作流 @@ -319,8 +336,11 @@ def update_workflow( wf.update(db, workflow.model_dump()) # 更新后的工作流对象 updated_workflow = workflow_oper.get(workflow.id) - # 更新定时任务 - Scheduler().update_workflow_job(updated_workflow) + scheduler = Scheduler() + scheduler.remove_workflow_job(updated_workflow) + if not updated_workflow.trigger_type or updated_workflow.trigger_type == WORKFLOW_TRIGGER_TIMER: + if updated_workflow.timer: + scheduler.update_workflow_job(updated_workflow) # 更新事件注册 WorkFlowManager().update_workflow_event(updated_workflow) return schemas.Response(success=True, message="更新成功") @@ -338,10 +358,10 @@ def delete_workflow( workflow = WorkflowOper(db).get(workflow_id) if not workflow: return schemas.Response(success=False, message="工作流不存在") - if not workflow.trigger_type or workflow.trigger_type == "timer": + if not workflow.trigger_type or workflow.trigger_type == WORKFLOW_TRIGGER_TIMER: # 定时触发:删除定时任务 Scheduler().remove_workflow_job(workflow) - else: + elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT: # 事件触发:从事件触发器中移除 WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type) # 删除工作流 diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 638cff09..45a4e0dd 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -1,22 +1,21 @@ import base64 +import copy import pickle import threading from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from time import sleep -from typing import List, Tuple, Optional - -from pydantic.fields import Callable +from typing import Callable, List, Optional, Tuple from app.chain import ChainBase from app.core.config import global_vars from app.core.event import Event, eventmanager -from app.workflow import WorkFlowManager from app.db.models import Workflow from app.db.workflow_oper import WorkflowOper from app.log import logger from app.schemas import ActionContext, ActionFlow, Action, ActionExecution from app.schemas.types import EventType +from app.workflow import WorkFlowManager class WorkflowExecutor: @@ -36,9 +35,14 @@ class WorkflowExecutor: self.actions = {action['id']: Action(**action) for action in workflow.actions} self.flows = [ActionFlow(**flow) for flow in workflow.flows] self.total_actions = len(self.actions) - self.finished_actions = 0 + self.completed_actions = { + action_id for action_id in (workflow.current_action or "").split(",") + if action_id in self.actions + } + self.finished_actions = len(self.completed_actions) self.success = True + self.stopped = False self.errmsg = "" # 工作流管理器 @@ -66,6 +70,10 @@ class WorkflowExecutor: if action_id not in self.indegree: self.indegree[action_id] = 0 + for action_id in self.completed_actions: + for succ_id in self.adjacency.get(action_id, []): + self.indegree[succ_id] -= 1 + # 初始上下文 if workflow.current_action and workflow.context: logger.info(f"工作流已执行动作:{workflow.current_action}") @@ -80,47 +88,68 @@ class WorkflowExecutor: global_vars.workflow_resume(self.workflow.id) # 初始化队列,添加入度为0的节点 for action_id in self.actions: - if self.indegree[action_id] == 0: + if action_id not in self.completed_actions and self.indegree[action_id] == 0: self.queue.append(action_id) - def execute(self): + def execute(self) -> None: """ 执行工作流 """ - while True: - with self.lock: - # 退出条件:队列为空且无运行任务 - if not self.queue and self.running_tasks == 0: - break - # 退出条件:出现了错误 - if not self.success: - break - if not self.queue: + try: + while True: + should_sleep = False + with self.lock: + if global_vars.is_workflow_stopped(self.workflow.id): + self.success = False + self.stopped = True + self.errmsg = "工作流已停止" + if self.running_tasks == 0: + break + should_sleep = True + # 退出条件:队列为空且无运行任务 + elif not self.queue and self.running_tasks == 0: + break + # 出错后不再调度新节点,但等待已提交节点完成,避免后台线程继续写状态。 + if not self.success: + if self.running_tasks == 0: + break + should_sleep = True + elif not self.queue: + should_sleep = True + else: + # 取出队首节点 + node_id = self.queue.popleft() + # 标记任务开始 + self.running_tasks += 1 + + if should_sleep: sleep(0.1) continue - # 取出队首节点 - node_id = self.queue.popleft() - # 标记任务开始 - self.running_tasks += 1 - # 已停机 - if global_vars.is_workflow_stopped(self.workflow.id): - global_vars.workflow_resume(self.workflow.id) - break + # 已停机 + if global_vars.is_workflow_stopped(self.workflow.id): + with self.lock: + self.success = False + self.stopped = True + self.errmsg = "工作流已停止" + self.running_tasks -= 1 + break - # 已执行的跳过 - if (self.workflow.current_action - and node_id in self.workflow.current_action.split(',')): - continue + # 已执行的跳过,并继续释放后继节点。 + if node_id in self.completed_actions: + self.on_node_skipped(node_id) + continue - # 提交任务到线程池 - future = self.executor.submit( - self.execute_node, - self.workflow.id, - node_id, - self.context - ) - future.add_done_callback(self.on_node_complete) + # 提交任务到线程池,每个节点使用上下文快照,避免并行节点互相修改同一个对象。 + future = self.executor.submit( + self.execute_node, + self.workflow.id, + node_id, + copy.deepcopy(self.context) + ) + future.add_done_callback(self.on_node_complete) + finally: + self.executor.shutdown(wait=True, cancel_futures=True) def execute_node(self, workflow_id: int, node_id: int, context: ActionContext) -> Tuple[Action, bool, str, ActionContext]: @@ -135,31 +164,38 @@ class WorkflowExecutor: """ 节点完成回调:更新上下文、处理后继节点 """ - action, state, message, result_ctx = future.result() - try: - self.finished_actions += 1 - # 更新当前进度 - self.context.progress = round(self.finished_actions / self.total_actions) * 100 + action, state, message, result_ctx = future.result() + with self.lock: + if global_vars.is_workflow_stopped(self.workflow.id): + self.success = False + self.stopped = True + self.errmsg = "工作流已停止" + return + self.finished_actions += 1 + # 更新当前进度 + self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100 - # 补充执行历史 - self.context.execute_history.append( - ActionExecution( - action=action.name, - result=state, - message=message + # 补充执行历史 + self.context.execute_history.append( + ActionExecution( + action=action.name, + result=state, + message=message + ) ) - ) # 节点执行失败 if not state: - self.success = False - self.errmsg = f"{action.name} 失败" + with self.lock: + self.success = False + self.errmsg = f"{action.name} 失败" return with self.lock: # 更新主上下文 self.merge_context(result_ctx) + self.completed_actions.add(action.id) # 回调 if self.step_callback: self.step_callback(action, self.context) @@ -171,17 +207,51 @@ class WorkflowExecutor: self.indegree[succ_id] -= 1 if self.indegree[succ_id] == 0: self.queue.append(succ_id) + except Exception as err: + logger.error(f"工作流节点执行回调失败: {str(err)}") + with self.lock: + self.success = False + self.errmsg = str(err) finally: # 标记任务完成 with self.lock: self.running_tasks -= 1 - def merge_context(self, context: ActionContext): + def on_node_skipped(self, node_id: str) -> None: + """ + 跳过已完成节点,并释放其后继节点。 + """ + with self.lock: + for succ_id in self.adjacency.get(node_id, []): + self.indegree[succ_id] -= 1 + if succ_id not in self.completed_actions and self.indegree[succ_id] == 0: + self.queue.append(succ_id) + self.running_tasks -= 1 + + def merge_context(self, context: ActionContext) -> None: """ 合并上下文 """ - for key, value in context.model_dump().items(): - if not getattr(self.context, key, None): + if not context: + return + for key in context.__class__.model_fields: + value = getattr(context, key, None) + if key in ("execute_history", "progress") or value in (None, "", [], {}): + continue + current_value = getattr(self.context, key, None) + if isinstance(value, list): + if current_value is None: + setattr(self.context, key, value) + continue + for item in value: + if item not in current_value: + current_value.append(item) + elif isinstance(value, dict): + if not current_value: + setattr(self.context, key, value) + else: + current_value.update(value) + elif not current_value: setattr(self.context, key, value) @@ -217,7 +287,7 @@ class WorkflowChain(ChainBase): serialized_data = pickle.dumps(context) # 使用Base64编码字节流 encoded_data = base64.b64encode(serialized_data).decode('utf-8') - workflowoper.step(workflow_id, action_id=action.id, context={ + WorkflowOper().step(workflow_id, action_id=action.id, context={ "content": encoded_data }) @@ -244,6 +314,10 @@ class WorkflowChain(ChainBase): executor = WorkflowExecutor(workflow, step_callback=save_step) executor.execute() + if executor.stopped: + logger.info(f"工作流 {workflow.name} 已停止") + return False, executor.errmsg + if not executor.success: logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}") workflowoper.fail(workflow_id, result=executor.errmsg) diff --git a/app/db/models/workflow.py b/app/db/models/workflow.py index 1a28f95f..caa4bf32 100644 --- a/app/db/models/workflow.py +++ b/app/db/models/workflow.py @@ -41,7 +41,7 @@ class Workflow(Base): # 执行上下文 context = Column(JSON, default=dict) # 创建时间 - add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S')) # 最后执行时间 last_time = Column(String) @@ -79,7 +79,7 @@ class Workflow(Base): and_( or_( cls.trigger_type == 'timer', - not cls.trigger_type + cls.trigger_type.is_(None) ), cls.state != 'P' ) @@ -93,7 +93,7 @@ class Workflow(Base): and_( or_( cls.trigger_type == 'timer', - not cls.trigger_type + cls.trigger_type.is_(None) ), cls.state != 'P' ) @@ -217,6 +217,7 @@ class Workflow(Base): "state": 'W', "result": None, "current_action": None, + "context": {}, "run_count": 0 if reset_count else cls.run_count, }) return True @@ -229,6 +230,7 @@ class Workflow(Base): state='W', result=None, current_action=None, + context={}, run_count=0 if reset_count else cls.run_count, )) return True @@ -236,8 +238,14 @@ class Workflow(Base): @classmethod @db_update def update_current_action(cls, db, wid: int, action_id: str, context: dict): + workflow = db.query(cls).filter(cls.id == wid).first() + current_actions = [] + if workflow and workflow.current_action: + current_actions = [item for item in workflow.current_action.split(",") if item] + if action_id not in current_actions: + current_actions.append(action_id) db.query(cls).filter(cls.id == wid).update({ - "current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id, + "current_action": ",".join(current_actions), "context": context }) return True @@ -249,7 +257,10 @@ class Workflow(Base): # 先获取当前current_action result = await db.execute(select(cls.current_action).where(cls.id == wid)) current_action = result.scalar() - new_current_action = current_action + f",{action_id}" if current_action else action_id + current_actions = [item for item in (current_action or "").split(",") if item] + if action_id not in current_actions: + current_actions.append(action_id) + new_current_action = ",".join(current_actions) await db.execute(update(cls).where(cls.id == wid).values( current_action=new_current_action, diff --git a/app/schemas/workflow.py b/app/schemas/workflow.py index 52f7401b..358a47d7 100644 --- a/app/schemas/workflow.py +++ b/app/schemas/workflow.py @@ -19,13 +19,13 @@ class Workflow(BaseModel): timer: Optional[str] = Field(default=None, description="定时器") trigger_type: Optional[str] = Field(default='timer', description="触发类型:timer-定时触发 event-事件触发 manual-手动触发") event_type: Optional[str] = Field(default=None, description="事件类型(当trigger_type为event时使用)") - event_conditions: Optional[dict] = Field(default={}, description="事件条件(JSON格式,用于过滤事件)") + event_conditions: Optional[dict] = Field(default_factory=dict, description="事件条件(JSON格式,用于过滤事件)") state: Optional[str] = Field(default=None, description="状态") current_action: Optional[str] = Field(default=None, description="已执行动作") result: Optional[str] = Field(default=None, description="任务执行结果") run_count: Optional[int] = Field(default=0, description="已执行次数") - actions: Optional[list] = Field(default=[], description="任务列表") - flows: Optional[list] = Field(default=[], description="任务流") + actions: Optional[list] = Field(default_factory=list, description="任务列表") + flows: Optional[list] = Field(default_factory=list, description="任务流") add_time: Optional[str] = Field(default=None, description="创建时间") last_time: Optional[str] = Field(default=None, description="最后执行时间") @@ -48,8 +48,8 @@ class Action(BaseModel): type: Optional[str] = Field(default=None, description="动作类型 (类名)") name: Optional[str] = Field(default=None, description="动作名称") description: Optional[str] = Field(default=None, description="动作描述") - position: Optional[dict] = Field(default={}, description="位置") - data: Optional[dict] = Field(default={}, description="参数") + position: Optional[dict] = Field(default_factory=dict, description="位置") + data: Optional[dict] = Field(default_factory=dict, description="参数") class ActionExecution(BaseModel): @@ -66,13 +66,13 @@ class ActionContext(BaseModel): 动作基础上下文,各动作通用数据 """ content: Optional[str] = Field(default=None, description="文本类内容") - torrents: Optional[List[Context]] = Field(default=[], description="资源列表") - medias: Optional[List[MediaInfo]] = Field(default=[], description="媒体列表") - fileitems: Optional[List[FileItem]] = Field(default=[], description="文件列表") - downloads: Optional[List[DownloadTask]] = Field(default=[], description="下载任务列表") - sites: Optional[List[Site]] = Field(default=[], description="站点列表") - subscribes: Optional[List[Subscribe]] = Field(default=[], description="订阅列表") - execute_history: Optional[List[ActionExecution]] = Field(default=[], description="执行历史") + torrents: Optional[List[Context]] = Field(default_factory=list, description="资源列表") + medias: Optional[List[MediaInfo]] = Field(default_factory=list, description="媒体列表") + fileitems: Optional[List[FileItem]] = Field(default_factory=list, description="文件列表") + downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表") + sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表") + subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表") + execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史") progress: Optional[int] = Field(default=0, description="执行进度(%)") diff --git a/app/workflow/__init__.py b/app/workflow/__init__.py index ebb71689..5f55fc76 100644 --- a/app/workflow/__init__.py +++ b/app/workflow/__init__.py @@ -63,9 +63,18 @@ class WorkFlowManager(metaclass=Singleton): """ 停止 """ + for event_type_str in list(self._event_workflows.keys()): + self.remove_workflow_event(event_type_str=event_type_str) self._actions = {} self._event_workflows = {} + def execute(self, workflow_id: int, action: Action, + context: ActionContext = None) -> Tuple[bool, str, ActionContext]: + """ + 执行工作流动作 + """ + return self.excute(workflow_id=workflow_id, action=action, context=context) + def excute(self, workflow_id: int, action: Action, context: ActionContext = None) -> Tuple[bool, str, ActionContext]: """ @@ -126,8 +135,8 @@ class WorkFlowManager(metaclass=Singleton): """ 更新工作流事件触发器 """ - # 确保先移除旧的事件监听器 - self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type) + # 工作流可能切换触发事件,先按工作流ID从所有事件映射中移除。 + self.remove_workflow_event(workflow_id=workflow.id) # 如果工作流是事件触发类型且未被禁用 if workflow.trigger_type == "event" and workflow.state != 'P': # 注册事件触发器 @@ -154,41 +163,46 @@ class WorkFlowManager(metaclass=Singleton): """ 注册工作流事件触发器 """ + if not event_type_str: + return try: event_type = EventType(event_type_str) except ValueError: logger.error(f"无效的事件类型: {event_type_str}") return if event_type in EventType: - # 确保先移除旧的事件监听器 - self.remove_workflow_event(workflow_id, event_type.value) with self._lock: - # 添加新的事件监听器 - eventmanager.add_event_listener(event_type, self._handle_event) - # 记录工作流事件触发器 if event_type.value not in self._event_workflows: self._event_workflows[event_type.value] = [] - self._event_workflows[event_type.value].append(workflow_id) + eventmanager.add_event_listener(event_type, self._handle_event) + # 记录工作流事件触发器 + if workflow_id not in self._event_workflows[event_type.value]: + self._event_workflows[event_type.value].append(workflow_id) logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}") - def remove_workflow_event(self, workflow_id: int, event_type_str: str): + def remove_workflow_event(self, workflow_id: Optional[int] = None, event_type_str: Optional[str] = None): """ 移除工作流事件触发器 """ - try: - event_type = EventType(event_type_str) - except ValueError: - logger.error(f"无效的事件类型: {event_type_str}") - return - if event_type in EventType: + event_type_values = [event_type_str] if event_type_str else list(self._event_workflows.keys()) + for event_type_value in event_type_values: + try: + event_type = EventType(event_type_value) + except ValueError: + logger.error(f"无效的事件类型: {event_type_value}") + continue with self._lock: - eventmanager.remove_event_listener(event_type, self._handle_event) - if event_type.value in self._event_workflows: - if workflow_id in self._event_workflows[event_type.value]: - self._event_workflows[event_type.value].remove(workflow_id) - if not self._event_workflows[event_type.value]: - del self._event_workflows[event_type.value] - logger.info(f"已移除工作流 {workflow_id} 事件触发器") + workflow_ids = self._event_workflows.get(event_type.value) + if not workflow_ids: + continue + if workflow_id is None: + workflow_ids.clear() + elif workflow_id in workflow_ids: + workflow_ids.remove(workflow_id) + if not workflow_ids: + self._event_workflows.pop(event_type.value, None) + eventmanager.remove_event_listener(event_type, self._handle_event) + logger.info(f"已移除工作流 {workflow_id or ''} 事件触发器") def _handle_event(self, event: Event): """ diff --git a/app/workflow/actions/__init__.py b/app/workflow/actions/__init__.py index 27f6a3c8..895a784a 100644 --- a/app/workflow/actions/__init__.py +++ b/app/workflow/actions/__init__.py @@ -26,6 +26,8 @@ class BaseAction(ABC): def __init__(self, action_id: str): self._action_id = action_id + self._done_flag = False + self._message = "" self.systemconfigoper = SystemConfigOper() @classmethod @@ -92,9 +94,12 @@ class BaseAction(ABC): workflow_cache = self.systemconfigoper.get(workflow_key) or {} action_cache = workflow_cache.get(self._action_id) or [] if isinstance(data, list): - action_cache.extend(data) + for item in data: + if item not in action_cache: + action_cache.append(item) else: - action_cache.append(data) + if data not in action_cache: + action_cache.append(data) workflow_cache[self._action_id] = action_cache self.systemconfigoper.set(workflow_key, workflow_cache) diff --git a/app/workflow/actions/fetch_downloads.py b/app/workflow/actions/fetch_downloads.py index 9a380fdf..aebeb927 100644 --- a/app/workflow/actions/fetch_downloads.py +++ b/app/workflow/actions/fetch_downloads.py @@ -43,12 +43,19 @@ class FetchDownloadsAction(BaseAction): """ 更新downloads中的下载任务状态 """ - __all_complete = False + self._downloads = context.downloads or [] + if not self._downloads: + self.job_done("无下载任务") + return context + for download in self._downloads: if global_vars.is_workflow_stopped(workflow_id): break logger.info(f"获取下载任务 {download.download_id} 状态 ...") - torrents = ActionChain().list_torrents(hashs=[download.download_id]) + torrents = ActionChain().list_torrents( + hashs=[download.download_id], + downloader=download.downloader, + ) if not torrents: download.completed = True continue @@ -61,5 +68,5 @@ class FetchDownloadsAction(BaseAction): logger.info(f"下载任务 {download.download_id} 未完成") download.completed = False if all([d.completed for d in self._downloads]): - self.job_done() + self.job_done("下载任务已全部完成") return context diff --git a/app/workflow/actions/scrape_file.py b/app/workflow/actions/scrape_file.py index 198ae8c5..7f4ceab3 100644 --- a/app/workflow/actions/scrape_file.py +++ b/app/workflow/actions/scrape_file.py @@ -61,18 +61,18 @@ class ScrapeFileAction(BaseAction): logger.info(f"{fileitem.path} 已刮削过,跳过") continue mediachain = MediaChain() - context = mediachain.recognize_by_path( + media_context = mediachain.recognize_by_path( fileitem.path, obtain_images=True, ) - if not context or not context.media_info: + if not media_context or not media_context.media_info: _failed_count += 1 logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削") continue mediachain.scrape_metadata( fileitem=fileitem, - meta=context.meta_info, - mediainfo=context.media_info + meta=media_context.meta_info, + mediainfo=media_context.media_info ) self._scraped_files.append(fileitem) # 保存缓存 diff --git a/tests/test_workflow_actions.py b/tests/test_workflow_actions.py new file mode 100644 index 00000000..f4cdbddc --- /dev/null +++ b/tests/test_workflow_actions.py @@ -0,0 +1,80 @@ +from types import SimpleNamespace + +from app.schemas import ActionContext, DownloadTask, FileItem +from app.workflow.actions import fetch_downloads as fetch_downloads_module +from app.workflow.actions import scrape_file as scrape_file_module +from app.workflow.actions.fetch_downloads import FetchDownloadsAction +from app.workflow.actions.scrape_file import ScrapeFileAction + + +def test_fetch_downloads_updates_context_downloads(monkeypatch): + """获取下载任务动作应更新上游上下文中的下载任务。""" + calls = [] + + class FakeActionChain: + """模拟下载器查询链。""" + + def list_torrents(self, hashs=None, downloader=None, **kwargs): + calls.append((hashs, downloader)) + return [SimpleNamespace(path="/downloads/movie.mkv", progress=100)] + + monkeypatch.setattr(fetch_downloads_module, "ActionChain", FakeActionChain) + monkeypatch.setattr(fetch_downloads_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + context = ActionContext( + downloads=[ + DownloadTask(download_id="hash-1", downloader="qbittorrent"), + ] + ) + + result = FetchDownloadsAction("fetch-downloads").execute( + workflow_id=1, + params={}, + context=context, + ) + + assert calls == [(["hash-1"], "qbittorrent")] + assert result.downloads[0].completed is True + assert result.downloads[0].path == "/downloads/movie.mkv" + + +def test_scrape_file_keeps_workflow_action_context(monkeypatch): + """刮削文件动作不应将工作流上下文替换为媒体识别上下文。""" + scraped = [] + + class FakeStorageChain: + """模拟存储链。""" + + def exists(self, fileitem): + return True + + class FakeMediaChain: + """模拟媒体识别和刮削链。""" + + def recognize_by_path(self, path, obtain_images=False): + return SimpleNamespace(meta_info="meta", media_info="media") + + def scrape_metadata(self, fileitem, meta=None, mediainfo=None): + scraped.append((fileitem.path, meta, mediainfo)) + + monkeypatch.setattr(scrape_file_module, "StorageChain", FakeStorageChain) + monkeypatch.setattr(scrape_file_module, "MediaChain", FakeMediaChain) + monkeypatch.setattr(scrape_file_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + monkeypatch.setattr(ScrapeFileAction, "check_cache", lambda self, workflow_id, key: False) + monkeypatch.setattr(ScrapeFileAction, "save_cache", lambda self, workflow_id, data: None) + + context = ActionContext( + fileitems=[ + FileItem(path="/library/movie.mkv", storage="local", type="file"), + ] + ) + + result = ScrapeFileAction("scrape-file").execute( + workflow_id=1, + params={}, + context=context, + ) + + assert result is context + assert result.fileitems[0].path == "/library/movie.mkv" + assert scraped == [("/library/movie.mkv", "meta", "media")] diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py new file mode 100644 index 00000000..a23a004d --- /dev/null +++ b/tests/test_workflow_execution.py @@ -0,0 +1,153 @@ +import base64 +import pickle +import threading +from types import SimpleNamespace + +from app.chain import workflow as workflow_module +from app.schemas import ActionContext +from app.schemas.types import EventType +from app import workflow as workflow_package + + +def _build_workflow(current_action=None, context=None): + """构造最小工作流对象。""" + return SimpleNamespace( + id=1, + name="测试工作流", + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + ], + flows=[ + {"id": "flow-1", "source": "A", "target": "B", "animated": True}, + ], + current_action=current_action, + context=context, + ) + + +def _encoded_context(context: ActionContext) -> dict: + """编码工作流恢复上下文。""" + return { + "content": base64.b64encode(pickle.dumps(context)).decode("utf-8"), + } + + +class _FakeWorkflowManager: + """记录执行动作的工作流管理器。""" + + def __init__(self, calls): + self.calls = calls + + def excute(self, workflow_id, action, context=None): + self.calls.append(action.id) + return True, f"{action.name}完成", context or ActionContext() + + +def test_workflow_executor_resumes_downstream_nodes(monkeypatch): + """恢复执行时应释放已完成节点的后继节点。""" + calls = [] + fake_manager = _FakeWorkflowManager(calls) + workflow = _build_workflow( + current_action="A", + context=_encoded_context(ActionContext()), + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["B"] + assert executor.success is True + assert executor.context.progress == 100 + + +def test_workflow_executor_reports_incremental_progress(monkeypatch): + """顺序工作流的中间进度应按已完成比例计算。""" + calls = [] + progresses = [] + fake_manager = _FakeWorkflowManager(calls) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor( + _build_workflow(), + step_callback=lambda action, context: progresses.append(context.progress), + ) + executor.execute() + + assert calls == ["A", "B"] + assert progresses == [50, 100] + + +def test_workflow_executor_stop_is_not_success(monkeypatch): + """停止信号不应被执行器汇报为成功完成。""" + calls = [] + fake_manager = _FakeWorkflowManager(calls) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: True) + + executor = workflow_module.WorkflowExecutor(_build_workflow()) + executor.execute() + + assert calls == [] + assert executor.stopped is True + assert executor.success is False + assert executor.errmsg == "工作流已停止" + + +def test_workflow_context_merge_preserves_runtime_objects(): + """合并上下文时应保留运行时对象,而不是转成字典。""" + executor = object.__new__(workflow_module.WorkflowExecutor) + executor.context = ActionContext() + runtime_torrent = SimpleNamespace(title="runtime torrent") + result_context = ActionContext() + result_context.torrents.append(runtime_torrent) + + executor.merge_context(result_context) + + assert executor.context.torrents[0] is runtime_torrent + + +class _FakeEventManager: + """记录事件监听器注册和移除次数。""" + + def __init__(self): + self.added = [] + self.removed = [] + + def add_event_listener(self, event_type, handler): + self.added.append(event_type) + + def remove_event_listener(self, event_type, handler): + self.removed.append(event_type) + + +def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkeypatch): + """同一事件下移除单个工作流时不应断开其他工作流监听。""" + fake_eventmanager = _FakeEventManager() + manager = object.__new__(workflow_package.WorkFlowManager) + manager._lock = threading.Lock() + manager._event_workflows = {} + + monkeypatch.setattr(workflow_package, "eventmanager", fake_eventmanager) + + manager.register_workflow_event(1, EventType.DownloadAdded.value) + manager.register_workflow_event(2, EventType.DownloadAdded.value) + manager.remove_workflow_event(1, EventType.DownloadAdded.value) + + assert fake_eventmanager.added == [EventType.DownloadAdded] + assert fake_eventmanager.removed == [] + assert manager.get_event_workflows() == {EventType.DownloadAdded.value: [2]} + + manager.remove_workflow_event(2, EventType.DownloadAdded.value) + + assert fake_eventmanager.removed == [EventType.DownloadAdded] + assert manager.get_event_workflows() == {}