feat(workflow): enhance workflow execution and context management

This commit is contained in:
jxxghp
2026-06-04 14:10:06 +08:00
parent fd280a49b7
commit 9056caae40
10 changed files with 483 additions and 119 deletions

View File

@@ -22,6 +22,10 @@ from app.schemas.types import EventType, EVENT_TYPE_NAMES
router = APIRouter()
WORKFLOW_TRIGGER_TIMER = "timer"
WORKFLOW_TRIGGER_EVENT = "event"
WORKFLOW_TRIGGER_MANUAL = "manual"
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
async def list_workflows(
@@ -148,16 +152,20 @@ async def workflow_fork(
except json.JSONDecodeError:
return schemas.Response(success=False, message="context字段JSON格式错误")
try:
event_conditions = json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {}
except json.JSONDecodeError:
return schemas.Response(success=False, message="event_conditions字段JSON格式错误")
share_id = workflow.id
# 创建工作流
workflow_dict = {
"name": workflow.name,
"description": workflow.description,
"timer": workflow.timer,
"trigger_type": workflow.trigger_type or "timer",
"trigger_type": workflow.trigger_type or WORKFLOW_TRIGGER_TIMER,
"event_type": workflow.event_type,
"event_conditions": json.loads(workflow.event_conditions or "{}")
if workflow.event_conditions
else {},
"event_conditions": event_conditions,
"actions": actions,
"flows": flows,
"context": context,
@@ -170,11 +178,11 @@ async def workflow_fork(
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
workflow = await Workflow(**workflow_dict).async_create(db)
workflow_obj = await Workflow(**workflow_dict).async_create(db)
# 更新复用次数
if workflow:
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=workflow.id)
if workflow_obj and share_id:
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=share_id)
return schemas.Response(success=True, message="复用成功")
@@ -225,14 +233,23 @@ def start_workflow(
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
trigger_type = workflow.trigger_type or WORKFLOW_TRIGGER_TIMER
if trigger_type == WORKFLOW_TRIGGER_TIMER and not workflow.timer:
return schemas.Response(success=False, message="定时工作流缺少定时器配置")
if trigger_type not in {
WORKFLOW_TRIGGER_TIMER,
WORKFLOW_TRIGGER_EVENT,
WORKFLOW_TRIGGER_MANUAL,
}:
return schemas.Response(success=False, message="工作流触发类型不支持")
# 先更新状态,事件触发注册会重新读取工作流并跳过暂停状态。
workflow.update_state(db, workflow_id, "W")
if trigger_type == WORKFLOW_TRIGGER_TIMER:
# 添加定时任务
Scheduler().update_workflow_job(workflow)
else:
elif trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:添加到事件触发器
WorkFlowManager().load_workflow_events(workflow_id)
# 更新状态
workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True)
@@ -251,10 +268,10 @@ def pause_workflow(
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 根据触发类型进行不同处理
if workflow.trigger_type == "timer":
if workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
# 定时触发:移除定时任务
Scheduler().remove_workflow_job(workflow)
elif workflow.trigger_type == "event":
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 停止工作流
@@ -319,8 +336,11 @@ def update_workflow(
wf.update(db, workflow.model_dump())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务
Scheduler().update_workflow_job(updated_workflow)
scheduler = Scheduler()
scheduler.remove_workflow_job(updated_workflow)
if not updated_workflow.trigger_type or updated_workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
if updated_workflow.timer:
scheduler.update_workflow_job(updated_workflow)
# 更新事件注册
WorkFlowManager().update_workflow_event(updated_workflow)
return schemas.Response(success=True, message="更新成功")
@@ -338,10 +358,10 @@ def delete_workflow(
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
if not workflow.trigger_type or workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
# 定时触发:删除定时任务
Scheduler().remove_workflow_job(workflow)
else:
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 删除工作流

View File

@@ -1,22 +1,21 @@
import base64
import copy
import pickle
import threading
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from time import sleep
from typing import List, Tuple, Optional
from pydantic.fields import Callable
from typing import Callable, List, Optional, Tuple
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.event import Event, eventmanager
from app.workflow import WorkFlowManager
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.log import logger
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
from app.schemas.types import EventType
from app.workflow import WorkFlowManager
class WorkflowExecutor:
@@ -36,9 +35,14 @@ class WorkflowExecutor:
self.actions = {action['id']: Action(**action) for action in workflow.actions}
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
self.total_actions = len(self.actions)
self.finished_actions = 0
self.completed_actions = {
action_id for action_id in (workflow.current_action or "").split(",")
if action_id in self.actions
}
self.finished_actions = len(self.completed_actions)
self.success = True
self.stopped = False
self.errmsg = ""
# 工作流管理器
@@ -66,6 +70,10 @@ class WorkflowExecutor:
if action_id not in self.indegree:
self.indegree[action_id] = 0
for action_id in self.completed_actions:
for succ_id in self.adjacency.get(action_id, []):
self.indegree[succ_id] -= 1
# 初始上下文
if workflow.current_action and workflow.context:
logger.info(f"工作流已执行动作:{workflow.current_action}")
@@ -80,47 +88,68 @@ class WorkflowExecutor:
global_vars.workflow_resume(self.workflow.id)
# 初始化队列添加入度为0的节点
for action_id in self.actions:
if self.indegree[action_id] == 0:
if action_id not in self.completed_actions and self.indegree[action_id] == 0:
self.queue.append(action_id)
def execute(self):
def execute(self) -> None:
"""
执行工作流
"""
while True:
with self.lock:
# 退出条件:队列为空且无运行任务
if not self.queue and self.running_tasks == 0:
break
# 退出条件:出现了错误
if not self.success:
break
if not self.queue:
try:
while True:
should_sleep = False
with self.lock:
if global_vars.is_workflow_stopped(self.workflow.id):
self.success = False
self.stopped = True
self.errmsg = "工作流已停止"
if self.running_tasks == 0:
break
should_sleep = True
# 退出条件:队列为空且无运行任务
elif not self.queue and self.running_tasks == 0:
break
# 出错后不再调度新节点,但等待已提交节点完成,避免后台线程继续写状态。
if not self.success:
if self.running_tasks == 0:
break
should_sleep = True
elif not self.queue:
should_sleep = True
else:
# 取出队首节点
node_id = self.queue.popleft()
# 标记任务开始
self.running_tasks += 1
if should_sleep:
sleep(0.1)
continue
# 取出队首节点
node_id = self.queue.popleft()
# 标记任务开始
self.running_tasks += 1
# 已停机
if global_vars.is_workflow_stopped(self.workflow.id):
global_vars.workflow_resume(self.workflow.id)
break
# 已停机
if global_vars.is_workflow_stopped(self.workflow.id):
with self.lock:
self.success = False
self.stopped = True
self.errmsg = "工作流已停止"
self.running_tasks -= 1
break
# 已执行的跳过
if (self.workflow.current_action
and node_id in self.workflow.current_action.split(',')):
continue
# 已执行的跳过,并继续释放后继节点。
if node_id in self.completed_actions:
self.on_node_skipped(node_id)
continue
# 提交任务到线程池
future = self.executor.submit(
self.execute_node,
self.workflow.id,
node_id,
self.context
)
future.add_done_callback(self.on_node_complete)
# 提交任务到线程池,每个节点使用上下文快照,避免并行节点互相修改同一个对象。
future = self.executor.submit(
self.execute_node,
self.workflow.id,
node_id,
copy.deepcopy(self.context)
)
future.add_done_callback(self.on_node_complete)
finally:
self.executor.shutdown(wait=True, cancel_futures=True)
def execute_node(self, workflow_id: int, node_id: int,
context: ActionContext) -> Tuple[Action, bool, str, ActionContext]:
@@ -135,31 +164,38 @@ class WorkflowExecutor:
"""
节点完成回调:更新上下文、处理后继节点
"""
action, state, message, result_ctx = future.result()
try:
self.finished_actions += 1
# 更新当前进度
self.context.progress = round(self.finished_actions / self.total_actions) * 100
action, state, message, result_ctx = future.result()
with self.lock:
if global_vars.is_workflow_stopped(self.workflow.id):
self.success = False
self.stopped = True
self.errmsg = "工作流已停止"
return
self.finished_actions += 1
# 更新当前进度
self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100
# 补充执行历史
self.context.execute_history.append(
ActionExecution(
action=action.name,
result=state,
message=message
# 补充执行历史
self.context.execute_history.append(
ActionExecution(
action=action.name,
result=state,
message=message
)
)
)
# 节点执行失败
if not state:
self.success = False
self.errmsg = f"{action.name} 失败"
with self.lock:
self.success = False
self.errmsg = f"{action.name} 失败"
return
with self.lock:
# 更新主上下文
self.merge_context(result_ctx)
self.completed_actions.add(action.id)
# 回调
if self.step_callback:
self.step_callback(action, self.context)
@@ -171,17 +207,51 @@ class WorkflowExecutor:
self.indegree[succ_id] -= 1
if self.indegree[succ_id] == 0:
self.queue.append(succ_id)
except Exception as err:
logger.error(f"工作流节点执行回调失败: {str(err)}")
with self.lock:
self.success = False
self.errmsg = str(err)
finally:
# 标记任务完成
with self.lock:
self.running_tasks -= 1
def merge_context(self, context: ActionContext):
def on_node_skipped(self, node_id: str) -> None:
"""
跳过已完成节点,并释放其后继节点。
"""
with self.lock:
for succ_id in self.adjacency.get(node_id, []):
self.indegree[succ_id] -= 1
if succ_id not in self.completed_actions and self.indegree[succ_id] == 0:
self.queue.append(succ_id)
self.running_tasks -= 1
def merge_context(self, context: ActionContext) -> None:
"""
合并上下文
"""
for key, value in context.model_dump().items():
if not getattr(self.context, key, None):
if not context:
return
for key in context.__class__.model_fields:
value = getattr(context, key, None)
if key in ("execute_history", "progress") or value in (None, "", [], {}):
continue
current_value = getattr(self.context, key, None)
if isinstance(value, list):
if current_value is None:
setattr(self.context, key, value)
continue
for item in value:
if item not in current_value:
current_value.append(item)
elif isinstance(value, dict):
if not current_value:
setattr(self.context, key, value)
else:
current_value.update(value)
elif not current_value:
setattr(self.context, key, value)
@@ -217,7 +287,7 @@ class WorkflowChain(ChainBase):
serialized_data = pickle.dumps(context)
# 使用Base64编码字节流
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
workflowoper.step(workflow_id, action_id=action.id, context={
WorkflowOper().step(workflow_id, action_id=action.id, context={
"content": encoded_data
})
@@ -244,6 +314,10 @@ class WorkflowChain(ChainBase):
executor = WorkflowExecutor(workflow, step_callback=save_step)
executor.execute()
if executor.stopped:
logger.info(f"工作流 {workflow.name} 已停止")
return False, executor.errmsg
if not executor.success:
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
workflowoper.fail(workflow_id, result=executor.errmsg)

View File

@@ -41,7 +41,7 @@ class Workflow(Base):
# 执行上下文
context = Column(JSON, default=dict)
# 创建时间
add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 最后执行时间
last_time = Column(String)
@@ -79,7 +79,7 @@ class Workflow(Base):
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
cls.trigger_type.is_(None)
),
cls.state != 'P'
)
@@ -93,7 +93,7 @@ class Workflow(Base):
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
cls.trigger_type.is_(None)
),
cls.state != 'P'
)
@@ -217,6 +217,7 @@ class Workflow(Base):
"state": 'W',
"result": None,
"current_action": None,
"context": {},
"run_count": 0 if reset_count else cls.run_count,
})
return True
@@ -229,6 +230,7 @@ class Workflow(Base):
state='W',
result=None,
current_action=None,
context={},
run_count=0 if reset_count else cls.run_count,
))
return True
@@ -236,8 +238,14 @@ class Workflow(Base):
@classmethod
@db_update
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
workflow = db.query(cls).filter(cls.id == wid).first()
current_actions = []
if workflow and workflow.current_action:
current_actions = [item for item in workflow.current_action.split(",") if item]
if action_id not in current_actions:
current_actions.append(action_id)
db.query(cls).filter(cls.id == wid).update({
"current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
"current_action": ",".join(current_actions),
"context": context
})
return True
@@ -249,7 +257,10 @@ class Workflow(Base):
# 先获取当前current_action
result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar()
new_current_action = current_action + f",{action_id}" if current_action else action_id
current_actions = [item for item in (current_action or "").split(",") if item]
if action_id not in current_actions:
current_actions.append(action_id)
new_current_action = ",".join(current_actions)
await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action,

