mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 07:26:48 +00:00
207 lines
6.4 KiB
Python
207 lines
6.4 KiB
Python
from abc import ABC, abstractmethod
|
||
from typing import Any, Union
|
||
|
||
from app.chain import ChainBase
|
||
from app.db.systemconfig_oper import SystemConfigOper
|
||
from app.schemas import ActionContext, ActionParams, ActionResult
|
||
|
||
|
||
class ActionChain(ChainBase):
|
||
pass
|
||
|
||
|
||
class BaseAction(ABC):
|
||
"""
|
||
工作流动作基类
|
||
"""
|
||
|
||
# 动作ID
|
||
_action_id = None
|
||
# 完成标志
|
||
_done_flag = False
|
||
# 执行信息
|
||
_message = ""
|
||
# 缓存键值
|
||
_cache_key = "WorkflowCache-%s"
|
||
# 动作输入输出契约,由具体动作按需覆盖
|
||
contract = {}
|
||
|
||
def __init__(self, action_id: str):
|
||
self._action_id = action_id
|
||
self._done_flag = False
|
||
self._message = ""
|
||
self.systemconfigoper = SystemConfigOper()
|
||
|
||
@classmethod
|
||
@property
|
||
@abstractmethod
|
||
def name(cls) -> str: # noqa
|
||
pass
|
||
|
||
@classmethod
|
||
@property
|
||
@abstractmethod
|
||
def description(cls) -> str: # noqa
|
||
pass
|
||
|
||
@classmethod
|
||
@property
|
||
@abstractmethod
|
||
def data(cls) -> dict: # noqa
|
||
pass
|
||
|
||
@classmethod
|
||
def get_contract(cls) -> dict:
|
||
"""
|
||
获取动作输入输出契约。
|
||
"""
|
||
contract = getattr(cls, "contract", None) or {}
|
||
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
|
||
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
|
||
return {
|
||
"inputs": input_fields,
|
||
"outputs": output_fields,
|
||
"condition_fields": output_fields,
|
||
"concurrency_key": contract.get("concurrency_key"),
|
||
}
|
||
|
||
@classmethod
|
||
def _build_contract_fields(cls, fields: list) -> list:
|
||
"""
|
||
标准化动作契约字段。
|
||
"""
|
||
result = []
|
||
for field in fields:
|
||
if isinstance(field, str):
|
||
field = {"name": field}
|
||
if not isinstance(field, dict) or not field.get("name"):
|
||
continue
|
||
result.append({
|
||
"name": field["name"],
|
||
"label": field.get("label") or field["name"],
|
||
"kind": field.get("kind") or "scalar",
|
||
"merge": field.get("merge"),
|
||
"identity": field.get("identity"),
|
||
})
|
||
return result
|
||
|
||
@property
|
||
def done(self) -> bool:
|
||
"""
|
||
判断动作是否完成
|
||
"""
|
||
return self._done_flag
|
||
|
||
@property
|
||
@abstractmethod
|
||
def success(self) -> bool:
|
||
"""
|
||
判断动作是否成功
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
def message(self) -> str:
|
||
"""
|
||
执行信息
|
||
"""
|
||
return self._message
|
||
|
||
def job_done(self, message: str = None):
|
||
"""
|
||
标记动作完成
|
||
"""
|
||
self._message = message
|
||
self._done_flag = True
|
||
|
||
def check_cache(self, workflow_id: int, key: str) -> bool:
|
||
"""
|
||
检查是否处理过
|
||
"""
|
||
workflow_key = self._cache_key % workflow_id
|
||
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
|
||
action_cache = workflow_cache.get(self._action_id) or []
|
||
return key in action_cache
|
||
|
||
def save_cache(self, workflow_id: int, data: Union[list, str]):
|
||
"""
|
||
保存缓存
|
||
"""
|
||
workflow_key = self._cache_key % workflow_id
|
||
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
|
||
action_cache = workflow_cache.get(self._action_id) or []
|
||
if isinstance(data, list):
|
||
for item in data:
|
||
if item not in action_cache:
|
||
action_cache.append(item)
|
||
else:
|
||
if data not in action_cache:
|
||
action_cache.append(data)
|
||
workflow_cache[self._action_id] = action_cache
|
||
self.systemconfigoper.set(workflow_key, workflow_cache)
|
||
|
||
@abstractmethod
|
||
def execute(self, workflow_id: int, params: ActionParams, context: ActionContext) -> ActionContext:
|
||
"""
|
||
执行动作
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
|
||
runtime: dict, context: ActionContext) -> ActionResult:
|
||
"""
|
||
使用显式输入与运行期信息执行动作。
|
||
"""
|
||
self._apply_inputs_to_context(inputs=inputs, context=context)
|
||
self._apply_runtime_to_context(runtime=runtime, context=context)
|
||
result_context = self.execute(workflow_id, params, context)
|
||
outputs = self._extract_outputs_from_context(result_context)
|
||
return ActionResult(
|
||
success=self.success,
|
||
message=self.message,
|
||
context=result_context,
|
||
outputs=outputs
|
||
)
|
||
|
||
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
|
||
"""
|
||
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
|
||
"""
|
||
inputs = inputs or {}
|
||
for field in self.get_contract().get("inputs") or []:
|
||
missing = object()
|
||
field_name = field["name"]
|
||
value = inputs.get(field_name, missing)
|
||
if value is missing:
|
||
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
|
||
for input_key, input_value in inputs.items():
|
||
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
|
||
value = input_value
|
||
break
|
||
if value is not missing:
|
||
setattr(context, field_name, value)
|
||
|
||
@staticmethod
|
||
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
|
||
"""
|
||
将运行期信息写入 runtime_state,供动作和执行状态读取。
|
||
"""
|
||
if not runtime:
|
||
return
|
||
context.runtime_state = context.runtime_state or {}
|
||
context.runtime_state["current_action_runtime"] = {
|
||
key: value for key, value in runtime.items()
|
||
if key != "cancel_token"
|
||
}
|
||
|
||
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
|
||
"""
|
||
按动作契约从上下文提取输出。
|
||
"""
|
||
outputs = {}
|
||
for field in self.get_contract().get("outputs") or []:
|
||
value = getattr(context, field["name"], None)
|
||
if value not in (None, "", [], {}):
|
||
outputs[field["name"]] = value
|
||
return outputs
|