mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-07 07:26:50 +00:00
feat(workflow): enhance workflow execution and context management
This commit is contained in:
@@ -22,6 +22,10 @@ from app.schemas.types import EventType, EVENT_TYPE_NAMES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
WORKFLOW_TRIGGER_TIMER = "timer"
|
||||
WORKFLOW_TRIGGER_EVENT = "event"
|
||||
WORKFLOW_TRIGGER_MANUAL = "manual"
|
||||
|
||||
|
||||
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
|
||||
async def list_workflows(
|
||||
@@ -148,16 +152,20 @@ async def workflow_fork(
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="context字段JSON格式错误")
|
||||
|
||||
try:
|
||||
event_conditions = json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {}
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="event_conditions字段JSON格式错误")
|
||||
|
||||
share_id = workflow.id
|
||||
# 创建工作流
|
||||
workflow_dict = {
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"timer": workflow.timer,
|
||||
"trigger_type": workflow.trigger_type or "timer",
|
||||
"trigger_type": workflow.trigger_type or WORKFLOW_TRIGGER_TIMER,
|
||||
"event_type": workflow.event_type,
|
||||
"event_conditions": json.loads(workflow.event_conditions or "{}")
|
||||
if workflow.event_conditions
|
||||
else {},
|
||||
"event_conditions": event_conditions,
|
||||
"actions": actions,
|
||||
"flows": flows,
|
||||
"context": context,
|
||||
@@ -170,11 +178,11 @@ async def workflow_fork(
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
|
||||
# 创建新工作流
|
||||
workflow = await Workflow(**workflow_dict).async_create(db)
|
||||
workflow_obj = await Workflow(**workflow_dict).async_create(db)
|
||||
|
||||
# 更新复用次数
|
||||
if workflow:
|
||||
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=workflow.id)
|
||||
if workflow_obj and share_id:
|
||||
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=share_id)
|
||||
|
||||
return schemas.Response(success=True, message="复用成功")
|
||||
|
||||
@@ -225,14 +233,23 @@ def start_workflow(
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
trigger_type = workflow.trigger_type or WORKFLOW_TRIGGER_TIMER
|
||||
if trigger_type == WORKFLOW_TRIGGER_TIMER and not workflow.timer:
|
||||
return schemas.Response(success=False, message="定时工作流缺少定时器配置")
|
||||
if trigger_type not in {
|
||||
WORKFLOW_TRIGGER_TIMER,
|
||||
WORKFLOW_TRIGGER_EVENT,
|
||||
WORKFLOW_TRIGGER_MANUAL,
|
||||
}:
|
||||
return schemas.Response(success=False, message="工作流触发类型不支持")
|
||||
# 先更新状态,事件触发注册会重新读取工作流并跳过暂停状态。
|
||||
workflow.update_state(db, workflow_id, "W")
|
||||
if trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 添加定时任务
|
||||
Scheduler().update_workflow_job(workflow)
|
||||
else:
|
||||
elif trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:添加到事件触发器
|
||||
WorkFlowManager().load_workflow_events(workflow_id)
|
||||
# 更新状态
|
||||
workflow.update_state(db, workflow_id, "W")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -251,10 +268,10 @@ def pause_workflow(
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 根据触发类型进行不同处理
|
||||
if workflow.trigger_type == "timer":
|
||||
if workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 定时触发:移除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
elif workflow.trigger_type == "event":
|
||||
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 停止工作流
|
||||
@@ -319,8 +336,11 @@ def update_workflow(
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
Scheduler().update_workflow_job(updated_workflow)
|
||||
scheduler = Scheduler()
|
||||
scheduler.remove_workflow_job(updated_workflow)
|
||||
if not updated_workflow.trigger_type or updated_workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
if updated_workflow.timer:
|
||||
scheduler.update_workflow_job(updated_workflow)
|
||||
# 更新事件注册
|
||||
WorkFlowManager().update_workflow_event(updated_workflow)
|
||||
return schemas.Response(success=True, message="更新成功")
|
||||
@@ -338,10 +358,10 @@ def delete_workflow(
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
if not workflow.trigger_type or workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 定时触发:删除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
else:
|
||||
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 删除工作流
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import base64
|
||||
import copy
|
||||
import pickle
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from pydantic.fields import Callable
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
|
||||
from app.schemas.types import EventType
|
||||
from app.workflow import WorkFlowManager
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
@@ -36,9 +35,14 @@ class WorkflowExecutor:
|
||||
self.actions = {action['id']: Action(**action) for action in workflow.actions}
|
||||
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
|
||||
self.total_actions = len(self.actions)
|
||||
self.finished_actions = 0
|
||||
self.completed_actions = {
|
||||
action_id for action_id in (workflow.current_action or "").split(",")
|
||||
if action_id in self.actions
|
||||
}
|
||||
self.finished_actions = len(self.completed_actions)
|
||||
|
||||
self.success = True
|
||||
self.stopped = False
|
||||
self.errmsg = ""
|
||||
|
||||
# 工作流管理器
|
||||
@@ -66,6 +70,10 @@ class WorkflowExecutor:
|
||||
if action_id not in self.indegree:
|
||||
self.indegree[action_id] = 0
|
||||
|
||||
for action_id in self.completed_actions:
|
||||
for succ_id in self.adjacency.get(action_id, []):
|
||||
self.indegree[succ_id] -= 1
|
||||
|
||||
# 初始上下文
|
||||
if workflow.current_action and workflow.context:
|
||||
logger.info(f"工作流已执行动作:{workflow.current_action}")
|
||||
@@ -80,47 +88,68 @@ class WorkflowExecutor:
|
||||
global_vars.workflow_resume(self.workflow.id)
|
||||
# 初始化队列,添加入度为0的节点
|
||||
for action_id in self.actions:
|
||||
if self.indegree[action_id] == 0:
|
||||
if action_id not in self.completed_actions and self.indegree[action_id] == 0:
|
||||
self.queue.append(action_id)
|
||||
|
||||
def execute(self):
|
||||
def execute(self) -> None:
|
||||
"""
|
||||
执行工作流
|
||||
"""
|
||||
while True:
|
||||
with self.lock:
|
||||
# 退出条件:队列为空且无运行任务
|
||||
if not self.queue and self.running_tasks == 0:
|
||||
break
|
||||
# 退出条件:出现了错误
|
||||
if not self.success:
|
||||
break
|
||||
if not self.queue:
|
||||
try:
|
||||
while True:
|
||||
should_sleep = False
|
||||
with self.lock:
|
||||
if global_vars.is_workflow_stopped(self.workflow.id):
|
||||
self.success = False
|
||||
self.stopped = True
|
||||
self.errmsg = "工作流已停止"
|
||||
if self.running_tasks == 0:
|
||||
break
|
||||
should_sleep = True
|
||||
# 退出条件:队列为空且无运行任务
|
||||
elif not self.queue and self.running_tasks == 0:
|
||||
break
|
||||
# 出错后不再调度新节点,但等待已提交节点完成,避免后台线程继续写状态。
|
||||
if not self.success:
|
||||
if self.running_tasks == 0:
|
||||
break
|
||||
should_sleep = True
|
||||
elif not self.queue:
|
||||
should_sleep = True
|
||||
else:
|
||||
# 取出队首节点
|
||||
node_id = self.queue.popleft()
|
||||
# 标记任务开始
|
||||
self.running_tasks += 1
|
||||
|
||||
if should_sleep:
|
||||
sleep(0.1)
|
||||
continue
|
||||
# 取出队首节点
|
||||
node_id = self.queue.popleft()
|
||||
# 标记任务开始
|
||||
self.running_tasks += 1
|
||||
|
||||
# 已停机
|
||||
if global_vars.is_workflow_stopped(self.workflow.id):
|
||||
global_vars.workflow_resume(self.workflow.id)
|
||||
break
|
||||
# 已停机
|
||||
if global_vars.is_workflow_stopped(self.workflow.id):
|
||||
with self.lock:
|
||||
self.success = False
|
||||
self.stopped = True
|
||||
self.errmsg = "工作流已停止"
|
||||
self.running_tasks -= 1
|
||||
break
|
||||
|
||||
# 已执行的跳过
|
||||
if (self.workflow.current_action
|
||||
and node_id in self.workflow.current_action.split(',')):
|
||||
continue
|
||||
# 已执行的跳过,并继续释放后继节点。
|
||||
if node_id in self.completed_actions:
|
||||
self.on_node_skipped(node_id)
|
||||
continue
|
||||
|
||||
# 提交任务到线程池
|
||||
future = self.executor.submit(
|
||||
self.execute_node,
|
||||
self.workflow.id,
|
||||
node_id,
|
||||
self.context
|
||||
)
|
||||
future.add_done_callback(self.on_node_complete)
|
||||
# 提交任务到线程池,每个节点使用上下文快照,避免并行节点互相修改同一个对象。
|
||||
future = self.executor.submit(
|
||||
self.execute_node,
|
||||
self.workflow.id,
|
||||
node_id,
|
||||
copy.deepcopy(self.context)
|
||||
)
|
||||
future.add_done_callback(self.on_node_complete)
|
||||
finally:
|
||||
self.executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
def execute_node(self, workflow_id: int, node_id: int,
|
||||
context: ActionContext) -> Tuple[Action, bool, str, ActionContext]:
|
||||
@@ -135,31 +164,38 @@ class WorkflowExecutor:
|
||||
"""
|
||||
节点完成回调:更新上下文、处理后继节点
|
||||
"""
|
||||
action, state, message, result_ctx = future.result()
|
||||
|
||||
try:
|
||||
self.finished_actions += 1
|
||||
# 更新当前进度
|
||||
self.context.progress = round(self.finished_actions / self.total_actions) * 100
|
||||
action, state, message, result_ctx = future.result()
|
||||
with self.lock:
|
||||
if global_vars.is_workflow_stopped(self.workflow.id):
|
||||
self.success = False
|
||||
self.stopped = True
|
||||
self.errmsg = "工作流已停止"
|
||||
return
|
||||
self.finished_actions += 1
|
||||
# 更新当前进度
|
||||
self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100
|
||||
|
||||
# 补充执行历史
|
||||
self.context.execute_history.append(
|
||||
ActionExecution(
|
||||
action=action.name,
|
||||
result=state,
|
||||
message=message
|
||||
# 补充执行历史
|
||||
self.context.execute_history.append(
|
||||
ActionExecution(
|
||||
action=action.name,
|
||||
result=state,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 节点执行失败
|
||||
if not state:
|
||||
self.success = False
|
||||
self.errmsg = f"{action.name} 失败"
|
||||
with self.lock:
|
||||
self.success = False
|
||||
self.errmsg = f"{action.name} 失败"
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
# 更新主上下文
|
||||
self.merge_context(result_ctx)
|
||||
self.completed_actions.add(action.id)
|
||||
# 回调
|
||||
if self.step_callback:
|
||||
self.step_callback(action, self.context)
|
||||
@@ -171,17 +207,51 @@ class WorkflowExecutor:
|
||||
self.indegree[succ_id] -= 1
|
||||
if self.indegree[succ_id] == 0:
|
||||
self.queue.append(succ_id)
|
||||
except Exception as err:
|
||||
logger.error(f"工作流节点执行回调失败: {str(err)}")
|
||||
with self.lock:
|
||||
self.success = False
|
||||
self.errmsg = str(err)
|
||||
finally:
|
||||
# 标记任务完成
|
||||
with self.lock:
|
||||
self.running_tasks -= 1
|
||||
|
||||
def merge_context(self, context: ActionContext):
|
||||
def on_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
跳过已完成节点,并释放其后继节点。
|
||||
"""
|
||||
with self.lock:
|
||||
for succ_id in self.adjacency.get(node_id, []):
|
||||
self.indegree[succ_id] -= 1
|
||||
if succ_id not in self.completed_actions and self.indegree[succ_id] == 0:
|
||||
self.queue.append(succ_id)
|
||||
self.running_tasks -= 1
|
||||
|
||||
def merge_context(self, context: ActionContext) -> None:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
for key, value in context.model_dump().items():
|
||||
if not getattr(self.context, key, None):
|
||||
if not context:
|
||||
return
|
||||
for key in context.__class__.model_fields:
|
||||
value = getattr(context, key, None)
|
||||
if key in ("execute_history", "progress") or value in (None, "", [], {}):
|
||||
continue
|
||||
current_value = getattr(self.context, key, None)
|
||||
if isinstance(value, list):
|
||||
if current_value is None:
|
||||
setattr(self.context, key, value)
|
||||
continue
|
||||
for item in value:
|
||||
if item not in current_value:
|
||||
current_value.append(item)
|
||||
elif isinstance(value, dict):
|
||||
if not current_value:
|
||||
setattr(self.context, key, value)
|
||||
else:
|
||||
current_value.update(value)
|
||||
elif not current_value:
|
||||
setattr(self.context, key, value)
|
||||
|
||||
|
||||
@@ -217,7 +287,7 @@ class WorkflowChain(ChainBase):
|
||||
serialized_data = pickle.dumps(context)
|
||||
# 使用Base64编码字节流
|
||||
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
|
||||
workflowoper.step(workflow_id, action_id=action.id, context={
|
||||
WorkflowOper().step(workflow_id, action_id=action.id, context={
|
||||
"content": encoded_data
|
||||
})
|
||||
|
||||
@@ -244,6 +314,10 @@ class WorkflowChain(ChainBase):
|
||||
executor = WorkflowExecutor(workflow, step_callback=save_step)
|
||||
executor.execute()
|
||||
|
||||
if executor.stopped:
|
||||
logger.info(f"工作流 {workflow.name} 已停止")
|
||||
return False, executor.errmsg
|
||||
|
||||
if not executor.success:
|
||||
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
|
||||
workflowoper.fail(workflow_id, result=executor.errmsg)
|
||||
|
||||
@@ -41,7 +41,7 @@ class Workflow(Base):
|
||||
# 执行上下文
|
||||
context = Column(JSON, default=dict)
|
||||
# 创建时间
|
||||
add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 最后执行时间
|
||||
last_time = Column(String)
|
||||
|
||||
@@ -79,7 +79,7 @@ class Workflow(Base):
|
||||
and_(
|
||||
or_(
|
||||
cls.trigger_type == 'timer',
|
||||
not cls.trigger_type
|
||||
cls.trigger_type.is_(None)
|
||||
),
|
||||
cls.state != 'P'
|
||||
)
|
||||
@@ -93,7 +93,7 @@ class Workflow(Base):
|
||||
and_(
|
||||
or_(
|
||||
cls.trigger_type == 'timer',
|
||||
not cls.trigger_type
|
||||
cls.trigger_type.is_(None)
|
||||
),
|
||||
cls.state != 'P'
|
||||
)
|
||||
@@ -217,6 +217,7 @@ class Workflow(Base):
|
||||
"state": 'W',
|
||||
"result": None,
|
||||
"current_action": None,
|
||||
"context": {},
|
||||
"run_count": 0 if reset_count else cls.run_count,
|
||||
})
|
||||
return True
|
||||
@@ -229,6 +230,7 @@ class Workflow(Base):
|
||||
state='W',
|
||||
result=None,
|
||||
current_action=None,
|
||||
context={},
|
||||
run_count=0 if reset_count else cls.run_count,
|
||||
))
|
||||
return True
|
||||
@@ -236,8 +238,14 @@ class Workflow(Base):
|
||||
@classmethod
|
||||
@db_update
|
||||
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
|
||||
workflow = db.query(cls).filter(cls.id == wid).first()
|
||||
current_actions = []
|
||||
if workflow and workflow.current_action:
|
||||
current_actions = [item for item in workflow.current_action.split(",") if item]
|
||||
if action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
db.query(cls).filter(cls.id == wid).update({
|
||||
"current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
|
||||
"current_action": ",".join(current_actions),
|
||||
"context": context
|
||||
})
|
||||
return True
|
||||
@@ -249,7 +257,10 @@ class Workflow(Base):
|
||||
# 先获取当前current_action
|
||||
result = await db.execute(select(cls.current_action).where(cls.id == wid))
|
||||
current_action = result.scalar()
|
||||
new_current_action = current_action + f",{action_id}" if current_action else action_id
|
||||
current_actions = [item for item in (current_action or "").split(",") if item]
|
||||
if action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
new_current_action = ",".join(current_actions)
|
||||
|
||||
await db.execute(update(cls).where(cls.id == wid).values(
|
||||
current_action=new_current_action,
|
||||
|
||||
@@ -19,13 +19,13 @@ class Workflow(BaseModel):
|
||||
timer: Optional[str] = Field(default=None, description="定时器")
|
||||
trigger_type: Optional[str] = Field(default='timer', description="触发类型:timer-定时触发 event-事件触发 manual-手动触发")
|
||||
event_type: Optional[str] = Field(default=None, description="事件类型(当trigger_type为event时使用)")
|
||||
event_conditions: Optional[dict] = Field(default={}, description="事件条件(JSON格式,用于过滤事件)")
|
||||
event_conditions: Optional[dict] = Field(default_factory=dict, description="事件条件(JSON格式,用于过滤事件)")
|
||||
state: Optional[str] = Field(default=None, description="状态")
|
||||
current_action: Optional[str] = Field(default=None, description="已执行动作")
|
||||
result: Optional[str] = Field(default=None, description="任务执行结果")
|
||||
run_count: Optional[int] = Field(default=0, description="已执行次数")
|
||||
actions: Optional[list] = Field(default=[], description="任务列表")
|
||||
flows: Optional[list] = Field(default=[], description="任务流")
|
||||
actions: Optional[list] = Field(default_factory=list, description="任务列表")
|
||||
flows: Optional[list] = Field(default_factory=list, description="任务流")
|
||||
add_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
last_time: Optional[str] = Field(default=None, description="最后执行时间")
|
||||
|
||||
@@ -48,8 +48,8 @@ class Action(BaseModel):
|
||||
type: Optional[str] = Field(default=None, description="动作类型 (类名)")
|
||||
name: Optional[str] = Field(default=None, description="动作名称")
|
||||
description: Optional[str] = Field(default=None, description="动作描述")
|
||||
position: Optional[dict] = Field(default={}, description="位置")
|
||||
data: Optional[dict] = Field(default={}, description="参数")
|
||||
position: Optional[dict] = Field(default_factory=dict, description="位置")
|
||||
data: Optional[dict] = Field(default_factory=dict, description="参数")
|
||||
|
||||
|
||||
class ActionExecution(BaseModel):
|
||||
@@ -66,13 +66,13 @@ class ActionContext(BaseModel):
|
||||
动作基础上下文,各动作通用数据
|
||||
"""
|
||||
content: Optional[str] = Field(default=None, description="文本类内容")
|
||||
torrents: Optional[List[Context]] = Field(default=[], description="资源列表")
|
||||
medias: Optional[List[MediaInfo]] = Field(default=[], description="媒体列表")
|
||||
fileitems: Optional[List[FileItem]] = Field(default=[], description="文件列表")
|
||||
downloads: Optional[List[DownloadTask]] = Field(default=[], description="下载任务列表")
|
||||
sites: Optional[List[Site]] = Field(default=[], description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field(default=[], description="订阅列表")
|
||||
execute_history: Optional[List[ActionExecution]] = Field(default=[], description="执行历史")
|
||||
torrents: Optional[List[Context]] = Field(default_factory=list, description="资源列表")
|
||||
medias: Optional[List[MediaInfo]] = Field(default_factory=list, description="媒体列表")
|
||||
fileitems: Optional[List[FileItem]] = Field(default_factory=list, description="文件列表")
|
||||
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
|
||||
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
|
||||
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
|
||||
progress: Optional[int] = Field(default=0, description="执行进度(%)")
|
||||
|
||||
|
||||
|
||||
@@ -63,9 +63,18 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
停止
|
||||
"""
|
||||
for event_type_str in list(self._event_workflows.keys()):
|
||||
self.remove_workflow_event(event_type_str=event_type_str)
|
||||
self._actions = {}
|
||||
self._event_workflows = {}
|
||||
|
||||
def execute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
"""
|
||||
执行工作流动作
|
||||
"""
|
||||
return self.excute(workflow_id=workflow_id, action=action, context=context)
|
||||
|
||||
def excute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
"""
|
||||
@@ -126,8 +135,8 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
更新工作流事件触发器
|
||||
"""
|
||||
# 确保先移除旧的事件监听器
|
||||
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
|
||||
# 工作流可能切换触发事件,先按工作流ID从所有事件映射中移除。
|
||||
self.remove_workflow_event(workflow_id=workflow.id)
|
||||
# 如果工作流是事件触发类型且未被禁用
|
||||
if workflow.trigger_type == "event" and workflow.state != 'P':
|
||||
# 注册事件触发器
|
||||
@@ -154,41 +163,46 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
注册工作流事件触发器
|
||||
"""
|
||||
if not event_type_str:
|
||||
return
|
||||
try:
|
||||
event_type = EventType(event_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_str}")
|
||||
return
|
||||
if event_type in EventType:
|
||||
# 确保先移除旧的事件监听器
|
||||
self.remove_workflow_event(workflow_id, event_type.value)
|
||||
with self._lock:
|
||||
# 添加新的事件监听器
|
||||
eventmanager.add_event_listener(event_type, self._handle_event)
|
||||
# 记录工作流事件触发器
|
||||
if event_type.value not in self._event_workflows:
|
||||
self._event_workflows[event_type.value] = []
|
||||
self._event_workflows[event_type.value].append(workflow_id)
|
||||
eventmanager.add_event_listener(event_type, self._handle_event)
|
||||
# 记录工作流事件触发器
|
||||
if workflow_id not in self._event_workflows[event_type.value]:
|
||||
self._event_workflows[event_type.value].append(workflow_id)
|
||||
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
|
||||
|
||||
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
|
||||
def remove_workflow_event(self, workflow_id: Optional[int] = None, event_type_str: Optional[str] = None):
|
||||
"""
|
||||
移除工作流事件触发器
|
||||
"""
|
||||
try:
|
||||
event_type = EventType(event_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_str}")
|
||||
return
|
||||
if event_type in EventType:
|
||||
event_type_values = [event_type_str] if event_type_str else list(self._event_workflows.keys())
|
||||
for event_type_value in event_type_values:
|
||||
try:
|
||||
event_type = EventType(event_type_value)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_value}")
|
||||
continue
|
||||
with self._lock:
|
||||
eventmanager.remove_event_listener(event_type, self._handle_event)
|
||||
if event_type.value in self._event_workflows:
|
||||
if workflow_id in self._event_workflows[event_type.value]:
|
||||
self._event_workflows[event_type.value].remove(workflow_id)
|
||||
if not self._event_workflows[event_type.value]:
|
||||
del self._event_workflows[event_type.value]
|
||||
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
|
||||
workflow_ids = self._event_workflows.get(event_type.value)
|
||||
if not workflow_ids:
|
||||
continue
|
||||
if workflow_id is None:
|
||||
workflow_ids.clear()
|
||||
elif workflow_id in workflow_ids:
|
||||
workflow_ids.remove(workflow_id)
|
||||
if not workflow_ids:
|
||||
self._event_workflows.pop(event_type.value, None)
|
||||
eventmanager.remove_event_listener(event_type, self._handle_event)
|
||||
logger.info(f"已移除工作流 {workflow_id or ''} 事件触发器")
|
||||
|
||||
def _handle_event(self, event: Event):
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,8 @@ class BaseAction(ABC):
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
self._action_id = action_id
|
||||
self._done_flag = False
|
||||
self._message = ""
|
||||
self.systemconfigoper = SystemConfigOper()
|
||||
|
||||
@classmethod
|
||||
@@ -92,9 +94,12 @@ class BaseAction(ABC):
|
||||
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
|
||||
action_cache = workflow_cache.get(self._action_id) or []
|
||||
if isinstance(data, list):
|
||||
action_cache.extend(data)
|
||||
for item in data:
|
||||
if item not in action_cache:
|
||||
action_cache.append(item)
|
||||
else:
|
||||
action_cache.append(data)
|
||||
if data not in action_cache:
|
||||
action_cache.append(data)
|
||||
workflow_cache[self._action_id] = action_cache
|
||||
self.systemconfigoper.set(workflow_key, workflow_cache)
|
||||
|
||||
|
||||
@@ -43,12 +43,19 @@ class FetchDownloadsAction(BaseAction):
|
||||
"""
|
||||
更新downloads中的下载任务状态
|
||||
"""
|
||||
__all_complete = False
|
||||
self._downloads = context.downloads or []
|
||||
if not self._downloads:
|
||||
self.job_done("无下载任务")
|
||||
return context
|
||||
|
||||
for download in self._downloads:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
|
||||
torrents = ActionChain().list_torrents(hashs=[download.download_id])
|
||||
torrents = ActionChain().list_torrents(
|
||||
hashs=[download.download_id],
|
||||
downloader=download.downloader,
|
||||
)
|
||||
if not torrents:
|
||||
download.completed = True
|
||||
continue
|
||||
@@ -61,5 +68,5 @@ class FetchDownloadsAction(BaseAction):
|
||||
logger.info(f"下载任务 {download.download_id} 未完成")
|
||||
download.completed = False
|
||||
if all([d.completed for d in self._downloads]):
|
||||
self.job_done()
|
||||
self.job_done("下载任务已全部完成")
|
||||
return context
|
||||
|
||||
@@ -61,18 +61,18 @@ class ScrapeFileAction(BaseAction):
|
||||
logger.info(f"{fileitem.path} 已刮削过,跳过")
|
||||
continue
|
||||
mediachain = MediaChain()
|
||||
context = mediachain.recognize_by_path(
|
||||
media_context = mediachain.recognize_by_path(
|
||||
fileitem.path,
|
||||
obtain_images=True,
|
||||
)
|
||||
if not context or not context.media_info:
|
||||
if not media_context or not media_context.media_info:
|
||||
_failed_count += 1
|
||||
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
|
||||
continue
|
||||
mediachain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
meta=context.meta_info,
|
||||
mediainfo=context.media_info
|
||||
meta=media_context.meta_info,
|
||||
mediainfo=media_context.media_info
|
||||
)
|
||||
self._scraped_files.append(fileitem)
|
||||
# 保存缓存
|
||||
|
||||
80
tests/test_workflow_actions.py
Normal file
80
tests/test_workflow_actions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.schemas import ActionContext, DownloadTask, FileItem
|
||||
from app.workflow.actions import fetch_downloads as fetch_downloads_module
|
||||
from app.workflow.actions import scrape_file as scrape_file_module
|
||||
from app.workflow.actions.fetch_downloads import FetchDownloadsAction
|
||||
from app.workflow.actions.scrape_file import ScrapeFileAction
|
||||
|
||||
|
||||
def test_fetch_downloads_updates_context_downloads(monkeypatch):
|
||||
"""获取下载任务动作应更新上游上下文中的下载任务。"""
|
||||
calls = []
|
||||
|
||||
class FakeActionChain:
|
||||
"""模拟下载器查询链。"""
|
||||
|
||||
def list_torrents(self, hashs=None, downloader=None, **kwargs):
|
||||
calls.append((hashs, downloader))
|
||||
return [SimpleNamespace(path="/downloads/movie.mkv", progress=100)]
|
||||
|
||||
monkeypatch.setattr(fetch_downloads_module, "ActionChain", FakeActionChain)
|
||||
monkeypatch.setattr(fetch_downloads_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
context = ActionContext(
|
||||
downloads=[
|
||||
DownloadTask(download_id="hash-1", downloader="qbittorrent"),
|
||||
]
|
||||
)
|
||||
|
||||
result = FetchDownloadsAction("fetch-downloads").execute(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert calls == [(["hash-1"], "qbittorrent")]
|
||||
assert result.downloads[0].completed is True
|
||||
assert result.downloads[0].path == "/downloads/movie.mkv"
|
||||
|
||||
|
||||
def test_scrape_file_keeps_workflow_action_context(monkeypatch):
|
||||
"""刮削文件动作不应将工作流上下文替换为媒体识别上下文。"""
|
||||
scraped = []
|
||||
|
||||
class FakeStorageChain:
|
||||
"""模拟存储链。"""
|
||||
|
||||
def exists(self, fileitem):
|
||||
return True
|
||||
|
||||
class FakeMediaChain:
|
||||
"""模拟媒体识别和刮削链。"""
|
||||
|
||||
def recognize_by_path(self, path, obtain_images=False):
|
||||
return SimpleNamespace(meta_info="meta", media_info="media")
|
||||
|
||||
def scrape_metadata(self, fileitem, meta=None, mediainfo=None):
|
||||
scraped.append((fileitem.path, meta, mediainfo))
|
||||
|
||||
monkeypatch.setattr(scrape_file_module, "StorageChain", FakeStorageChain)
|
||||
monkeypatch.setattr(scrape_file_module, "MediaChain", FakeMediaChain)
|
||||
monkeypatch.setattr(scrape_file_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
monkeypatch.setattr(ScrapeFileAction, "check_cache", lambda self, workflow_id, key: False)
|
||||
monkeypatch.setattr(ScrapeFileAction, "save_cache", lambda self, workflow_id, data: None)
|
||||
|
||||
context = ActionContext(
|
||||
fileitems=[
|
||||
FileItem(path="/library/movie.mkv", storage="local", type="file"),
|
||||
]
|
||||
)
|
||||
|
||||
result = ScrapeFileAction("scrape-file").execute(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result is context
|
||||
assert result.fileitems[0].path == "/library/movie.mkv"
|
||||
assert scraped == [("/library/movie.mkv", "meta", "media")]
|
||||
153
tests/test_workflow_execution.py
Normal file
153
tests/test_workflow_execution.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import base64
|
||||
import pickle
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.chain import workflow as workflow_module
|
||||
from app.schemas import ActionContext
|
||||
from app.schemas.types import EventType
|
||||
from app import workflow as workflow_package
|
||||
|
||||
|
||||
def _build_workflow(current_action=None, context=None):
|
||||
"""构造最小工作流对象。"""
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
name="测试工作流",
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
|
||||
],
|
||||
current_action=current_action,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def _encoded_context(context: ActionContext) -> dict:
|
||||
"""编码工作流恢复上下文。"""
|
||||
return {
|
||||
"content": base64.b64encode(pickle.dumps(context)).decode("utf-8"),
|
||||
}
|
||||
|
||||
|
||||
class _FakeWorkflowManager:
|
||||
"""记录执行动作的工作流管理器。"""
|
||||
|
||||
def __init__(self, calls):
|
||||
self.calls = calls
|
||||
|
||||
def excute(self, workflow_id, action, context=None):
|
||||
self.calls.append(action.id)
|
||||
return True, f"{action.name}完成", context or ActionContext()
|
||||
|
||||
|
||||
def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
|
||||
"""恢复执行时应释放已完成节点的后继节点。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
current_action="A",
|
||||
context=_encoded_context(ActionContext()),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.success is True
|
||||
assert executor.context.progress == 100
|
||||
|
||||
|
||||
def test_workflow_executor_reports_incremental_progress(monkeypatch):
|
||||
"""顺序工作流的中间进度应按已完成比例计算。"""
|
||||
calls = []
|
||||
progresses = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(
|
||||
_build_workflow(),
|
||||
step_callback=lambda action, context: progresses.append(context.progress),
|
||||
)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "B"]
|
||||
assert progresses == [50, 100]
|
||||
|
||||
|
||||
def test_workflow_executor_stop_is_not_success(monkeypatch):
|
||||
"""停止信号不应被执行器汇报为成功完成。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: True)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(_build_workflow())
|
||||
executor.execute()
|
||||
|
||||
assert calls == []
|
||||
assert executor.stopped is True
|
||||
assert executor.success is False
|
||||
assert executor.errmsg == "工作流已停止"
|
||||
|
||||
|
||||
def test_workflow_context_merge_preserves_runtime_objects():
|
||||
"""合并上下文时应保留运行时对象,而不是转成字典。"""
|
||||
executor = object.__new__(workflow_module.WorkflowExecutor)
|
||||
executor.context = ActionContext()
|
||||
runtime_torrent = SimpleNamespace(title="runtime torrent")
|
||||
result_context = ActionContext()
|
||||
result_context.torrents.append(runtime_torrent)
|
||||
|
||||
executor.merge_context(result_context)
|
||||
|
||||
assert executor.context.torrents[0] is runtime_torrent
|
||||
|
||||
|
||||
class _FakeEventManager:
|
||||
"""记录事件监听器注册和移除次数。"""
|
||||
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.removed = []
|
||||
|
||||
def add_event_listener(self, event_type, handler):
|
||||
self.added.append(event_type)
|
||||
|
||||
def remove_event_listener(self, event_type, handler):
|
||||
self.removed.append(event_type)
|
||||
|
||||
|
||||
def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkeypatch):
|
||||
"""同一事件下移除单个工作流时不应断开其他工作流监听。"""
|
||||
fake_eventmanager = _FakeEventManager()
|
||||
manager = object.__new__(workflow_package.WorkFlowManager)
|
||||
manager._lock = threading.Lock()
|
||||
manager._event_workflows = {}
|
||||
|
||||
monkeypatch.setattr(workflow_package, "eventmanager", fake_eventmanager)
|
||||
|
||||
manager.register_workflow_event(1, EventType.DownloadAdded.value)
|
||||
manager.register_workflow_event(2, EventType.DownloadAdded.value)
|
||||
manager.remove_workflow_event(1, EventType.DownloadAdded.value)
|
||||
|
||||
assert fake_eventmanager.added == [EventType.DownloadAdded]
|
||||
assert fake_eventmanager.removed == []
|
||||
assert manager.get_event_workflows() == {EventType.DownloadAdded.value: [2]}
|
||||
|
||||
manager.remove_workflow_event(2, EventType.DownloadAdded.value)
|
||||
|
||||
assert fake_eventmanager.removed == [EventType.DownloadAdded]
|
||||
assert manager.get_event_workflows() == {}
|
||||
Reference in New Issue
Block a user