View File

@@ -19,13 +19,13 @@ class Workflow(BaseModel):
timer: Optional[str] = Field(default=None, description="定时器")
trigger_type: Optional[str] = Field(default='timer', description="触发类型timer-定时触发 event-事件触发 manual-手动触发")
event_type: Optional[str] = Field(default=None, description="事件类型当trigger_type为event时使用")
event_conditions: Optional[dict] = Field(default={}, description="事件条件JSON格式用于过滤事件")
event_conditions: Optional[dict] = Field(default_factory=dict, description="事件条件JSON格式用于过滤事件")
state: Optional[str] = Field(default=None, description="状态")
current_action: Optional[str] = Field(default=None, description="已执行动作")
result: Optional[str] = Field(default=None, description="任务执行结果")
run_count: Optional[int] = Field(default=0, description="已执行次数")
actions: Optional[list] = Field(default=[], description="任务列表")
flows: Optional[list] = Field(default=[], description="任务流")
actions: Optional[list] = Field(default_factory=list, description="任务列表")
flows: Optional[list] = Field(default_factory=list, description="任务流")
add_time: Optional[str] = Field(default=None, description="创建时间")
last_time: Optional[str] = Field(default=None, description="最后执行时间")
@@ -48,8 +48,8 @@ class Action(BaseModel):
type: Optional[str] = Field(default=None, description="动作类型 (类名)")
name: Optional[str] = Field(default=None, description="动作名称")
description: Optional[str] = Field(default=None, description="动作描述")
position: Optional[dict] = Field(default={}, description="位置")
data: Optional[dict] = Field(default={}, description="参数")
position: Optional[dict] = Field(default_factory=dict, description="位置")
data: Optional[dict] = Field(default_factory=dict, description="参数")
class ActionExecution(BaseModel):
@@ -66,13 +66,13 @@ class ActionContext(BaseModel):
动作基础上下文,各动作通用数据
"""
content: Optional[str] = Field(default=None, description="文本类内容")
torrents: Optional[List[Context]] = Field(default=[], description="资源列表")
medias: Optional[List[MediaInfo]] = Field(default=[], description="媒体列表")
fileitems: Optional[List[FileItem]] = Field(default=[], description="文件列表")
downloads: Optional[List[DownloadTask]] = Field(default=[], description="下载任务列表")
sites: Optional[List[Site]] = Field(default=[], description="站点列表")
subscribes: Optional[List[Subscribe]] = Field(default=[], description="订阅列表")
execute_history: Optional[List[ActionExecution]] = Field(default=[], description="执行历史")
torrents: Optional[List[Context]] = Field(default_factory=list, description="资源列表")
medias: Optional[List[MediaInfo]] = Field(default_factory=list, description="媒体列表")
fileitems: Optional[List[FileItem]] = Field(default_factory=list, description="文件列表")
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
progress: Optional[int] = Field(default=0, description="执行进度(%")

View File

@@ -63,9 +63,18 @@ class WorkFlowManager(metaclass=Singleton):
"""
停止
"""
for event_type_str in list(self._event_workflows.keys()):
self.remove_workflow_event(event_type_str=event_type_str)
self._actions = {}
self._event_workflows = {}
def execute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
"""
执行工作流动作
"""
return self.excute(workflow_id=workflow_id, action=action, context=context)
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
"""
@@ -126,8 +135,8 @@ class WorkFlowManager(metaclass=Singleton):
"""
更新工作流事件触发器
"""
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
# 工作流可能切换触发事件先按工作流ID从所有事件映射中移除。
self.remove_workflow_event(workflow_id=workflow.id)
# 如果工作流是事件触发类型且未被禁用
if workflow.trigger_type == "event" and workflow.state != 'P':
# 注册事件触发器
@@ -154,41 +163,46 @@ class WorkFlowManager(metaclass=Singleton):
"""
注册工作流事件触发器
"""
if not event_type_str:
return
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id, event_type.value)
with self._lock:
# 添加新的事件监听器
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if event_type.value not in self._event_workflows:
self._event_workflows[event_type.value] = []
self._event_workflows[event_type.value].append(workflow_id)
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if workflow_id not in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].append(workflow_id)
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
def remove_workflow_event(self, workflow_id: Optional[int] = None, event_type_str: Optional[str] = None):
"""
移除工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
event_type_values = [event_type_str] if event_type_str else list(self._event_workflows.keys())
for event_type_value in event_type_values:
try:
event_type = EventType(event_type_value)
except ValueError:
logger.error(f"无效的事件类型: {event_type_value}")
continue
with self._lock:
eventmanager.remove_event_listener(event_type, self._handle_event)
if event_type.value in self._event_workflows:
if workflow_id in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].remove(workflow_id)
if not self._event_workflows[event_type.value]:
del self._event_workflows[event_type.value]
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
workflow_ids = self._event_workflows.get(event_type.value)
if not workflow_ids:
continue
if workflow_id is None:
workflow_ids.clear()
elif workflow_id in workflow_ids:
workflow_ids.remove(workflow_id)
if not workflow_ids:
self._event_workflows.pop(event_type.value, None)
eventmanager.remove_event_listener(event_type, self._handle_event)
logger.info(f"已移除工作流 {workflow_id or ''} 事件触发器")
def _handle_event(self, event: Event):
"""

