Files
archived-MoviePilot/app/workflow/actions/__init__.py

207 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from abc import ABC, abstractmethod
from typing import Any, Union
from app.chain import ChainBase
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import ActionContext, ActionParams, ActionResult
class ActionChain(ChainBase):
pass
class BaseAction(ABC):
"""
工作流动作基类
"""
# 动作ID
_action_id = None
# 完成标志
_done_flag = False
# 执行信息
_message = ""
# 缓存键值
_cache_key = "WorkflowCache-%s"
# 动作输入输出契约,由具体动作按需覆盖
contract = {}
def __init__(self, action_id: str):
self._action_id = action_id
self._done_flag = False
self._message = ""
self.systemconfigoper = SystemConfigOper()
@classmethod
@property
@abstractmethod
def name(cls) -> str: # noqa
pass
@classmethod
@property
@abstractmethod
def description(cls) -> str: # noqa
pass
@classmethod
@property
@abstractmethod
def data(cls) -> dict: # noqa
pass
@classmethod
def get_contract(cls) -> dict:
"""
获取动作输入输出契约。
"""
contract = getattr(cls, "contract", None) or {}
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
return {
"inputs": input_fields,
"outputs": output_fields,
"condition_fields": output_fields,
"concurrency_key": contract.get("concurrency_key"),
}
@classmethod
def _build_contract_fields(cls, fields: list) -> list:
"""
标准化动作契约字段。
"""
result = []
for field in fields:
if isinstance(field, str):
field = {"name": field}
if not isinstance(field, dict) or not field.get("name"):
continue
result.append({
"name": field["name"],
"label": field.get("label") or field["name"],
"kind": field.get("kind") or "scalar",
"merge": field.get("merge"),
"identity": field.get("identity"),
})
return result
@property
def done(self) -> bool:
"""
判断动作是否完成
"""
return self._done_flag
@property
@abstractmethod
def success(self) -> bool:
"""
判断动作是否成功
"""
pass
@property
def message(self) -> str:
"""
执行信息
"""
return self._message
def job_done(self, message: str = None):
"""
标记动作完成
"""
self._message = message
self._done_flag = True
def check_cache(self, workflow_id: int, key: str) -> bool:
"""
检查是否处理过
"""
workflow_key = self._cache_key % workflow_id
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
action_cache = workflow_cache.get(self._action_id) or []
return key in action_cache
def save_cache(self, workflow_id: int, data: Union[list, str]):
"""
保存缓存
"""
workflow_key = self._cache_key % workflow_id
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
action_cache = workflow_cache.get(self._action_id) or []
if isinstance(data, list):
for item in data:
if item not in action_cache:
action_cache.append(item)
else:
if data not in action_cache:
action_cache.append(data)
workflow_cache[self._action_id] = action_cache
self.systemconfigoper.set(workflow_key, workflow_cache)
@abstractmethod
def execute(self, workflow_id: int, params: ActionParams, context: ActionContext) -> ActionContext:
"""
执行动作
"""
raise NotImplementedError
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
runtime: dict, context: ActionContext) -> ActionResult:
"""
使用显式输入与运行期信息执行动作。
"""
self._apply_inputs_to_context(inputs=inputs, context=context)
self._apply_runtime_to_context(runtime=runtime, context=context)
result_context = self.execute(workflow_id, params, context)
outputs = self._extract_outputs_from_context(result_context)
return ActionResult(
success=self.success,
message=self.message,
context=result_context,
outputs=outputs
)
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
"""
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
"""
inputs = inputs or {}
for field in self.get_contract().get("inputs") or []:
missing = object()
field_name = field["name"]
value = inputs.get(field_name, missing)
if value is missing:
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
for input_key, input_value in inputs.items():
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
value = input_value
break
if value is not missing:
setattr(context, field_name, value)
@staticmethod
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
"""
将运行期信息写入 runtime_state供动作和执行状态读取。
"""
if not runtime:
return
context.runtime_state = context.runtime_state or {}
context.runtime_state["current_action_runtime"] = {
key: value for key, value in runtime.items()
if key != "cancel_token"
}
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
"""
按动作契约从上下文提取输出。
"""
outputs = {}
for field in self.get_contract().get("outputs") or []:
value = getattr(context, field["name"], None)
if value not in (None, "", [], {}):
outputs[field["name"]] = value
return outputs