View File

@@ -26,6 +26,8 @@ class BaseAction(ABC):
def __init__(self, action_id: str):
self._action_id = action_id
self._done_flag = False
self._message = ""
self.systemconfigoper = SystemConfigOper()
@classmethod
@@ -92,9 +94,12 @@ class BaseAction(ABC):
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
action_cache = workflow_cache.get(self._action_id) or []
if isinstance(data, list):
action_cache.extend(data)
for item in data:
if item not in action_cache:
action_cache.append(item)
else:
action_cache.append(data)
if data not in action_cache:
action_cache.append(data)
workflow_cache[self._action_id] = action_cache
self.systemconfigoper.set(workflow_key, workflow_cache)

View File

@@ -43,12 +43,19 @@ class FetchDownloadsAction(BaseAction):
"""
更新downloads中的下载任务状态
"""
__all_complete = False
self._downloads = context.downloads or []
if not self._downloads:
self.job_done("无下载任务")
return context
for download in self._downloads:
if global_vars.is_workflow_stopped(workflow_id):
break
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
torrents = ActionChain().list_torrents(hashs=[download.download_id])
torrents = ActionChain().list_torrents(
hashs=[download.download_id],
downloader=download.downloader,
)
if not torrents:
download.completed = True
continue
@@ -61,5 +68,5 @@ class FetchDownloadsAction(BaseAction):
logger.info(f"下载任务 {download.download_id} 未完成")
download.completed = False
if all([d.completed for d in self._downloads]):
self.job_done()
self.job_done("下载任务已全部完成")
return context

View File

@@ -61,18 +61,18 @@ class ScrapeFileAction(BaseAction):
logger.info(f"{fileitem.path} 已刮削过,跳过")
continue
mediachain = MediaChain()
context = mediachain.recognize_by_path(
media_context = mediachain.recognize_by_path(
fileitem.path,
obtain_images=True,
)
if not context or not context.media_info:
if not media_context or not media_context.media_info:
_failed_count += 1
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
continue
mediachain.scrape_metadata(
fileitem=fileitem,
meta=context.meta_info,
mediainfo=context.media_info
meta=media_context.meta_info,
mediainfo=media_context.media_info
)
self._scraped_files.append(fileitem)
# 保存缓存

View File

@@ -0,0 +1,80 @@
from types import SimpleNamespace
from app.schemas import ActionContext, DownloadTask, FileItem
from app.workflow.actions import fetch_downloads as fetch_downloads_module
from app.workflow.actions import scrape_file as scrape_file_module
from app.workflow.actions.fetch_downloads import FetchDownloadsAction
from app.workflow.actions.scrape_file import ScrapeFileAction
def test_fetch_downloads_updates_context_downloads(monkeypatch):
"""获取下载任务动作应更新上游上下文中的下载任务。"""
calls = []
class FakeActionChain:
"""模拟下载器查询链。"""
def list_torrents(self, hashs=None, downloader=None, **kwargs):
calls.append((hashs, downloader))
return [SimpleNamespace(path="/downloads/movie.mkv", progress=100)]
monkeypatch.setattr(fetch_downloads_module, "ActionChain", FakeActionChain)
monkeypatch.setattr(fetch_downloads_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
context = ActionContext(
downloads=[
DownloadTask(download_id="hash-1", downloader="qbittorrent"),
]
)
result = FetchDownloadsAction("fetch-downloads").execute(
workflow_id=1,
params={},
context=context,
)
assert calls == [(["hash-1"], "qbittorrent")]
assert result.downloads[0].completed is True
assert result.downloads[0].path == "/downloads/movie.mkv"
def test_scrape_file_keeps_workflow_action_context(monkeypatch):
"""刮削文件动作不应将工作流上下文替换为媒体识别上下文。"""
scraped = []
class FakeStorageChain:
"""模拟存储链。"""
def exists(self, fileitem):
return True
class FakeMediaChain:
"""模拟媒体识别和刮削链。"""
def recognize_by_path(self, path, obtain_images=False):
return SimpleNamespace(meta_info="meta", media_info="media")
def scrape_metadata(self, fileitem, meta=None, mediainfo=None):
scraped.append((fileitem.path, meta, mediainfo))
monkeypatch.setattr(scrape_file_module, "StorageChain", FakeStorageChain)
monkeypatch.setattr(scrape_file_module, "MediaChain", FakeMediaChain)
monkeypatch.setattr(scrape_file_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
monkeypatch.setattr(ScrapeFileAction, "check_cache", lambda self, workflow_id, key: False)
monkeypatch.setattr(ScrapeFileAction, "save_cache", lambda self, workflow_id, data: None)
context = ActionContext(
fileitems=[
FileItem(path="/library/movie.mkv", storage="local", type="file"),
]
)
result = ScrapeFileAction("scrape-file").execute(
workflow_id=1,
params={},
context=context,
)
assert result is context
assert result.fileitems[0].path == "/library/movie.mkv"
assert scraped == [("/library/movie.mkv", "meta", "media")]

View File

@@ -0,0 +1,153 @@
import base64
import pickle
import threading
from types import SimpleNamespace
from app.chain import workflow as workflow_module
from app.schemas import ActionContext
from app.schemas.types import EventType
from app import workflow as workflow_package
def _build_workflow(current_action=None, context=None):
"""构造最小工作流对象。"""
return SimpleNamespace(
id=1,
name="测试工作流",
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
],
flows=[
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
],
current_action=current_action,
context=context,
)
def _encoded_context(context: ActionContext) -> dict:
"""编码工作流恢复上下文。"""
return {
"content": base64.b64encode(pickle.dumps(context)).decode("utf-8"),
}
class _FakeWorkflowManager:
"""记录执行动作的工作流管理器。"""
def __init__(self, calls):
self.calls = calls
def excute(self, workflow_id, action, context=None):
self.calls.append(action.id)
return True, f"{action.name}完成", context or ActionContext()
def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
"""恢复执行时应释放已完成节点的后继节点。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
current_action="A",
context=_encoded_context(ActionContext()),
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.success is True
assert executor.context.progress == 100
def test_workflow_executor_reports_incremental_progress(monkeypatch):
"""顺序工作流的中间进度应按已完成比例计算。"""
calls = []
progresses = []
fake_manager = _FakeWorkflowManager(calls)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(
_build_workflow(),
step_callback=lambda action, context: progresses.append(context.progress),
)
executor.execute()
assert calls == ["A", "B"]
assert progresses == [50, 100]
def test_workflow_executor_stop_is_not_success(monkeypatch):
"""停止信号不应被执行器汇报为成功完成。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: True)
executor = workflow_module.WorkflowExecutor(_build_workflow())
executor.execute()
assert calls == []
assert executor.stopped is True
assert executor.success is False
assert executor.errmsg == "工作流已停止"
def test_workflow_context_merge_preserves_runtime_objects():
"""合并上下文时应保留运行时对象,而不是转成字典。"""
executor = object.__new__(workflow_module.WorkflowExecutor)
executor.context = ActionContext()
runtime_torrent = SimpleNamespace(title="runtime torrent")
result_context = ActionContext()
result_context.torrents.append(runtime_torrent)
executor.merge_context(result_context)
assert executor.context.torrents[0] is runtime_torrent
class _FakeEventManager:
"""记录事件监听器注册和移除次数。"""
def __init__(self):
self.added = []
self.removed = []
def add_event_listener(self, event_type, handler):
self.added.append(event_type)
def remove_event_listener(self, event_type, handler):
self.removed.append(event_type)
def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkeypatch):
"""同一事件下移除单个工作流时不应断开其他工作流监听。"""
fake_eventmanager = _FakeEventManager()
manager = object.__new__(workflow_package.WorkFlowManager)
manager._lock = threading.Lock()
manager._event_workflows = {}
monkeypatch.setattr(workflow_package, "eventmanager", fake_eventmanager)
manager.register_workflow_event(1, EventType.DownloadAdded.value)
manager.register_workflow_event(2, EventType.DownloadAdded.value)
manager.remove_workflow_event(1, EventType.DownloadAdded.value)
assert fake_eventmanager.added == [EventType.DownloadAdded]
assert fake_eventmanager.removed == []
assert manager.get_event_workflows() == {EventType.DownloadAdded.value: [2]}
manager.remove_workflow_event(2, EventType.DownloadAdded.value)
assert fake_eventmanager.removed == [EventType.DownloadAdded]
assert manager.get_event_workflows() == {}