mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-13 23:16:46 +00:00
Compare commits
99 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48b1ac28de | ||
|
|
6e329b17a9 | ||
|
|
6a492198a8 | ||
|
|
8bf9b6e7cb | ||
|
|
42e23ef564 | ||
|
|
c6806ee648 | ||
|
|
076fae696c | ||
|
|
ed294d3ea4 | ||
|
|
043be409d0 | ||
|
|
a5e7483870 | ||
|
|
365335be46 | ||
|
|
62543dd171 | ||
|
|
e2eef8ff21 | ||
|
|
3acf937d56 | ||
|
|
d572e523ba | ||
|
|
82113abe88 | ||
|
|
b7d121c58f | ||
|
|
6d5a85b144 | ||
|
|
78121917c6 | ||
|
|
a0913f0e32 | ||
|
|
e96e284715 | ||
|
|
c572a1b607 | ||
|
|
1845311f98 | ||
|
|
4f806db8b7 | ||
|
|
22858cc1e9 | ||
|
|
a0329a3eb0 | ||
|
|
b3e92088ee | ||
|
|
46db1c20f1 | ||
|
|
9d182e53b2 | ||
|
|
1205fc7fdb | ||
|
|
ff2826a448 | ||
|
|
ee750115ec | ||
|
|
0e13d22c97 | ||
|
|
8e7d040ac4 | ||
|
|
6755202958 | ||
|
|
8b7374a687 | ||
|
|
c17cca2365 | ||
|
|
8016a9539a | ||
|
|
e885fb15a0 | ||
|
|
c7f098771b | ||
|
|
fcd0908032 | ||
|
|
7ff1285084 | ||
|
|
b45b603b97 | ||
|
|
247208b8a9 | ||
|
|
182c46037b | ||
|
|
438d3210bc | ||
|
|
d523c7c916 | ||
|
|
09a19e94d5 | ||
|
|
3971c145df | ||
|
|
055117d83d | ||
|
|
c6baf43986 | ||
|
|
4ff16af3a7 | ||
|
|
17a1bd352b | ||
|
|
7421ca09cc | ||
|
|
9797e696e5 | ||
|
|
c36d6d8b2d | ||
|
|
3873786b99 | ||
|
|
76fdba7f09 | ||
|
|
72799e9638 | ||
|
|
2e77d03fe9 | ||
|
|
0c58eae5e7 | ||
|
|
b609567c38 | ||
|
|
7ecfa44fa0 | ||
|
|
a685b1dc3b | ||
|
|
63ce49a17c | ||
|
|
820fbe4076 | ||
|
|
efa05b7775 | ||
|
|
003781e903 | ||
|
|
ee71bafc96 | ||
|
|
bdd5f1231e | ||
|
|
6fee532c96 | ||
|
|
78aaad7b59 | ||
|
|
b128b0ede2 | ||
|
|
737d2f3bc6 | ||
|
|
179be53a65 | ||
|
|
1867f5e7c2 | ||
|
|
6662d24565 | ||
|
|
5880566a99 | ||
|
|
5d05b32711 | ||
|
|
fa2b720e92 | ||
|
|
d381238f83 | ||
|
|
751d627ead | ||
|
|
3e66a8de9b | ||
|
|
266052b12b | ||
|
|
803f4328f4 | ||
|
|
8e95568e11 | ||
|
|
ab09ee4819 | ||
|
|
41f94a172f | ||
|
|
566e597994 | ||
|
|
765fb9c05f | ||
|
|
b6720a19f7 | ||
|
|
3b130651c4 | ||
|
|
3f6c35dabe | ||
|
|
db2a952bca | ||
|
|
0ea9770bc3 | ||
|
|
0b20956c90 | ||
|
|
9f73b47d54 | ||
|
|
ce9c99af71 | ||
|
|
784024fb5d |
@@ -40,10 +40,11 @@ git clone https://github.com/jxxghp/MoviePilot
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,设置`app`为源代码根目录,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 main.py
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
|
||||
355
app/agent/__init__.py
Normal file
355
app/agent/__init__.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
|
||||
return agent_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
agent.source = source
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
33
app/agent/callback/__init__.py
Normal file
33
app/agent/callback/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import threading
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
280
app/agent/memory/__init__.py
Normal file
280
app/agent/memory/__init__.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 每小时清理一次
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
current_time = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
for cache_key in expired_sessions:
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
70
app/agent/prompt/Agent Prompt.txt
Normal file
70
app/agent/prompt/Agent Prompt.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
|
||||
## Working Principles
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
|
||||
## Common Operation Workflows
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
118
app/agent/prompt/__init__.py
Normal file
118
app/agent/prompt/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
31
app/agent/tools/__init__.py
Normal file
31
app/agent/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""MoviePilot工具模块"""
|
||||
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from .factory import MoviePilotToolFactory
|
||||
|
||||
__all__ = [
|
||||
"MoviePilotTool",
|
||||
"SearchMediaTool",
|
||||
"AddSubscribeTool",
|
||||
"SearchTorrentsTool",
|
||||
"AddDownloadTool",
|
||||
"QuerySubscribesTool",
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"QuerySitesTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
"MoviePilotToolFactory"
|
||||
]
|
||||
73
app/agent/tools/base.py
Normal file
73
app/agent/tools/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""MoviePilot工具基类"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.chain import ChainBase
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
await self.send_tool_message(f"▶️️{explanation}")
|
||||
return await self.run(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
),
|
||||
escape_markdown=False
|
||||
)
|
||||
84
app/agent/tools/factory.py
Normal file
84
app/agent/tools/factory.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
AddSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
for plugin_info in plugin_tools_info:
|
||||
plugin_id = plugin_info.get("plugin_id")
|
||||
plugin_name = plugin_info.get("plugin_name")
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
0
app/agent/tools/impl/__init__.py
Normal file
0
app/agent/tools/impl/__init__.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
60
app/agent/tools/impl/add_subscribe.py
Normal file
60
app/agent/tools/impl/add_subscribe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id
|
||||
)
|
||||
if sid:
|
||||
return f"成功添加订阅:{title} ({year})"
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加订阅失败: {e}", exc_info=True)
|
||||
return f"添加订阅时发生错误: {str(e)}"
|
||||
84
app/agent/tools/impl/get_recommendations.py
Normal file
84
app/agent/tools/impl/get_recommendations.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""获取推荐工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
name_dicts = {
|
||||
"tmdb_trending": "TMDB 热门推荐",
|
||||
"douban_hot": "豆瓣热门推荐",
|
||||
"bangumi_calendar": "番组计划推荐"
|
||||
}
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
results = await recommend_chain.async_tmdb_trending(limit=limit)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(limit=limit)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(limit=limit)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(limit=limit)
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
# r 已经是字典格式(to_dict的结果)
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"overview": r.get("overview", "")[:200] + "..." if r.get("overview") and len(r.get("overview", "")) > 200 else r.get("overview"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
return f"获取推荐时发生错误: {str(e)}"
|
||||
34
app/agent/tools/impl/query_downloaders.py
Normal file
34
app/agent/tools/impl/query_downloaders.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryDownloadersInput(BaseModel):
|
||||
"""查询下载器工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
|
||||
if downloaders_config:
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载器失败: {e}")
|
||||
return f"查询下载器时发生错误: {str(e)}"
|
||||
80
app/agent/tools/impl/query_downloads.py
Normal file
80
app/agent/tools/impl/query_downloads.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDownloadsInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadsInput
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
# 使用 DownloadChain.downloading 方法获取正在下载的任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
if status != "all" and dl.status != status:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
41
app/agent/tools/impl/query_media_library.py
Normal file
41
app/agent/tools/impl/query_media_library.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class QueryMediaLibraryInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryMediaLibraryInput
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_server_oper = MediaServerOper()
|
||||
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.to_dict() for m in filtered_medias])
|
||||
return "媒体库中未找到相关媒体"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
66
app/agent/tools/impl/query_sites.py
Normal file
66
app/agent/tools/impl/query_sites.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
# 获取所有站点(按优先级排序)
|
||||
sites = await site_oper.async_list()
|
||||
filtered_sites = []
|
||||
for site in sites:
|
||||
# 按状态过滤
|
||||
if status == "active" and not site.is_active:
|
||||
continue
|
||||
if status == "inactive" and site.is_active:
|
||||
continue
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and name.lower() not in (site.name or "").lower():
|
||||
continue
|
||||
filtered_sites.append(site)
|
||||
if filtered_sites:
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_sites = []
|
||||
for s in filtered_sites:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"domain": s.domain,
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关站点"
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
73
app/agent/tools/impl/query_subscribes.py
Normal file
73
app/agent/tools/impl/query_subscribes.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""查询订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
return f"查询订阅时发生错误: {str(e)}"
|
||||
96
app/agent/tools/impl/search_media.py
Normal file
96
app/agent/tools/impl/search_media.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""搜索媒体工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
"""搜索媒体工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 构建搜索标题
|
||||
search_title = title
|
||||
if year:
|
||||
search_title = f"{title} {year}"
|
||||
if media_type:
|
||||
search_title = f"{search_title} {media_type}"
|
||||
if season:
|
||||
search_title = f"{search_title} S{season:02d}"
|
||||
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=search_title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
simplified = {
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
"douban_id": r.douban_id,
|
||||
"overview": r.overview[:200] + "..." if r.overview and len(r.overview) > 200 else r.overview,
|
||||
"vote_average": r.vote_average,
|
||||
"poster_path": r.poster_path,
|
||||
"detail_link": r.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
else:
|
||||
return f"未找到相关媒体资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索媒体失败: {str(e)}"
|
||||
logger.error(f"搜索媒体失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
122
app/agent/tools/impl/search_torrents.py
Normal file
122
app/agent/tools/impl/search_torrents.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
31
app/agent/tools/impl/send_message.py
Normal file
31
app/agent/tools/impl/send_message.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return f"发送消息时发生错误: {str(e)}"
|
||||
@@ -137,7 +137,7 @@ async def transfer(days: Optional[int] = 7,
|
||||
return [stat[1] for stat in transfer_stat]
|
||||
|
||||
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=int)
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=float)
|
||||
def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率
|
||||
@@ -145,7 +145,7 @@ def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.cpu_usage()
|
||||
|
||||
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=int)
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=float)
|
||||
def cpu2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率 API_TOKEN认证(?token=xxx)
|
||||
|
||||
@@ -40,10 +40,10 @@ def download(
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 手动下载始终使用选择的下载器
|
||||
torrentinfo.site_downloader = downloader
|
||||
# 上下文
|
||||
@@ -64,6 +64,9 @@ def download(
|
||||
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
|
||||
def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
bangumiid: Annotated[int | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
@@ -73,12 +76,12 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 上下文
|
||||
context = Context(
|
||||
meta_info=metainfo,
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType, MediaType
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -70,7 +70,7 @@ async def transfer_history(title: Optional[str] = None,
|
||||
|
||||
return schemas.Response(success=True,
|
||||
data={
|
||||
"list": result,
|
||||
"list": [item.to_dict() for item in result],
|
||||
"total": total,
|
||||
})
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def exists(media_in: schemas.MediaInfo,
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
@@ -108,7 +108,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
meta.year = media_in.year
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
|
||||
@@ -132,7 +132,7 @@ async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload
|
||||
"""
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.dict()
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
@@ -148,7 +148,7 @@ def send_notification(payload: schemas.SubscriptionMessage, _: schemas.TokenPayl
|
||||
try:
|
||||
webpush(
|
||||
subscription_info=sub,
|
||||
data=json.dumps(payload.dict()),
|
||||
data=json.dumps(payload.model_dump()),
|
||||
vapid_private_key=settings.VAPID.get("privateKey"),
|
||||
vapid_claims={
|
||||
"sub": settings.VAPID.get("subject")
|
||||
|
||||
@@ -67,7 +67,7 @@ async def add_site(
|
||||
site_in.name = site_info.get("name")
|
||||
site_in.id = None
|
||||
site_in.public = 1 if site_info.get("public") else 0
|
||||
site = Site(**site_in.dict())
|
||||
site = Site(**site_in.model_dump())
|
||||
site.create(db)
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
@@ -92,7 +92,7 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
await site.async_update(db, site_in.dict())
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
@@ -399,7 +399,7 @@ def auth_site(
|
||||
if not auth_info or not auth_info.site or not auth_info.params:
|
||||
return schemas.Response(success=False, message="请输入认证站点和认证参数")
|
||||
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.model_dump())
|
||||
# 认证成功后,重新初始化插件
|
||||
PluginManager().init_config()
|
||||
Scheduler().init_plugin_jobs()
|
||||
|
||||
@@ -79,7 +79,7 @@ async def create_subscribe(
|
||||
# 订阅用户
|
||||
subscribe_in.username = current_user.name
|
||||
# 转化为字典
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if subscribe_in.id:
|
||||
subscribe_dict.pop("id", None)
|
||||
sid, message = await SubscribeChain().async_add(mtype=mtype,
|
||||
@@ -106,7 +106,7 @@ async def update_subscribe(
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
# 避免更新缺失集数
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if not subscribe_in.lack_episode:
|
||||
# 没有缺失集数时,缺失集数清空,避免更新为0
|
||||
subscribe_dict.pop("lack_episode")
|
||||
@@ -529,7 +529,7 @@ async def subscribe_fork(
|
||||
"""
|
||||
复用订阅
|
||||
"""
|
||||
sub_dict = sub.dict()
|
||||
sub_dict = sub.model_dump()
|
||||
sub_dict.pop("id")
|
||||
for key in list(sub_dict.keys()):
|
||||
if not hasattr(schemas.Subscribe(), key):
|
||||
|
||||
@@ -16,6 +16,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Re
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
@@ -52,6 +53,7 @@ async def fetch_image(
|
||||
proxy: bool = False,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
@@ -96,8 +98,13 @@ async def fetch_image(
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
@@ -140,6 +147,7 @@ async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
) -> Response:
|
||||
@@ -150,7 +158,12 @@ async def proxy_img(
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache,
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
|
||||
|
||||
@@ -178,7 +191,7 @@ def get_global_setting(token: str):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
@@ -199,7 +212,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
)
|
||||
info.update({
|
||||
|
||||
@@ -41,7 +41,7 @@ async def create_user(
|
||||
user = await current_user.async_get_by_name(db, name=user_in.name)
|
||||
if user:
|
||||
return schemas.Response(success=False, message="用户已存在")
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
user_info["hashed_password"] = get_password_hash(user_info["password"])
|
||||
user_info.pop("password")
|
||||
@@ -59,7 +59,7 @@ async def update_user(
|
||||
"""
|
||||
更新用户
|
||||
"""
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
# 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项
|
||||
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$'
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_token
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import Workflow
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -47,7 +47,7 @@ async def create_workflow(workflow: schemas.Workflow,
|
||||
workflow.state = "P"
|
||||
if not workflow.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
workflow_obj = Workflow(**workflow.dict())
|
||||
workflow_obj = Workflow(**workflow.model_dump())
|
||||
await workflow_obj.async_create(db)
|
||||
return schemas.Response(success=True, message="创建工作流成功")
|
||||
|
||||
@@ -277,7 +277,7 @@ def update_workflow(workflow: schemas.Workflow,
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not wf.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
wf.update(db, workflow.dict())
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi.concurrency import run_in_threadpool
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.core.cache import FileCache, AsyncFileCache
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
@@ -358,9 +358,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
with fresh(not cache):
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
async def async_recognize_media(self, meta: MetaBase = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
@@ -391,9 +392,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
|
||||
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
|
||||
@@ -850,9 +852,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.dict())
|
||||
self.messageoper.add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -899,15 +905,15 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
self.messagequeue.send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
async def async_post_message(self,
|
||||
message: Optional[Notification] = None,
|
||||
@@ -929,9 +935,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.dict())
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -978,16 +988,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**message.dict(), "type": message.mtype})
|
||||
data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
await self.messagequeue.async_send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -998,7 +1008,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [media.to_dict() for media in medias]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
@@ -1011,7 +1021,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ class DownloadChain(ChainBase):
|
||||
# 登记下载记录
|
||||
downloadhis = DownloadHistoryOper()
|
||||
downloadhis.add(
|
||||
path=str(download_path),
|
||||
path=download_path.as_posix(),
|
||||
type=_media.type.value,
|
||||
title=_media.title,
|
||||
year=_media.year,
|
||||
@@ -331,8 +331,8 @@ class DownloadChain(ChainBase):
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
"downloader": _downloader,
|
||||
"fullpath": str(_save_path / file),
|
||||
"savepath": str(_save_path),
|
||||
"fullpath": (_save_path / file).as_posix(),
|
||||
"savepath": _save_path.as_posix(),
|
||||
"filepath": file,
|
||||
"torrentname": _meta.org_string,
|
||||
})
|
||||
@@ -994,7 +994,7 @@ class DownloadChain(ChainBase):
|
||||
# 发出下载任务删除事件,如需处理辅种,可监听该事件
|
||||
self.eventmanager.send_event(EventType.DownloadDeleted, {
|
||||
"hash": hash_str,
|
||||
"torrents": [torrent.dict() for torrent in torrents]
|
||||
"torrents": [torrent.model_dump() for torrent in torrents]
|
||||
})
|
||||
else:
|
||||
logger.info(f"没有在下载器中查询到 {hash_str} 对应的下载任务")
|
||||
|
||||
@@ -113,6 +113,16 @@ class MediaServerChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("mediaserver_play_url", server=server, item_id=item_id)
|
||||
|
||||
def get_image_cookies(
|
||||
self, server: Optional[str], image_url: str
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取图片的Cookies
|
||||
"""
|
||||
return self.run_module(
|
||||
"mediaserver_image_cookies", server=server, image_url=image_url
|
||||
)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
同步媒体库所有数据到本地数据库
|
||||
@@ -167,7 +177,7 @@ class MediaServerChain(ChainBase):
|
||||
for episode in espisodes_info:
|
||||
seasoninfo[episode.season] = episode.episodes
|
||||
# 插入数据
|
||||
item_dict = item.dict()
|
||||
item_dict = item.model_dump()
|
||||
item_dict["seasoninfo"] = seasoninfo
|
||||
item_dict["item_type"] = item_type
|
||||
dboper.add(**item_dict)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -163,6 +165,10 @@ class MessageChain(ChainBase):
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
self.eventmanager.send_event(
|
||||
@@ -815,3 +821,86 @@ class MessageChain(ChainBase):
|
||||
buttons.append(page_buttons)
|
||||
|
||||
return buttons
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
try:
|
||||
# 检查AI智能体是否启用
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助手未启用,请在系统设置中启用"
|
||||
))
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助未配置,请在系统设置中配置"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="请输入您的问题或需求"
|
||||
))
|
||||
return
|
||||
|
||||
# 发送处理中消息
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot助手已收到您的请求,请稍候..."
|
||||
))
|
||||
|
||||
# 生成会话ID
|
||||
session_id = f"user_{userid}_{hash(user_message) % 10000}"
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class SiteChain(ChainBase):
|
||||
if userdata:
|
||||
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
|
||||
name=site.get("name"),
|
||||
payload=userdata.dict())
|
||||
payload=userdata.model_dump())
|
||||
# 发送事件
|
||||
eventmanager.send_event(EventType.SiteRefreshed, {
|
||||
"site_id": site.get("id")
|
||||
|
||||
@@ -150,7 +150,7 @@ class TorrentsChain(ChainBase):
|
||||
return []
|
||||
# 解析RSS
|
||||
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
|
||||
timeout=int(site.get("timeout") or 30))
|
||||
timeout=int(site.get("timeout") or 30), ua=site.get("ua") if site.get("ua") else None)
|
||||
if rss_items is None:
|
||||
# rss过期,尝试保留原配置生成新的rss
|
||||
self.__renew_rss_url(domain=domain, site=site)
|
||||
|
||||
@@ -33,6 +33,7 @@ from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey,
|
||||
SystemConfigKey, ChainEventType, ContentType
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
downloader_lock = threading.Lock()
|
||||
job_lock = threading.Lock()
|
||||
@@ -329,8 +330,12 @@ class JobManager:
|
||||
# 计算状态为完成的任务数
|
||||
if __mediaid__ not in self._job_view:
|
||||
return 0
|
||||
return sum([task.fileitem.size for task in self._job_view[__mediaid__].tasks if
|
||||
task.state == "completed" and task.fileitem.size is not None])
|
||||
return sum([
|
||||
task.fileitem.size if task.fileitem.size is not None
|
||||
else (SystemUtils.get_directory_size(Path(task.fileitem.path)) if task.fileitem.storage == "local" else 0)
|
||||
for task in self._job_view[__mediaid__].tasks
|
||||
if task.state == "completed"
|
||||
])
|
||||
|
||||
def total(self) -> int:
|
||||
"""
|
||||
@@ -1111,6 +1116,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
file_meta=file_meta)
|
||||
if begin_ep is not None:
|
||||
file_meta.begin_episode = begin_ep
|
||||
if part is not None:
|
||||
file_meta.part = part
|
||||
if end_ep is not None:
|
||||
file_meta.end_episode = end_ep
|
||||
@@ -1120,10 +1126,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(str(file_path))
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(str(file_path))
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
@@ -1436,7 +1442,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
|
||||
for keyword in exclude_words:
|
||||
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
|
||||
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1472,7 +1478,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (file_path.suffix in self.all_exts
|
||||
and not self._is_blocked_by_exclude_words(str(file_path), transfer_exclude_words)
|
||||
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
|
||||
and file_path.exists()):
|
||||
return False
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from pydantic.fields import Callable
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
@@ -180,7 +180,7 @@ class WorkflowExecutor:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
for key, value in context.dict().items():
|
||||
for key, value in context.model_dump().items():
|
||||
if not getattr(self.context, key, None):
|
||||
setattr(self.context, key, value)
|
||||
|
||||
|
||||
@@ -215,7 +215,7 @@ class Command(metaclass=Singleton):
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
|
||||
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
触发事件,允许调整命令数据
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import contextvars
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union
|
||||
@@ -27,6 +29,9 @@ DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
# 上下文变量来控制缓存行为
|
||||
_fresh = contextvars.ContextVar('fresh', default=False)
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""
|
||||
@@ -1010,6 +1015,49 @@ class AsyncFileBackend(AsyncCacheBackend):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fresh(fresh: bool = True):
|
||||
"""
|
||||
是否获取新数据(不使用缓存的值)
|
||||
|
||||
Usage:
|
||||
with fresh():
|
||||
result = some_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset fresh mode. {id(token):#x}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_fresh(fresh: bool = True):
|
||||
"""
|
||||
是否获取新数据(不使用缓存的值)
|
||||
|
||||
Usage:
|
||||
async with async_fresh():
|
||||
result = await some_async_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
|
||||
|
||||
def is_fresh() -> bool:
|
||||
"""
|
||||
是否获取新数据
|
||||
"""
|
||||
try:
|
||||
return _fresh.get()
|
||||
except LookupError:
|
||||
return False
|
||||
|
||||
def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend:
|
||||
"""
|
||||
获取文件缓存后端实例(Redis或文件系统),ttl仅在Redis环境中有效
|
||||
@@ -1084,16 +1132,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# 检查是否为异步函数
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
# 根据函数类型选择对应的缓存后端,没有ttl时默认是 LRU 缓存,否则是 TTL 缓存
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
else:
|
||||
# 同步函数使用同步缓存后端
|
||||
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
|
||||
def should_cache(value: Any) -> bool:
|
||||
"""
|
||||
@@ -1169,16 +1207,20 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
# 异步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
cache_key = __get_cache_key(args, kwargs)
|
||||
# 尝试获取缓存
|
||||
cached_value = await cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
|
||||
cache_region):
|
||||
return cached_value
|
||||
|
||||
if not is_fresh():
|
||||
# 尝试获取缓存
|
||||
cached_value = await cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
|
||||
cache_region):
|
||||
return cached_value
|
||||
# 执行异步函数并缓存结果
|
||||
result = await func(*args, **kwargs)
|
||||
# 判断是否需要缓存
|
||||
@@ -1198,15 +1240,19 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
async_wrapper.cache_clear = cache_clear
|
||||
return async_wrapper
|
||||
else:
|
||||
# 同步函数使用同步缓存后端
|
||||
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
# 同步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
cache_key = __get_cache_key(args, kwargs)
|
||||
# 尝试获取缓存
|
||||
cached_value = cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
|
||||
return cached_value
|
||||
|
||||
if not is_fresh():
|
||||
# 尝试获取缓存
|
||||
cached_value = cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
|
||||
return cached_value
|
||||
# 执行函数并缓存结果
|
||||
result = func(*args, **kwargs)
|
||||
# 判断是否需要缓存
|
||||
|
||||
@@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, BaseSettings, validator, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from app.log import logger, log_settings, LogConfigModel
|
||||
from app.schemas import MediaType
|
||||
@@ -49,8 +50,7 @@ class ConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# ==================== 基础应用配置 ====================
|
||||
# 项目名称
|
||||
@@ -92,7 +92,7 @@ class ConfigModel(BaseModel):
|
||||
# 超级管理员初始用户名
|
||||
SUPERUSER: str = "admin"
|
||||
# 超级管理员初始密码
|
||||
SUPERUSER_PASSWORD: str = None
|
||||
SUPERUSER_PASSWORD: Optional[str] = None
|
||||
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
@@ -398,24 +398,51 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 存储配置 ====================
|
||||
# 对rclone进行快照对比时,是否检查文件夹的修改时间
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对OpenList进行快照对比时,是否检查文件夹的修改时间
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
系统配置类
|
||||
"""
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -512,33 +539,54 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
|
||||
return default, True
|
||||
|
||||
@validator('*', pre=True, always=True)
|
||||
def generic_type_validator(cls, value: Any, field): # noqa
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def generic_type_validator(cls, data: Any): # noqa
|
||||
"""
|
||||
通用校验器,尝试将配置值转换为期望的类型
|
||||
"""
|
||||
if field.name == "API_TOKEN":
|
||||
converted_value, needs_update = cls.validate_api_token(value, value)
|
||||
else:
|
||||
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
|
||||
field.name)
|
||||
if needs_update:
|
||||
cls.update_env_config(field, value, converted_value)
|
||||
return converted_value
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# 处理 API_TOKEN 特殊验证
|
||||
if 'API_TOKEN' in data:
|
||||
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
|
||||
if needs_update:
|
||||
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
|
||||
data['API_TOKEN'] = converted_value
|
||||
|
||||
# 对其他字段进行类型转换
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
value = data[field_name]
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
field = cls.model_fields.get(field_name)
|
||||
if field:
|
||||
converted_value, needs_update = cls.generic_type_converter(
|
||||
value, value, field.annotation, field.default, field_name
|
||||
)
|
||||
if needs_update:
|
||||
cls.update_env_config(field_name, value, converted_value)
|
||||
data[field_name] = converted_value
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 env 配置
|
||||
"""
|
||||
message = None
|
||||
is_converted = original_value is not None and str(original_value) != str(converted_value)
|
||||
if is_converted:
|
||||
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
logger.warning(message)
|
||||
|
||||
if field.name in os.environ:
|
||||
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
if field_name in os.environ:
|
||||
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
@@ -548,10 +596,10 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
else:
|
||||
value_to_write = str(converted_value) if converted_value is not None else ""
|
||||
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field.name, value_to_set=value_to_write,
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
|
||||
quote_mode="always")
|
||||
if is_converted:
|
||||
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
|
||||
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
|
||||
return True, message
|
||||
|
||||
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
|
||||
@@ -565,19 +613,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
|
||||
try:
|
||||
field = self.__fields__[key]
|
||||
field = Settings.model_fields[key]
|
||||
original_value = getattr(self, key)
|
||||
if field.name == "API_TOKEN":
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(value, original_value)
|
||||
else:
|
||||
converted_value, needs_update = self.generic_type_converter(value,
|
||||
original_value,
|
||||
field.type_,
|
||||
field.default,
|
||||
key)
|
||||
converted_value, needs_update = self.generic_type_converter(
|
||||
value, original_value, field.annotation, field.default, key
|
||||
)
|
||||
# 如果没有抛出异常,则统一使用 converted_value 进行更新
|
||||
if needs_update or str(value) != str(converted_value):
|
||||
success, message = self.update_env_config(field, value, converted_value)
|
||||
success, message = self.update_env_config(key, value, converted_value)
|
||||
# 仅成功更新配置时,才更新内存
|
||||
if success:
|
||||
setattr(self, key, converted_value)
|
||||
|
||||
@@ -250,6 +250,8 @@ class MediaInfo:
|
||||
production_countries: list = field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: list = field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = field(default_factory=list)
|
||||
# 状态
|
||||
status: str = None
|
||||
# 标签
|
||||
@@ -257,7 +259,7 @@ class MediaInfo:
|
||||
# 评价数量
|
||||
vote_count: int = None
|
||||
# 流行度
|
||||
popularity: int = None
|
||||
popularity: float = None
|
||||
# 时长
|
||||
runtime: int = None
|
||||
# 下一集
|
||||
@@ -433,6 +435,18 @@ class MediaInfo:
|
||||
if self.release_date:
|
||||
# 年份
|
||||
self.year = self.release_date[:4]
|
||||
# 所有发行日期
|
||||
self.release_dates = [
|
||||
{
|
||||
"date": release_date.get("release_date"),
|
||||
"iso_code": result.get("iso_3166_1"),
|
||||
"note": release_date.get("note"),
|
||||
"type": release_date.get("type"),
|
||||
}
|
||||
for result in info.get("release_dates", {}).get("results", [])
|
||||
for release_date in result.get("release_dates", [])
|
||||
if release_date.get("release_date")
|
||||
]
|
||||
else:
|
||||
# 电视剧
|
||||
self.title = info.get('name')
|
||||
|
||||
@@ -586,7 +586,8 @@ class EventManager(metaclass=Singleton):
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
await run_in_threadpool(method, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, handler=handler, e=e, module_name=plugin.name)
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_module_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
|
||||
@@ -94,7 +94,6 @@ class MetaVideo(MetaBase):
|
||||
title = re.sub(r'\d{4}[\s._-]\d{1,2}[\s._-]\d{1,2}', "", title)
|
||||
# 拆分tokens
|
||||
tokens = Tokens(title)
|
||||
self.tokens = tokens
|
||||
# 实例化StreamingPlatforms对象
|
||||
streaming_platforms = StreamingPlatforms()
|
||||
# 解析名称、年份、季、集、资源类型、分辨率等
|
||||
@@ -102,7 +101,7 @@ class MetaVideo(MetaBase):
|
||||
while token:
|
||||
self._index += 1 # 更新当前处理的token索引
|
||||
# Part
|
||||
self.__init_part(token)
|
||||
self.__init_part(token, tokens)
|
||||
# 标题
|
||||
if self._continue_flag:
|
||||
self.__init_name(token)
|
||||
@@ -123,7 +122,7 @@ class MetaVideo(MetaBase):
|
||||
self.__init_resource_type(token)
|
||||
# 流媒体平台
|
||||
if self._continue_flag:
|
||||
self.__init_web_source(token, streaming_platforms)
|
||||
self.__init_web_source(token, tokens, streaming_platforms)
|
||||
# 视频编码
|
||||
if self._continue_flag:
|
||||
self.__init_video_encode(token)
|
||||
@@ -311,7 +310,7 @@ class MetaVideo(MetaBase):
|
||||
self.en_name = token
|
||||
self._last_token_type = "enname"
|
||||
|
||||
def __init_part(self, token: str):
|
||||
def __init_part(self, token: str, tokens: Tokens):
|
||||
"""
|
||||
识别Part
|
||||
"""
|
||||
@@ -327,12 +326,12 @@ class MetaVideo(MetaBase):
|
||||
if re_res:
|
||||
if not self.part:
|
||||
self.part = re_res.group(1)
|
||||
nextv = self.tokens.cur()
|
||||
nextv = tokens.cur()
|
||||
if nextv \
|
||||
and ((nextv.isdigit() and (len(nextv) == 1 or len(nextv) == 2 and nextv.startswith('0')))
|
||||
or nextv.upper() in ['A', 'B', 'C', 'I', 'II', 'III']):
|
||||
self.part = "%s%s" % (self.part, nextv)
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
self._last_token_type = "part"
|
||||
self._continue_flag = False
|
||||
# self._stop_name_flag = False
|
||||
@@ -582,7 +581,7 @@ class MetaVideo(MetaBase):
|
||||
self._effect.append(effect)
|
||||
self._last_token = effect.upper()
|
||||
|
||||
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
|
||||
def __init_web_source(self, token: str, tokens: Tokens, streaming_platforms: StreamingPlatforms):
|
||||
"""
|
||||
识别流媒体平台
|
||||
"""
|
||||
@@ -594,10 +593,10 @@ class MetaVideo(MetaBase):
|
||||
|
||||
prev_token = None
|
||||
prev_idx = self._index - 2
|
||||
if 0 <= prev_idx < len(self.tokens.tokens):
|
||||
prev_token = self.tokens.tokens[prev_idx]
|
||||
if 0 <= prev_idx < len(tokens.tokens):
|
||||
prev_token = tokens.tokens[prev_idx]
|
||||
|
||||
next_token = self.tokens.peek()
|
||||
next_token = tokens.peek()
|
||||
|
||||
if streaming_platforms.is_streaming_platform(token):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(token)
|
||||
@@ -616,7 +615,7 @@ class MetaVideo(MetaBase):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
|
||||
query_range = 2
|
||||
if is_next:
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
break
|
||||
|
||||
if not platform_name:
|
||||
@@ -626,8 +625,8 @@ class MetaVideo(MetaBase):
|
||||
match_start_idx = self._index - query_range
|
||||
match_end_idx = self._index - 1
|
||||
start_index = max(0, match_start_idx - query_range)
|
||||
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = self.tokens.tokens[start_index:end_index]
|
||||
end_index = min(len(tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = tokens.tokens[start_index:end_index]
|
||||
|
||||
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
|
||||
self.web_source = platform_name
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
@@ -9,15 +10,15 @@ import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
from watchfiles import watch
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
@@ -27,64 +28,12 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.limit import rate_limit_window
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginMonitorHandler(FileSystemEventHandler):
|
||||
|
||||
def on_modified(self, event):
|
||||
"""
|
||||
插件文件修改后重载
|
||||
"""
|
||||
if event.is_directory:
|
||||
return
|
||||
# 使用 pathlib 处理文件路径,跳过非 .py 文件以及 pycache 目录中的文件
|
||||
event_path = Path(event.src_path)
|
||||
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
|
||||
return
|
||||
|
||||
# 读取插件根目录下的__init__.py文件,读取class XXXX(_PluginBase)的类名
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if plugins_root not in event_path.parents:
|
||||
return
|
||||
# 获取插件目录路径,没有找到__init__.py时,说明不是有效包,跳过插件重载
|
||||
# 插件重载目前没有支持app/plugins/plugin/package/__init__.py的场景,这里也不做支持
|
||||
plugin_dir = event_path.parent
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
logger.debug(f"{plugin_dir} 下没有找到 __init__.py,跳过插件重载")
|
||||
return
|
||||
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
pid = None
|
||||
for line in lines:
|
||||
if line.startswith("class") and "(_PluginBase)" in line:
|
||||
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
|
||||
if pid:
|
||||
self.__reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
@staticmethod
|
||||
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
|
||||
def __reload_plugin(pid):
|
||||
"""
|
||||
重新加载插件
|
||||
"""
|
||||
try:
|
||||
logger.info(f"插件 {pid} 文件修改,重新加载...")
|
||||
PluginManager().reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
@@ -97,8 +46,10 @@ class PluginManager(metaclass=Singleton):
|
||||
self._running_plugins: dict = {}
|
||||
# 配置Key
|
||||
self._config_key: str = "plugin.%s"
|
||||
# 监听器
|
||||
self._observer: Observer = None
|
||||
# 监控线程
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -265,7 +216,6 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(module_name)
|
||||
importlib.reload(module)
|
||||
|
||||
# 检查模块中的类
|
||||
for name, obj in module.__dict__.items():
|
||||
@@ -319,10 +269,9 @@ class PluginManager(metaclass=Singleton):
|
||||
重新加载插件文件修改监测
|
||||
"""
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
if self._observer and self._observer.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
else:
|
||||
self.__start_monitor()
|
||||
# 先关闭已有监测,再重新启动
|
||||
self.stop_monitor()
|
||||
self.__start_monitor()
|
||||
else:
|
||||
self.stop_monitor()
|
||||
|
||||
@@ -330,25 +279,123 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
启用监测插件文件修改监测
|
||||
"""
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
return
|
||||
|
||||
logger.info("开始监测插件文件修改...")
|
||||
monitor_handler = PluginMonitorHandler()
|
||||
self._observer = Observer()
|
||||
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
|
||||
self._observer.start()
|
||||
|
||||
# 在启动新线程之前,确保停止事件是清除状态
|
||||
self._stop_monitor_event.clear()
|
||||
|
||||
# 创建并启动监控线程
|
||||
self._monitor_thread = threading.Thread(
|
||||
target=self._run_file_watcher,
|
||||
daemon=True
|
||||
)
|
||||
self._monitor_thread.start()
|
||||
|
||||
def stop_monitor(self):
|
||||
"""
|
||||
停止监测插件文件修改监测
|
||||
"""
|
||||
# 停止监测
|
||||
if self._observer and self._observer.is_alive():
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("正在停止插件文件修改监测...")
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
self._stop_monitor_event.set()
|
||||
self._monitor_thread.join(timeout=5)
|
||||
if self._monitor_thread.is_alive():
|
||||
logger.warning("插件文件修改监测线程在5秒内未能正常停止。")
|
||||
self._monitor_thread = None
|
||||
logger.info("插件文件修改监测停止完成")
|
||||
else:
|
||||
logger.info("未启用插件文件修改监测,无需停止")
|
||||
|
||||
def _run_file_watcher(self):
|
||||
"""
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
if plugins_to_reload:
|
||||
logger.info(f"检测到插件文件变化,准备重载: {list(plugins_to_reload)}")
|
||||
for pid in plugins_to_reload:
|
||||
try:
|
||||
self.reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {pid} 热重载失败: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_id_from_path(event_path: Path) -> Optional[str]:
|
||||
"""
|
||||
根据文件路径解析出插件的ID。
|
||||
:param event_path: 被修改文件的 Path 对象。
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
return None
|
||||
|
||||
try:
|
||||
plugin_dir_name = event_path.relative_to(plugins_root).parts[0]
|
||||
plugin_dir = plugins_root / plugin_dir_name
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
return None
|
||||
|
||||
# 读取 __init__.py 文件,查找插件主类名
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# 遍历AST,查找继承自 _PluginBase 的类
|
||||
for node in ast.walk(tree):
|
||||
# 检查节点是否为类定义
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# 遍历该类的所有基类
|
||||
for base in node.bases:
|
||||
# 检查基类是否是我们寻找的 _PluginBase
|
||||
# ast.Name 用于处理简单的基类名
|
||||
if isinstance(base, ast.Name) and base.id == '_PluginBase':
|
||||
# 返回这个类的名字
|
||||
return node.name
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -411,6 +458,10 @@ class PluginManager(metaclass=Singleton):
|
||||
except KeyError:
|
||||
# 模块可能已经被删除
|
||||
pass
|
||||
|
||||
importlib.invalidate_caches()
|
||||
logger.debug("已清除查找器的缓存")
|
||||
|
||||
if plugin_id:
|
||||
if modules_to_remove:
|
||||
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
|
||||
@@ -694,6 +745,36 @@ class PluginManager(metaclass=Singleton):
|
||||
logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}")
|
||||
return ret_actions
|
||||
|
||||
def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
[{
|
||||
"plugin_id": "插件ID",
|
||||
"plugin_name": "插件名称",
|
||||
"tools": [ToolClass1, ToolClass2, ...]
|
||||
}]
|
||||
"""
|
||||
ret_tools = []
|
||||
# 创建字典快照避免并发修改
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if pid and pid != plugin_id:
|
||||
continue
|
||||
if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools):
|
||||
try:
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
tools = plugin.get_agent_tools()
|
||||
if tools:
|
||||
ret_tools.append({
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}")
|
||||
return ret_tools
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str:
|
||||
"""
|
||||
@@ -864,14 +945,10 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
return list(self._running_plugins.keys())
|
||||
|
||||
@cached(maxsize=1, ttl=1800)
|
||||
def get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
|
||||
"""
|
||||
获取所有在线插件信息
|
||||
"""
|
||||
if force:
|
||||
self.get_online_plugins.cache_clear()
|
||||
|
||||
if not settings.PLUGIN_MARKET:
|
||||
return []
|
||||
|
||||
@@ -1029,7 +1106,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version, force)
|
||||
with fresh(force):
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
@@ -1167,15 +1245,11 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
return plugin
|
||||
|
||||
@cached(maxsize=1, ttl=1800)
|
||||
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
|
||||
"""
|
||||
异步获取所有在线插件信息
|
||||
:param force: 是否强制刷新(忽略缓存)
|
||||
"""
|
||||
if force:
|
||||
await self.async_get_online_plugins.cache_clear()
|
||||
|
||||
if not settings.PLUGIN_MARKET:
|
||||
return []
|
||||
|
||||
@@ -1240,7 +1314,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
|
||||
async with async_fresh(force):
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
|
||||
@@ -66,6 +66,12 @@ class Site(Base):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_name(cls, db: AsyncSession, name: str):
|
||||
result = await db.execute(select(cls).where(cls.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_actives(cls, db: Session):
|
||||
|
||||
@@ -85,6 +85,12 @@ class SiteOper(DbOper):
|
||||
"""
|
||||
return await Site.async_get_by_domain(self._db, domain)
|
||||
|
||||
async def async_get_by_name(self, name: str) -> Site:
|
||||
"""
|
||||
异步按名称获取站点
|
||||
"""
|
||||
return await Site.async_get_by_name(self._db, name)
|
||||
|
||||
def get_domains_by_ids(self, ids: List[int]) -> List[str]:
|
||||
"""
|
||||
按ID获取站点域名
|
||||
|
||||
@@ -128,10 +128,10 @@ class TransferHistoryOper(DbOper):
|
||||
self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
dest=transferinfo.target_item.path if transferinfo.target_item else None,
|
||||
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
|
||||
mode=mode,
|
||||
type=mediainfo.type.value,
|
||||
category=mediainfo.category,
|
||||
@@ -159,10 +159,10 @@ class TransferHistoryOper(DbOper):
|
||||
his = self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
dest=transferinfo.target_item.path if transferinfo.target_item else None,
|
||||
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
|
||||
mode=mode,
|
||||
type=mediainfo.type.value,
|
||||
category=mediainfo.category,
|
||||
@@ -188,7 +188,7 @@ class TransferHistoryOper(DbOper):
|
||||
year=meta.year,
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
mode=mode,
|
||||
seasons=meta.season,
|
||||
episodes=meta.episode,
|
||||
|
||||
@@ -367,7 +367,6 @@ class TemplateHelper(metaclass=SingletonClass):
|
||||
return rendered
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"模板处理失败: {str(e)}")
|
||||
raise ValueError(f"模板处理失败: {str(e)}") from e
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -48,34 +48,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if self.install_report():
|
||||
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
|
||||
|
||||
def get_plugins(self, repo_url: str, package_version: Optional[str] = None,
|
||||
force: bool = False) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
:param force: 是否强制刷新,忽略缓存
|
||||
"""
|
||||
# 如果强制刷新,直接调用不带缓存的版本
|
||||
if force:
|
||||
return self._request_plugins(repo_url, package_version)
|
||||
else:
|
||||
return self._request_plugins_cached(repo_url, package_version)
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
def _request_plugins_cached(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
return self._request_plugins(repo_url, package_version)
|
||||
|
||||
def _request_plugins(self, repo_url: str,
|
||||
def get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(不使用缓存)
|
||||
获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
@@ -932,33 +909,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"[GitHub] 所有策略均请求失败,URL: {url},请检查网络连接或 GitHub 配置")
|
||||
return None
|
||||
|
||||
async def async_get_plugins(self, repo_url: str, package_version: Optional[str] = None,
|
||||
force: bool = False) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
:param force: 是否强制刷新,忽略缓存
|
||||
"""
|
||||
if force:
|
||||
return await self._async_request_plugins(repo_url, package_version)
|
||||
else:
|
||||
return await self._async_request_plugins_cached(repo_url, package_version)
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
async def _async_request_plugins_cached(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
return await self._async_request_plugins(repo_url, package_version)
|
||||
|
||||
async def _async_request_plugins(self, repo_url: str,
|
||||
async def async_get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表(不使用缓存)
|
||||
异步获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
|
||||
@@ -228,13 +228,14 @@ class RssHelper:
|
||||
}
|
||||
|
||||
def parse(self, url, proxy: bool = False,
|
||||
timeout: Optional[int] = 15, headers: dict = None) -> Union[List[dict], None, bool]:
|
||||
timeout: Optional[int] = 15, headers: dict = None, ua: str = None) -> Union[List[dict], None, bool]:
|
||||
"""
|
||||
解析RSS订阅URL,获取RSS中的种子信息
|
||||
:param url: RSS地址
|
||||
:param proxy: 是否使用代理
|
||||
:param timeout: 请求超时
|
||||
:param headers: 自定义请求头
|
||||
:param ua: 自定义User-Agent
|
||||
:return: 种子信息列表,如为None代表Rss过期,如果为False则为错误
|
||||
"""
|
||||
# 开始处理
|
||||
@@ -243,8 +244,9 @@ class RssHelper:
|
||||
return False
|
||||
|
||||
try:
|
||||
ret = RequestUtils(proxies=settings.PROXY if proxy else None,
|
||||
timeout=timeout, headers=headers).get_res(url)
|
||||
ret = RequestUtils(ua=ua,
|
||||
proxies=settings.PROXY if proxy else None,
|
||||
timeout=timeout or 30, headers=headers).get_res(url)
|
||||
if not ret:
|
||||
logger.error(f"获取RSS失败:请求返回空值,URL: {url}")
|
||||
return False
|
||||
|
||||
@@ -47,7 +47,7 @@ class StorageHelper:
|
||||
if s.type == storage:
|
||||
s.config = conf
|
||||
break
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
def add_storage(self, storage: str, name: str, conf: dict):
|
||||
"""
|
||||
@@ -68,7 +68,7 @@ class StorageHelper:
|
||||
name=name,
|
||||
config=conf
|
||||
))
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
def reset_storage(self, storage: str):
|
||||
"""
|
||||
@@ -79,4 +79,4 @@ class StorageHelper:
|
||||
if s.type == storage:
|
||||
s.config = {}
|
||||
break
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
15
app/log.py
15
app/log.py
@@ -11,7 +11,8 @@ from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import click
|
||||
from pydantic import BaseSettings, BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -21,8 +22,7 @@ class LogConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# 配置文件目录
|
||||
CONFIG_DIR: Optional[str] = None
|
||||
@@ -71,10 +71,11 @@ class LogSettings(BaseSettings, LogConfigModel):
|
||||
"""
|
||||
return self.LOG_MAX_FILE_SIZE * 1024 * 1024
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = ConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
# 实例化日志设置
|
||||
|
||||
@@ -95,4 +95,4 @@ if __name__ == '__main__':
|
||||
# 更新数据库
|
||||
update_db()
|
||||
# 启动API服务
|
||||
Server.run()
|
||||
Server.run()
|
||||
@@ -232,6 +232,19 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
super().__init__()
|
||||
self._default_config_name: Optional[str] = None
|
||||
|
||||
def init_service(self, service_name: str,
|
||||
service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None):
|
||||
"""
|
||||
初始化服务,获取配置并实例化对应服务
|
||||
|
||||
:param service_name: 服务名称,作为配置匹配的依据
|
||||
:param service_type: 服务的类型,可以是类类型(Type[TService])、工厂函数(Callable)或 None 来跳过实例化
|
||||
"""
|
||||
# 重置默认配置名称
|
||||
self.reset_default_config_name()
|
||||
# 初始化服务
|
||||
super().init_service(service_name, service_type)
|
||||
|
||||
def get_default_config_name(self) -> Optional[str]:
|
||||
"""
|
||||
获取默认服务配置的名称
|
||||
@@ -263,6 +276,12 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
return {}
|
||||
return {conf.name: conf for conf in configs if conf.type == self._service_name and conf.enabled}
|
||||
|
||||
def reset_default_config_name(self):
|
||||
"""
|
||||
重置默认配置名称
|
||||
"""
|
||||
self._default_config_name = None
|
||||
|
||||
|
||||
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
|
||||
"""
|
||||
|
||||
@@ -984,11 +984,13 @@ class DoubanModule(_ModuleBase):
|
||||
"""
|
||||
if result:
|
||||
doubanid = result.get("id")
|
||||
if doubanid and not str(doubanid).isdigit():
|
||||
doubanid = re.search(r"\d+", doubanid).group(0)
|
||||
result["id"] = doubanid
|
||||
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
|
||||
return result
|
||||
if doubanid:
|
||||
if not str(doubanid).isdigit():
|
||||
doubanid = re.search(r"\d+", doubanid).group(0)
|
||||
result["id"] = doubanid
|
||||
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
|
||||
return result
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -640,7 +640,7 @@ class Emby:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=item.get("ProductionYear"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
@@ -1151,7 +1151,7 @@ class Emby:
|
||||
link = self.get_play_url(item.get("Id"))
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.get("Name")
|
||||
subtitle = item.get("ProductionYear")
|
||||
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
|
||||
else:
|
||||
title = f'{item.get("SeriesName")}'
|
||||
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
|
||||
@@ -1223,7 +1223,7 @@ class Emby:
|
||||
ret_latest.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
title=item.get("Name"),
|
||||
subtitle=item.get("ProductionYear"),
|
||||
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -129,7 +129,7 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
new_path = target_path / fileitem.name
|
||||
# 整理目录
|
||||
@@ -147,21 +147,18 @@ class TransHandler:
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
logger.info(f"文件夹 {fileitem.path} 整理成功")
|
||||
# 计算目录下所有文件大小
|
||||
total_size = sum(file.stat().st_size for file in Path(fileitem.path).rglob('*') if file.is_file())
|
||||
# 返回整理后的路径
|
||||
self.__set_result(success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
total_size=total_size,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
# 整理单个文件
|
||||
if mediainfo.type == MediaType.TV:
|
||||
@@ -174,7 +171,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
# 文件结束季为空
|
||||
in_meta.end_season = None
|
||||
@@ -210,7 +207,7 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
new_file = target_path / fileitem.name
|
||||
folder_path = target_path
|
||||
@@ -227,7 +224,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
# 目标文件
|
||||
target_item = target_oper.get_item(new_file)
|
||||
if target_item:
|
||||
@@ -258,7 +255,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'never':
|
||||
# 存在不覆盖
|
||||
self.__set_result(success=False,
|
||||
@@ -269,7 +266,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'latest':
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
@@ -296,7 +293,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
logger.info(f"文件 {fileitem.path} 整理成功")
|
||||
self.__set_result(success=True,
|
||||
@@ -306,7 +303,7 @@ class TransHandler:
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
finally:
|
||||
self.result = None
|
||||
|
||||
@@ -425,7 +422,7 @@ class TransHandler:
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
|
||||
@@ -154,7 +154,7 @@ class FilterModule(_ModuleBase):
|
||||
custom_rules = self.rulehelper.get_custom_rules()
|
||||
for rule in custom_rules:
|
||||
logger.info(f"加载自定义规则 {rule.id} - {rule.name}")
|
||||
self.rule_set[rule.id] = rule.dict()
|
||||
self.rule_set[rule.id] = rule.model_dump()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@@ -33,6 +33,8 @@ class SiteSchema(Enum):
|
||||
MTorrent = "MTorrent"
|
||||
Yema = "Yema"
|
||||
HDDolby = "HDDolby"
|
||||
Zhixing = "Zhixing"
|
||||
Bitpt = "Bitpt"
|
||||
|
||||
|
||||
class SiteParserBase(metaclass=ABCMeta):
|
||||
|
||||
161
app/modules/indexer/parser/bitpt.py
Normal file
161
app/modules/indexer/parser/bitpt.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#
|
||||
# 极速之星 https://bitpt.cn/
|
||||
# author: ThedoRap
|
||||
# time: 2025-10-02
|
||||
#
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urljoin, urlencode
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
class BitptSiteUserInfo(SiteParserBase):
|
||||
schema = SiteSchema.Bitpt
|
||||
|
||||
def _parse_site_page(self, html_text: str):
|
||||
self._user_basic_page = "userdetails.php?uid={uid}"
|
||||
self._user_detail_page = None
|
||||
self._user_basic_params = {}
|
||||
self._user_traffic_page = None
|
||||
self._sys_mail_unread_page = None
|
||||
self._user_mail_unread_page = None
|
||||
self._mail_unread_params = {}
|
||||
self._torrent_seeding_base = "browse.php"
|
||||
self._torrent_seeding_params = {"t": "myseed", "st": "2", "d": "desc"}
|
||||
self._torrent_seeding_headers = {}
|
||||
self._addition_headers = {}
|
||||
|
||||
def _parse_logged_in(self, html_text):
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
return bool(soup.find(id='userinfotop'))
|
||||
|
||||
def _parse_user_base_info(self, html_text: str):
|
||||
if not html_text:
|
||||
return None
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
table = soup.find('table', class_='frmtable')
|
||||
if not table:
|
||||
return
|
||||
|
||||
rows = table.find_all('tr')
|
||||
info_dict = {}
|
||||
for row in rows:
|
||||
cells = row.find_all('td')
|
||||
if len(cells) == 2:
|
||||
key = cells[0].text.strip()
|
||||
value = cells[1].text.strip()
|
||||
info_dict[key] = value
|
||||
|
||||
self.userid = info_dict.get('UID')
|
||||
self.username = info_dict.get('用户名').split('\xa0')[0] if '用户名' in info_dict else None
|
||||
self.user_level = info_dict.get('用户级别') if '用户级别' in info_dict else None
|
||||
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
|
||||
|
||||
self.upload = StringUtils.num_filesize(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
|
||||
self.download = StringUtils.num_filesize(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
|
||||
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
|
||||
bonus_str = info_dict.get('星辰', '')
|
||||
self.bonus = float(re.search(r'累计([\d\.]+)', bonus_str).group(1)) if re.search(r'累计([\d\.]+)', bonus_str) else 0
|
||||
self.message_unread = 0
|
||||
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
self.seeding = 0
|
||||
self.seeding_size = 0
|
||||
else:
|
||||
seeding_info = soup.find('div', style="margin:0 auto;width:90%;font-size:14px;margin-top:10px;margin-bottom:10px;text-align:center;")
|
||||
if seeding_info:
|
||||
seeding_link = seeding_info.find_all('a')[1].text if len(seeding_info.find_all('a')) > 1 else ''
|
||||
match = re.search(r'当前上传的种子\((\d+)个, 共([\d\.]+ [KMGT]B)\)', seeding_link)
|
||||
if match:
|
||||
self.seeding = int(match.group(1))
|
||||
self.seeding_size = StringUtils.num_filesize(match.group(2))
|
||||
else:
|
||||
self.seeding = 0
|
||||
self.seeding_size = 0
|
||||
|
||||
def _parse_user_traffic_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
|
||||
if not html_text:
|
||||
return 0, 0
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
torrent_table = soup.find('table', class_='torrenttable')
|
||||
if not torrent_table:
|
||||
return 0, 0
|
||||
rows = torrent_table.find_all('tr')
|
||||
if len(rows) <= 1:
|
||||
return 0, 0
|
||||
torrents = [row for row in rows[1:] if 'btr' in row.get('class', [])]
|
||||
page_seeding = 0
|
||||
page_seeding_size = 0
|
||||
for torrent in torrents:
|
||||
size_td = torrent.find('td', class_='r')
|
||||
if size_td:
|
||||
size_a = size_td.find('a')
|
||||
size_text = size_a.text.strip() if size_a else size_td.text.strip()
|
||||
if size_text:
|
||||
page_seeding += 1
|
||||
page_seeding_size += StringUtils.num_filesize(size_text)
|
||||
return page_seeding, page_seeding_size
|
||||
|
||||
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
super().parse()
|
||||
if self._index_html:
|
||||
soup = BeautifulSoup(self._index_html, 'html.parser')
|
||||
user_link = soup.find('a', href=re.compile(r'userdetails\.php\?uid=\d+'))
|
||||
if user_link:
|
||||
uid_match = re.search(r'uid=(\d+)', user_link['href'])
|
||||
if uid_match:
|
||||
self.userid = uid_match.group(1)
|
||||
|
||||
if self.userid and self._user_basic_page:
|
||||
basic_url = self._user_basic_page.format(uid=self.userid)
|
||||
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
|
||||
self._parse_user_base_info(basic_html)
|
||||
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
seeding_base_url = urljoin(self._base_url, self._torrent_seeding_base)
|
||||
params = self._torrent_seeding_params.copy()
|
||||
page_num = 1
|
||||
while True:
|
||||
params['p'] = page_num
|
||||
query_string = urlencode(params)
|
||||
full_url = f"{seeding_base_url}?{query_string}"
|
||||
seeding_html = self._get_page_content(url=full_url)
|
||||
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
|
||||
self.seeding += page_seeding
|
||||
self.seeding_size += page_seeding_size
|
||||
if page_seeding == 0:
|
||||
break
|
||||
page_num += 1
|
||||
|
||||
# 🔑 最终对外统一转字符串
|
||||
self.userid = str(self.userid or "")
|
||||
self.username = str(self.username or "")
|
||||
self.user_level = str(self.user_level or "")
|
||||
self.join_at = str(self.join_at or "")
|
||||
|
||||
self.upload = str(self.upload or 0)
|
||||
self.download = str(self.download or 0)
|
||||
self.ratio = str(self.ratio or 0)
|
||||
self.bonus = str(self.bonus or 0.0)
|
||||
self.message_unread = str(self.message_unread or 0)
|
||||
|
||||
self.seeding = str(self.seeding or 0)
|
||||
self.seeding_size = str(self.seeding_size or 0)
|
||||
184
app/modules/indexer/parser/zhixing.py
Normal file
184
app/modules/indexer/parser/zhixing.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# 知行 http://pt.zhixing.bjtu.edu.cn/
|
||||
# author: ThedoRap
|
||||
# time: 2025-10-02
|
||||
#
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.utils.string import StringUtils
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
class ZhixingSiteUserInfo(SiteParserBase):
|
||||
schema = SiteSchema.Zhixing
|
||||
|
||||
def _parse_site_page(self, html_text: str):
|
||||
"""
|
||||
获取站点页面地址
|
||||
"""
|
||||
self._user_basic_page = "user/{uid}/"
|
||||
self._user_detail_page = None
|
||||
self._user_basic_params = {}
|
||||
self._user_traffic_page = None
|
||||
self._sys_mail_unread_page = None
|
||||
self._user_mail_unread_page = None
|
||||
self._mail_unread_params = {}
|
||||
self._torrent_seeding_base = "user/{uid}/seeding"
|
||||
self._torrent_seeding_params = {}
|
||||
self._torrent_seeding_headers = {}
|
||||
self._addition_headers = {}
|
||||
|
||||
def _parse_logged_in(self, html_text):
|
||||
"""
|
||||
判断是否登录成功, 通过判断是否存在用户信息
|
||||
"""
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
return bool(soup.find(id='um'))
|
||||
|
||||
def _parse_user_base_info(self, html_text: str):
|
||||
"""
|
||||
解析用户基本信息,这里把_parse_user_traffic_info和_parse_user_detail_info合并到这里
|
||||
"""
|
||||
if not html_text:
|
||||
return None
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
details_tabs = soup.find_all('div', class_='user-details-tabs')
|
||||
info_dict = {}
|
||||
for tab in details_tabs:
|
||||
for p in tab.find_all('p'):
|
||||
text = p.text.strip()
|
||||
if ':' in text:
|
||||
parts = text.split(':', 1)
|
||||
elif ':' in text:
|
||||
parts = text.split(':', 1)
|
||||
else:
|
||||
continue
|
||||
if len(parts) == 2:
|
||||
key = parts[0].strip()
|
||||
value_text = parts[1].strip()
|
||||
value = re.split(r'\s*\(', value_text)[0].strip().split('查看')[0].strip()
|
||||
info_dict[key] = value
|
||||
|
||||
self._basic_info = info_dict # Save for fallback
|
||||
|
||||
self.userid = info_dict.get('UID')
|
||||
self.username = info_dict.get('用户名')
|
||||
self.user_level = info_dict.get('用户组')
|
||||
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
|
||||
|
||||
def num_filesize_safe(s: str):
|
||||
if s:
|
||||
s = s.strip()
|
||||
if re.match(r'^\d+(\.\d+)?$', s):
|
||||
s += ' B'
|
||||
return StringUtils.num_filesize(s) if s else 0
|
||||
|
||||
self.upload = num_filesize_safe(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
|
||||
self.download = num_filesize_safe(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
|
||||
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
|
||||
self.bonus = float(info_dict.get('保种积分')) if '保种积分' in info_dict else 0.0
|
||||
self.message_unread = 0 # 暂无消息解析
|
||||
|
||||
# Temporarily set seeding from basic, will override or fallback later
|
||||
self.seeding = int(info_dict.get('当前保种数量')) if '当前保种数量' in info_dict else 0
|
||||
self.seeding_size = num_filesize_safe(info_dict.get('当前保种容量')) if '当前保种容量' in info_dict else 0
|
||||
|
||||
def _parse_user_traffic_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
|
||||
"""
|
||||
解析用户做种信息单页,返回本页数量和大小
|
||||
"""
|
||||
if not html_text:
|
||||
return 0, 0
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
torrents = soup.find_all('tr', id=re.compile(r'^t\d+'))
|
||||
page_seeding = 0
|
||||
page_seeding_size = 0
|
||||
for torrent in torrents:
|
||||
size_td = torrent.find('td', class_='r')
|
||||
if size_td:
|
||||
size_text = size_td.find('a').text if size_td.find('a') else size_td.text.strip()
|
||||
page_seeding += 1
|
||||
page_seeding_size += StringUtils.num_filesize(size_text)
|
||||
return page_seeding, page_seeding_size
|
||||
|
||||
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
"""
|
||||
占位,避免抽象类报错
|
||||
"""
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
"""
|
||||
解析站点信息
|
||||
"""
|
||||
super().parse()
|
||||
# 先从首页解析userid
|
||||
if self._index_html:
|
||||
soup = BeautifulSoup(self._index_html, 'html.parser')
|
||||
user_link = soup.find('a', href=re.compile(r'/user/\d+/'))
|
||||
if user_link:
|
||||
uid_match = re.search(r'/user/(\d+)/', user_link['href'])
|
||||
if uid_match:
|
||||
self.userid = uid_match.group(1)
|
||||
# 如果有userid,则格式化页面
|
||||
if self.userid:
|
||||
if self._user_basic_page:
|
||||
basic_url = self._user_basic_page.format(uid=self.userid)
|
||||
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
|
||||
self._parse_user_base_info(basic_html)
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
self.seeding = 0 # Reset to sum from pages
|
||||
self.seeding_size = 0
|
||||
seeding_base = self._torrent_seeding_base.format(uid=self.userid)
|
||||
seeding_base_url = urljoin(self._base_url, seeding_base)
|
||||
page_num = 1
|
||||
while True:
|
||||
seeding_url = f"{seeding_base_url}/p{page_num}"
|
||||
seeding_html = self._get_page_content(url=seeding_url)
|
||||
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
|
||||
self.seeding += page_seeding
|
||||
self.seeding_size += page_seeding_size
|
||||
if page_seeding == 0:
|
||||
break
|
||||
page_num += 1
|
||||
# Fallback to basic if no seeding found from pages
|
||||
if self.seeding == 0 and hasattr(self, '_basic_info'):
|
||||
def num_filesize_safe(s: str):
|
||||
if s:
|
||||
s = s.strip()
|
||||
if re.match(r'^\d+(\.\d+)?$', s):
|
||||
s += ' B'
|
||||
return StringUtils.num_filesize(s) if s else 0
|
||||
self.seeding = int(self._basic_info.get('当前保种数量', 0))
|
||||
self.seeding_size = num_filesize_safe(self._basic_info.get('当前保种容量', ''))
|
||||
|
||||
# 🔑 最终对外统一转字符串,避免 join 报错
|
||||
self.userid = str(self.userid or "")
|
||||
self.username = str(self.username or "")
|
||||
self.user_level = str(self.user_level or "")
|
||||
self.join_at = str(self.join_at or "")
|
||||
|
||||
self.upload = str(self.upload or 0)
|
||||
self.download = str(self.download or 0)
|
||||
self.ratio = str(self.ratio or 0)
|
||||
self.bonus = str(self.bonus or 0.0)
|
||||
self.message_unread = str(self.message_unread or 0)
|
||||
|
||||
self.seeding = str(self.seeding or 0)
|
||||
self.seeding_size = str(self.seeding_size or 0)
|
||||
@@ -732,7 +732,7 @@ class Jellyfin:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=item.get("ProductionYear"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
@@ -924,7 +924,7 @@ class Jellyfin:
|
||||
image = self.generate_image_link(item.get("Id"), "Backdrop", False)
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.get("Name")
|
||||
subtitle = item.get("ProductionYear")
|
||||
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
|
||||
else:
|
||||
title = f'{item.get("SeriesName")}'
|
||||
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
|
||||
@@ -984,7 +984,7 @@ class Jellyfin:
|
||||
ret_latest.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
title=item.get("Name"),
|
||||
subtitle=item.get("ProductionYear"),
|
||||
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -437,7 +437,7 @@ class Plex:
|
||||
|
||||
@staticmethod
|
||||
def __get_ids(guids: List[Any]) -> dict:
|
||||
def parse_tmdb_id(value: str) -> (bool, int):
|
||||
def parse_tmdb_id(value: str) -> tuple[bool, int]:
|
||||
"""尝试将TMDB ID字符串转换为整数。如果成功,返回(True, int),失败则返回(False, None)。"""
|
||||
try:
|
||||
int_value = int(value)
|
||||
@@ -509,7 +509,7 @@ class Plex:
|
||||
item_type=item.type,
|
||||
title=item.title,
|
||||
original_title=item.originalTitle,
|
||||
year=item.year,
|
||||
year=str(item.year),
|
||||
tmdbid=ids.get("tmdb_id"),
|
||||
imdbid=ids.get("imdb_id"),
|
||||
tvdbid=ids.get("tvdb_id"),
|
||||
@@ -746,7 +746,7 @@ class Plex:
|
||||
item_type = MediaType.MOVIE.value if item.TYPE == "movie" else MediaType.TV.value
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.title
|
||||
subtitle = item.year
|
||||
subtitle = str(item.year) if item.year else None
|
||||
else:
|
||||
title = item.grandparentTitle
|
||||
subtitle = f"S{item.parentIndex}:E{item.index} - {item.title}"
|
||||
@@ -825,7 +825,7 @@ class Plex:
|
||||
ret_resume.append(schemas.MediaServerPlayItem(
|
||||
id=item.key,
|
||||
title=title,
|
||||
subtitle=item.year,
|
||||
subtitle=str(item.year) if item.year else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -264,7 +264,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
userid=userid, username=username, text=text)
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息
|
||||
|
||||
@@ -120,7 +120,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
logger.debug(f"解析SynologyChat消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息体
|
||||
|
||||
@@ -261,7 +261,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return cleaned
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息体
|
||||
@@ -283,7 +283,8 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
original_chat_id=message.original_chat_id,
|
||||
escape_markdown=kwargs.get("escape_markdown"))
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,8 @@ class Telegram:
|
||||
_callback_handlers: Dict[str, Callable] = {} # 存储回调处理器
|
||||
_user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
|
||||
_escape_chars = r'_*[]()~`>#+-=|{}.!' # Telegram MarkdownV2
|
||||
_markdown_escape_pattern = re.compile(f'([{re.escape(_escape_chars)}])') # Telegram MarkdownV2 规则转义特殊字符正则pattern
|
||||
def __init__(self, TELEGRAM_TOKEN: Optional[str] = None, TELEGRAM_CHAT_ID: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
初始化参数
|
||||
@@ -52,7 +53,7 @@ class Telegram:
|
||||
else:
|
||||
apihelper.proxy = settings.PROXY
|
||||
# bot
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown")
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
# 记录句柄
|
||||
self._bot = _bot
|
||||
# 获取并存储bot用户名用于@检测
|
||||
@@ -215,7 +216,8 @@ class Telegram:
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
original_chat_id: Optional[str] = None,
|
||||
escape_markdown: bool = True) -> Optional[bool]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
:param title: 消息标题
|
||||
@@ -226,7 +228,8 @@ class Telegram:
|
||||
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
|
||||
:param original_message_id: 原消息ID,如果提供则编辑原消息
|
||||
:param original_chat_id: 原消息的聊天ID,编辑消息时需要
|
||||
:userid: 发送消息的目标用户ID,为空则发给管理员
|
||||
:param escape_markdown: 是否对内容进行Markdown转义
|
||||
|
||||
"""
|
||||
if not self._telegram_token or not self._telegram_chat_id:
|
||||
return None
|
||||
@@ -236,10 +239,20 @@ class Telegram:
|
||||
return False
|
||||
|
||||
try:
|
||||
if title:
|
||||
# 标题总是转义(因为通常标题不包含Markdown格式)
|
||||
title = self.escape_markdown(title)
|
||||
if text:
|
||||
# 对text进行Markdown特殊字符转义
|
||||
text = re.sub(r"([_`])", r"\\\1", text)
|
||||
caption = f"*{title}*\n{text}"
|
||||
if escape_markdown:
|
||||
# 完全转义模式:转义所有特殊字符
|
||||
text = self.escape_markdown(text)
|
||||
else:
|
||||
# 智能转义模式:保留Markdown格式,只转义普通文本中的特殊字符
|
||||
text = self.escape_markdown_smart(text)
|
||||
if title:
|
||||
caption = f"*{title}*\n{text}"
|
||||
else:
|
||||
caption = text
|
||||
else:
|
||||
caption = f"*{title}*"
|
||||
|
||||
@@ -499,7 +512,7 @@ class Telegram:
|
||||
|
||||
if image:
|
||||
# 如果有图片,使用edit_message_media
|
||||
media = InputMediaPhoto(media=image, caption=text, parse_mode="Markdown")
|
||||
media = InputMediaPhoto(media=image, caption=text, parse_mode="MarkdownV2")
|
||||
self._bot.edit_message_media(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
@@ -512,7 +525,7 @@ class Telegram:
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup
|
||||
)
|
||||
return True
|
||||
@@ -542,7 +555,7 @@ class Telegram:
|
||||
ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id,
|
||||
photo=photo,
|
||||
caption=caption,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送图片消息失败")
|
||||
@@ -553,12 +566,12 @@ class Telegram:
|
||||
for i in range(0, len(caption), 4095):
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption[i:i + 4095],
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup if i == 0 else None)
|
||||
else:
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送文本消息失败")
|
||||
@@ -597,3 +610,84 @@ class Telegram:
|
||||
self._bot.stop_polling()
|
||||
self._polling_thread.join()
|
||||
logger.info("Telegram消息接收服务已停止")
|
||||
|
||||
def escape_markdown(self, text: str) -> str:
|
||||
# 按 Telegram MarkdownV2 规则转义特殊字符
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
return self._markdown_escape_pattern.sub(r'\\\1', text)
|
||||
|
||||
def escape_markdown_smart(self, text: str) -> str:
|
||||
"""
|
||||
智能转义Markdown文本:只转义不在Markdown标记内的特殊字符
|
||||
这样可以保留已有的Markdown格式(如*粗体*、_斜体_、[链接](url)等),
|
||||
同时转义普通文本中的特殊字符以避免API错误
|
||||
|
||||
注意:Telegram MarkdownV2不支持以下语法,这些字符会被转义:
|
||||
- 标题语法(#、##、###)会被转义为 \#、\##、\###
|
||||
- 列表语法(-、*、+)会被转义为 \-、\*、\+
|
||||
- 引用语法(>)会被转义为 \>
|
||||
|
||||
建议使用加粗文本模拟标题:*标题文本*
|
||||
|
||||
:param text: 要转义的文本
|
||||
:return: 转义后的文本
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
|
||||
# 如果没有特殊字符,直接返回
|
||||
if not any(char in self._escape_chars for char in text):
|
||||
return text
|
||||
|
||||
# 标记受保护的区域(Markdown标记内的内容不转义)
|
||||
protected = [False] * len(text)
|
||||
|
||||
# 按优先级匹配Markdown标记(从最复杂到最简单)
|
||||
# 1. 链接:[text](url) - 必须最先匹配
|
||||
link_pattern = r'\[([^\]]*)\]\(([^)]*)\)'
|
||||
for match in re.finditer(link_pattern, text):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 2. 粗体:*text*(单个*,不是**)
|
||||
bold_pattern = r'(?<!\*)\*(?!\*)([^*]+?)(?<!\*)\*(?!\*)'
|
||||
for match in re.finditer(bold_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 3. 斜体:_text_(单个_,不是__)
|
||||
italic_pattern = r'(?<!_)_(?!_)([^_]+?)(?<!_)_(?!_)'
|
||||
for match in re.finditer(italic_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 4. 代码:`text`
|
||||
code_pattern = r'`([^`]+)`'
|
||||
for match in re.finditer(code_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 5. 删除线:~text~
|
||||
strikethrough_pattern = r'~([^~]+)~'
|
||||
for match in re.finditer(strikethrough_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 构建结果:只转义未保护区域的特殊字符
|
||||
result = []
|
||||
for i, char in enumerate(text):
|
||||
if protected[i]:
|
||||
# 受保护区域(Markdown标记内),不转义
|
||||
result.append(char)
|
||||
elif char in self._escape_chars:
|
||||
# 未保护区域,转义特殊字符
|
||||
result.append('\\' + char)
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
return ''.join(result)
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from .exceptions import TMDbException
|
||||
@@ -18,14 +18,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TMDb(object):
|
||||
|
||||
def __init__(self, obj_cached=True, session=None, language=None):
|
||||
def __init__(self, session=None, language=None):
|
||||
self._api_key = settings.TMDB_API_KEY
|
||||
self._language = language or settings.TMDB_LOCALE or "en-US"
|
||||
self._session_id = None
|
||||
self._session = session
|
||||
self._wait_on_rate_limit = True
|
||||
self._debug_enabled = False
|
||||
self._cache_enabled = obj_cached
|
||||
self._proxies = settings.PROXY
|
||||
self._domain = settings.TMDB_API_DOMAIN
|
||||
self._page = None
|
||||
@@ -41,7 +39,6 @@ class TMDb(object):
|
||||
self._remaining = 40
|
||||
self._reset = None
|
||||
self._timeout = 15
|
||||
self.obj_cached = obj_cached
|
||||
|
||||
self.__clear_async_cache__ = False
|
||||
|
||||
@@ -111,36 +108,8 @@ class TMDb(object):
|
||||
def wait_on_rate_limit(self, wait_on_rate_limit):
|
||||
self._wait_on_rate_limit = bool(wait_on_rate_limit)
|
||||
|
||||
@property
|
||||
def debug(self):
|
||||
return self._debug_enabled
|
||||
|
||||
@debug.setter
|
||||
def debug(self, debug):
|
||||
self._debug_enabled = bool(debug)
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._cache_enabled
|
||||
|
||||
@cache.setter
|
||||
def cache(self, cache):
|
||||
self._cache_enabled = bool(cache)
|
||||
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
def cached_request(self, method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
|
||||
return self.request(method, url, data, json)
|
||||
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
async def async_cached_request(self, method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
|
||||
if self.__clear_async_cache__:
|
||||
self.__clear_async_cache__ = False
|
||||
await self.async_cached_request.cache_clear()
|
||||
return await self.async_request(method, url, data, json)
|
||||
|
||||
def request(self, method, url, data, json):
|
||||
def request(self, method, url, data, json, **kwargs):
|
||||
if method == "GET":
|
||||
req = self._req.get_res(url, params=data, json=json)
|
||||
else:
|
||||
@@ -149,7 +118,8 @@ class TMDb(object):
|
||||
raise TMDbException("无法连接TheMovieDb,请检查网络连接!")
|
||||
return req
|
||||
|
||||
async def async_request(self, method, url, data, json):
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
async def async_request(self, method, url, data, json, **kwargs):
|
||||
if method == "GET":
|
||||
req = await self._async_req.get_res(url, params=data, json=json)
|
||||
else:
|
||||
@@ -160,7 +130,7 @@ class TMDb(object):
|
||||
|
||||
def cache_clear(self):
|
||||
self.__clear_async_cache__ = True
|
||||
return self.cached_request.cache_clear()
|
||||
return self.request.cache_clear()
|
||||
|
||||
def _validate_api_key(self):
|
||||
if self.api_key is None or self.api_key == "":
|
||||
@@ -204,13 +174,6 @@ class TMDb(object):
|
||||
if "total_pages" in json_data:
|
||||
self._total_pages = json_data["total_pages"]
|
||||
|
||||
if self.debug:
|
||||
logger.info(json_data)
|
||||
if is_async:
|
||||
logger.info(self.async_cached_request.cache_info())
|
||||
else:
|
||||
logger.info(self.cached_request.cache_info())
|
||||
|
||||
@staticmethod
|
||||
def _handle_errors(json_data):
|
||||
if "errors" in json_data:
|
||||
@@ -224,10 +187,9 @@ class TMDb(object):
|
||||
self._validate_api_key()
|
||||
url = self._build_url(action, params)
|
||||
|
||||
if self.cache and self.obj_cached and call_cached and method != "POST":
|
||||
req = self.cached_request(method, url, data, json)
|
||||
else:
|
||||
req = self.request(method, url, data, json)
|
||||
with fresh(not call_cached or method == "POST"):
|
||||
req = self.request(method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
|
||||
|
||||
if req is None:
|
||||
return None
|
||||
@@ -253,10 +215,13 @@ class TMDb(object):
|
||||
self._validate_api_key()
|
||||
url = self._build_url(action, params)
|
||||
|
||||
if self.cache and self.obj_cached and call_cached and method != "POST":
|
||||
req = await self.async_cached_request(method, url, data, json)
|
||||
else:
|
||||
req = await self.async_request(method, url, data, json)
|
||||
if self.__clear_async_cache__:
|
||||
self.__clear_async_cache__ = False
|
||||
await self.async_request.cache_clear()
|
||||
|
||||
async with async_fresh(not call_cached or method == "POST"):
|
||||
req = await self.async_request(method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
|
||||
|
||||
if req is None:
|
||||
return None
|
||||
|
||||
@@ -382,3 +382,27 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
if not server_obj:
|
||||
return []
|
||||
return server_obj.get_latest_backdrops(num=count, remote=remote) or []
|
||||
|
||||
def mediaserver_image_cookies(
|
||||
self,
|
||||
server: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取飞牛影视服务器的图片Cookies
|
||||
|
||||
:param server: 媒体服务器名称
|
||||
:param image_url: 图片网址
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
if server:
|
||||
server_obj = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return None
|
||||
return server_obj.get_image_cookies(image_url)
|
||||
else:
|
||||
for server_obj in self.get_instances().values():
|
||||
if cookies := server_obj.get_image_cookies(image_url):
|
||||
return cookies
|
||||
|
||||
@@ -5,6 +5,7 @@ import app.modules.trimemedia.api as fnapi
|
||||
from app import schemas
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
@@ -20,6 +21,7 @@ class TrimeMedia:
|
||||
_sync_libraries: List[str] = []
|
||||
|
||||
_api: Optional[fnapi.Api] = None
|
||||
_version: Optional[fnapi.Version] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,6 +57,13 @@ class TrimeMedia:
|
||||
"""
|
||||
return self._api
|
||||
|
||||
@property
|
||||
def version(self) -> Optional[fnapi.Version]:
|
||||
"""
|
||||
获得飞牛API的版本
|
||||
"""
|
||||
return self._version
|
||||
|
||||
class _ApiCreateResult:
|
||||
api: fnapi.Api
|
||||
version: fnapi.Version
|
||||
@@ -123,7 +132,9 @@ class TrimeMedia:
|
||||
self.disconnect()
|
||||
if result := self.__create_api(self._host):
|
||||
self._api = result.api
|
||||
self._version = result.version
|
||||
# 版本号:0.8.53, 服务版本:0.8.23
|
||||
# 版本号:0.8.56, 服务版本:0.8.23 接口/memory/user/list改为/manager/user/list
|
||||
logger.debug(
|
||||
f"版本号:{result.version.frontend}, 服务版本:{result.version.backend}"
|
||||
)
|
||||
@@ -189,6 +200,7 @@ class TrimeMedia:
|
||||
],
|
||||
link=f"{self._playhost or self._api.host}/library/{library.guid}",
|
||||
server_type="trimemedia",
|
||||
use_cookies=True,
|
||||
)
|
||||
)
|
||||
return libraries
|
||||
@@ -437,7 +449,7 @@ class TrimeMedia:
|
||||
item_type=item_type,
|
||||
title=item.title,
|
||||
original_title=item.original_title,
|
||||
year=year,
|
||||
year=str(year),
|
||||
tmdbid=item.tmdb_id,
|
||||
imdbid=item.imdb_id,
|
||||
user_state=user_state,
|
||||
@@ -487,6 +499,7 @@ class TrimeMedia:
|
||||
else 0
|
||||
),
|
||||
server_type="trimemedia",
|
||||
use_cookies=True,
|
||||
)
|
||||
|
||||
def get_items(
|
||||
@@ -603,6 +616,7 @@ class TrimeMedia:
|
||||
if (item_details := self._api.item(item.guid)) is None:
|
||||
continue
|
||||
if remote:
|
||||
# FIXME 新版飞牛的壁纸无法直接在浏览器中访问
|
||||
img_host = self._playhost or self._api.host
|
||||
else:
|
||||
img_host = self._api.host
|
||||
@@ -631,3 +645,15 @@ class TrimeMedia:
|
||||
)
|
||||
else False
|
||||
)
|
||||
|
||||
def get_image_cookies(self, image_url: str):
|
||||
"""
|
||||
获得指定图片的Cookies
|
||||
"""
|
||||
if not self.is_authenticated():
|
||||
return None
|
||||
if not image_url or not SecurityUtils.is_safe_url(
|
||||
image_url, [self._api.host], strict=True
|
||||
):
|
||||
return None
|
||||
return {"Trim-MC-token": self._api.token}
|
||||
|
||||
@@ -139,7 +139,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
logger.error(f"VoceChat消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -71,7 +71,7 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -184,7 +184,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -46,12 +46,18 @@ class FileMonitorHandler(FileSystemEventHandler):
|
||||
self.callback = callback
|
||||
|
||||
def on_created(self, event: FileSystemEvent):
|
||||
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
|
||||
file_size=Path(event.src_path).stat().st_size)
|
||||
try:
|
||||
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
|
||||
file_size=Path(event.src_path).stat().st_size)
|
||||
except Exception as e:
|
||||
logger.error(f"on_created 异常: {e}")
|
||||
|
||||
def on_moved(self, event: FileSystemMovedEvent):
|
||||
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
|
||||
file_size=Path(event.dest_path).stat().st_size)
|
||||
try:
|
||||
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
|
||||
file_size=Path(event.dest_path).stat().st_size)
|
||||
except Exception as e:
|
||||
logger.error(f"on_moved 异常: {e}")
|
||||
|
||||
|
||||
class Monitor(metaclass=SingletonClass):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Dict, Tuple, Optional
|
||||
from typing import Any, List, Dict, Tuple, Optional, Type
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
@@ -200,6 +200,20 @@ class _PluginBase(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_agent_tools(self) -> List[Type]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
返回工具类列表,每个工具类必须继承自 MoviePilotTool
|
||||
[ToolClass1, ToolClass2, ...]
|
||||
|
||||
对工具类的要求:
|
||||
1、工具类必须继承自 app.agent.tools.base.MoviePilotTool
|
||||
2、工具类需要实现 run 方法(异步方法)
|
||||
3、工具类需要定义 name 和 description 属性
|
||||
4、工具类可以定义 args_schema 来指定输入参数模型
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_service(self):
|
||||
"""
|
||||
|
||||
58
app/schemas/agent.py
Normal file
58
app/schemas/agent.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""AI智能体相关数据模型"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
class ConversationMemory(BaseModel):
|
||||
"""对话记忆模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
title: Optional[str] = Field(default=None, description="会话标题")
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
@field_serializer('created_at', 'updated_at', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""AI智能体状态模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
current_task: Optional[str] = Field(default=None, description="当前任务")
|
||||
is_thinking: bool = Field(default=False, description="是否正在思考")
|
||||
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
@field_serializer('last_activity', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""用户消息模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
content: str = Field(description="消息内容")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
channel: Optional[str] = Field(default=None, description="消息渠道")
|
||||
source: Optional[str] = Field(default=None, description="消息来源")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
call_id: str = Field(description="调用ID")
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Optional[str] = Field(default=None, description="执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
@@ -86,7 +86,7 @@ class MediaInfo(BaseModel):
|
||||
# IMDB ID
|
||||
imdb_id: Optional[str] = None
|
||||
# TVDB ID
|
||||
tvdb_id: Optional[str] = None
|
||||
tvdb_id: Optional[int] = None
|
||||
# 豆瓣ID
|
||||
douban_id: Optional[str] = None
|
||||
# Bangumi ID
|
||||
@@ -158,6 +158,8 @@ class MediaInfo(BaseModel):
|
||||
production_countries: Optional[list] = Field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: Optional[list] = Field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = Field(default_factory=list)
|
||||
# 状态
|
||||
status: Optional[str] = None
|
||||
# 标签
|
||||
@@ -167,7 +169,7 @@ class MediaInfo(BaseModel):
|
||||
# 评价数量
|
||||
vote_count: Optional[int] = 0
|
||||
# 流行度
|
||||
popularity: Optional[int] = 0
|
||||
popularity: Optional[float] = 0.0
|
||||
# 时长
|
||||
runtime: Optional[int] = None
|
||||
# 下一集
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Set, Callable
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.schemas.message import MessageChannel
|
||||
from app.schemas.file import FileItem
|
||||
@@ -68,7 +68,8 @@ class AuthCredentials(ChainEventData):
|
||||
channel: Optional[str] = Field(default=None, description="认证渠道")
|
||||
service: Optional[str] = Field(default=None, description="服务名称")
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def check_fields_based_on_grant_type(cls, values): # noqa
|
||||
grant_type = values.get("grant_type")
|
||||
if not grant_type:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class DownloadHistory(BaseModel):
|
||||
@@ -51,8 +51,7 @@ class DownloadHistory(BaseModel):
|
||||
# 自定义剧集组
|
||||
episode_group: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TransferHistory(BaseModel):
|
||||
@@ -97,5 +96,4 @@ class TransferHistory(BaseModel):
|
||||
# 日期
|
||||
date: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Union, List, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
@@ -74,6 +74,8 @@ class MediaServerLibrary(BaseModel):
|
||||
link: Optional[str] = None
|
||||
# 服务器类型
|
||||
server_type: Optional[str] = None
|
||||
# 飞牛的图片需要Cookies
|
||||
use_cookies: Optional[bool] = None
|
||||
|
||||
|
||||
class MediaServerItemUserState(BaseModel):
|
||||
@@ -125,8 +127,7 @@ class MediaServerItem(BaseModel):
|
||||
lst_mod_date: Optional[str] = None
|
||||
user_state: Optional[MediaServerItemUserState] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MediaServerSeasonInfo(BaseModel):
|
||||
@@ -178,3 +179,5 @@ class MediaServerPlayItem(BaseModel):
|
||||
percent: Optional[float] = None
|
||||
BackdropImageTags: Optional[list] = Field(default_factory=list)
|
||||
server_type: Optional[str] = None
|
||||
# 飞牛的图片需要Cookies
|
||||
use_cookies: Optional[bool] = None
|
||||
|
||||
@@ -40,7 +40,7 @@ class CommingMessage(BaseModel):
|
||||
"""
|
||||
转换为字典
|
||||
"""
|
||||
items = self.dict()
|
||||
items = self.model_dump()
|
||||
for k, v in items.items():
|
||||
if isinstance(v, MessageChannel):
|
||||
items[k] = v.value
|
||||
@@ -88,7 +88,7 @@ class Notification(BaseModel):
|
||||
"""
|
||||
转换为字典
|
||||
"""
|
||||
items = self.dict()
|
||||
items = self.model_dump()
|
||||
for k, v in items.items():
|
||||
if isinstance(v, MessageChannel) \
|
||||
or isinstance(v, NotificationType):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Any, Union, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class Site(BaseModel):
|
||||
@@ -47,8 +47,7 @@ class Site(BaseModel):
|
||||
# 下载器
|
||||
downloader: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SiteStatistic(BaseModel):
|
||||
@@ -67,8 +66,7 @@ class SiteStatistic(BaseModel):
|
||||
# 备注
|
||||
note: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SiteUserData(BaseModel):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class Subscribe(BaseModel):
|
||||
@@ -24,7 +24,7 @@ class Subscribe(BaseModel):
|
||||
# 背景图
|
||||
backdrop: Optional[str] = None
|
||||
# 评分
|
||||
vote: Optional[int] = 0
|
||||
vote: Optional[float] = 0.0
|
||||
# 描述
|
||||
description: Optional[str] = None
|
||||
# 过滤规则
|
||||
@@ -76,8 +76,7 @@ class Subscribe(BaseModel):
|
||||
# 剧集组
|
||||
episode_group: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SubscribeShare(BaseModel):
|
||||
@@ -111,7 +110,7 @@ class SubscribeShare(BaseModel):
|
||||
# 背景图
|
||||
backdrop: Optional[str] = None
|
||||
# 评分
|
||||
vote: Optional[int] = 0
|
||||
vote: Optional[float] = 0.0
|
||||
# 描述
|
||||
description: Optional[str] = None
|
||||
# 包含
|
||||
|
||||
@@ -73,10 +73,10 @@ class TransferTask(BaseModel):
|
||||
返回字典
|
||||
"""
|
||||
dicts = vars(self).copy()
|
||||
dicts["fileitem"] = self.fileitem.dict() if self.fileitem else None
|
||||
dicts["meta"] = self.meta.dict() if self.meta else None
|
||||
dicts["mediainfo"] = self.mediainfo.dict() if self.mediainfo else None
|
||||
dicts["target_directory"] = self.target_directory.dict() if self.target_directory else None
|
||||
dicts["fileitem"] = self.fileitem.model_dump() if self.fileitem else None
|
||||
dicts["meta"] = self.meta.model_dump() if self.meta else None
|
||||
dicts["mediainfo"] = self.mediainfo.model_dump() if self.mediainfo else None
|
||||
dicts["target_directory"] = self.target_directory.model_dump() if self.target_directory else None
|
||||
return dicts
|
||||
|
||||
|
||||
@@ -144,8 +144,8 @@ class TransferInfo(BaseModel):
|
||||
返回字典
|
||||
"""
|
||||
dicts = vars(self).copy()
|
||||
dicts["fileitem"] = self.fileitem.dict() if self.fileitem else None
|
||||
dicts["target_item"] = self.target_item.dict() if self.target_item else None
|
||||
dicts["fileitem"] = self.fileitem.model_dump() if self.fileitem else None
|
||||
dicts["target_item"] = self.target_item.model_dump() if self.target_item else None
|
||||
return dicts
|
||||
|
||||
|
||||
|
||||
@@ -194,6 +194,8 @@ class SystemConfigKey(Enum):
|
||||
FollowSubscribers = "FollowSubscribers"
|
||||
# 通知发送时间
|
||||
NotificationSendTime = "NotificationSendTime"
|
||||
# AI智能体配置
|
||||
AIAgentConfig = "AIAgentConfig"
|
||||
# 通知消息格式模板
|
||||
NotificationTemplates = "NotificationTemplates"
|
||||
# 刮削开关设置
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
# Shared properties
|
||||
@@ -22,8 +22,7 @@ class UserBase(BaseModel):
|
||||
# 个性化设置
|
||||
settings: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Properties to receive via API on creation
|
||||
@@ -48,8 +47,7 @@ class UserUpdate(UserBase):
|
||||
class UserInDBBase(UserBase):
|
||||
id: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Additional properties to return via API
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from app.schemas.context import Context, MediaInfo
|
||||
from app.schemas.download import DownloadTask
|
||||
@@ -29,8 +29,7 @@ class Workflow(BaseModel):
|
||||
add_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
last_time: Optional[str] = Field(default=None, description="最后执行时间")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ActionParams(BaseModel):
|
||||
@@ -108,5 +107,4 @@ class WorkflowShare(BaseModel):
|
||||
date: Optional[str] = Field(default=None, description="分享时间")
|
||||
count: Optional[int] = Field(default=0, description="复用人次")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
105
app/startup/agent_initializer.py
Normal file
105
app/startup/agent_initializer.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""AI智能体初始化器"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_manager = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
初始化AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
from app.agent import agent_manager
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
await agent_manager.initialize()
|
||||
self._initialized = True
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
return False
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""
|
||||
清理AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not self._initialized or not self.agent_manager:
|
||||
return
|
||||
|
||||
await self.agent_manager.close()
|
||||
self._initialized = False
|
||||
logger.info("AI智能体管理器已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关闭AI智能体管理器时发生错误: {e}")
|
||||
|
||||
|
||||
# 全局AI智能体初始化器实例
|
||||
agent_initializer = AgentInitializer()
|
||||
|
||||
|
||||
def init_agent():
|
||||
"""
|
||||
初始化AI智能体(同步版本,用于在后台线程中运行)
|
||||
"""
|
||||
try:
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
# 在新的事件循环中初始化AI智能体管理器
|
||||
def run_init():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
success = loop.run_until_complete(agent_initializer.initialize())
|
||||
if success:
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
else:
|
||||
logger.error("AI智能体管理器初始化失败")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# 在后台线程中初始化
|
||||
init_thread = threading.Thread(target=run_init, daemon=True)
|
||||
init_thread.start()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化AI智能体时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_agent():
|
||||
"""
|
||||
停止AI智能体(异步版本,用于在应用关闭时调用)
|
||||
"""
|
||||
try:
|
||||
await agent_initializer.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"停止AI智能体时发生错误: {e}")
|
||||
@@ -27,6 +27,7 @@ from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.command import CommandChain
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.startup.agent_initializer import init_agent, stop_agent
|
||||
|
||||
|
||||
def start_frontend():
|
||||
@@ -110,6 +111,8 @@ async def stop_modules():
|
||||
"""
|
||||
服务关闭
|
||||
"""
|
||||
# 停止AI智能体
|
||||
await stop_agent()
|
||||
# 停止模块
|
||||
ModuleManager().stop()
|
||||
# 停止事件消费
|
||||
@@ -151,6 +154,8 @@ def init_modules():
|
||||
EventManager().start()
|
||||
# 初始化订阅分享
|
||||
SubscribeHelper()
|
||||
# 初始化AI智能体
|
||||
init_agent()
|
||||
# 启动前端服务
|
||||
start_frontend()
|
||||
# 检查认证状态
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.workflow import WorkFlowManager
|
||||
|
||||
|
||||
def init_workflow():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
@@ -48,7 +49,7 @@ def retry(ExceptionToCheck: Any,
|
||||
logger.warn(msg)
|
||||
else:
|
||||
print(msg)
|
||||
time.sleep(mdelay)
|
||||
await asyncio.sleep(mdelay)
|
||||
mtries -= 1
|
||||
mdelay *= backoff
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
295
app/utils/debounce.py
Normal file
295
app/utils/debounce.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Timer, Lock
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class BaseDebouncer(ABC):
|
||||
"""
|
||||
防抖器的抽象基类。定义了防抖器的基本接口和日志功能。
|
||||
所有防抖器实现类必须继承此类并实现其抽象方法。
|
||||
"""
|
||||
def __init__(self, func: Callable, interval: float, *,
|
||||
leading: bool = False, enable_logging: bool = False, source: str = ""):
|
||||
"""
|
||||
初始化防抖器实例。
|
||||
:param func: 要防抖的函数或协程
|
||||
:param interval: 防抖间隔,单位秒
|
||||
:param leading: 是否启用前沿模式
|
||||
:param enable_logging: 是否启用日志记录
|
||||
:param source: 日志来源标识
|
||||
"""
|
||||
self.func = func
|
||||
self.interval = interval
|
||||
self.leading = leading
|
||||
self.enable_logging = enable_logging
|
||||
self.source = source
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
定义防抖调用的契约,子类必须实现。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self) -> None:
|
||||
"""
|
||||
定义取消挂起调用的契约,子类必须实现。
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_log(self, message: str) -> str:
|
||||
"""
|
||||
格式化日志消息,加入 source 前缀。
|
||||
"""
|
||||
return f"[{self.source}] {message}" if self.source else message
|
||||
|
||||
def log(self, level: str, message: str):
|
||||
"""
|
||||
根据日志级别记录日志。
|
||||
"""
|
||||
if self.enable_logging:
|
||||
log_method = getattr(logger, level, logger.debug)
|
||||
log_method(self.format_log(message))
|
||||
|
||||
def log_debug(self, message: str):
|
||||
"""
|
||||
记录调试日志。
|
||||
"""
|
||||
self.log("debug", message)
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""
|
||||
记录信息日志。
|
||||
"""
|
||||
self.log("info", message)
|
||||
|
||||
def log_warning(self, message: str):
|
||||
"""
|
||||
记录警告日志。
|
||||
"""
|
||||
self.log("warning", message)
|
||||
|
||||
def error(self, message: str):
|
||||
"""
|
||||
记录错误日志。
|
||||
"""
|
||||
self.log("error", message)
|
||||
|
||||
def critical(self, message: str):
|
||||
"""
|
||||
记录严重错误日志。
|
||||
"""
|
||||
self.log("critical", message)
|
||||
|
||||
|
||||
class Debouncer(BaseDebouncer):
|
||||
"""
|
||||
同步防抖实现类
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
初始化防抖器实例。
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.timer: Optional[Timer] = None
|
||||
self.lock = Lock()
|
||||
# 用于前沿模式,标记是否处于“冷却”或“不应期”
|
||||
self.is_cooling_down = False
|
||||
|
||||
def __call__(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
调用防抖函数。
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
with self.lock:
|
||||
if self.leading:
|
||||
self._call_leading(*args, **kwargs)
|
||||
else:
|
||||
self._call_trailing(*args, **kwargs)
|
||||
|
||||
def _call_leading(self, *args, **kwargs):
|
||||
"""
|
||||
前沿模式的逻辑。
|
||||
"""
|
||||
# 如果不在冷却期,则立即执行
|
||||
if not self.is_cooling_down:
|
||||
self.log_info("前沿模式: 立即执行函数。")
|
||||
self.func(*args, **kwargs)
|
||||
|
||||
# 无论是否执行,都重置冷却计时器
|
||||
if self.timer and self.timer.is_alive():
|
||||
self.timer.cancel()
|
||||
|
||||
# 设置自己进入冷却期
|
||||
self.is_cooling_down = True
|
||||
|
||||
# 在间隔结束后,将冷却状态解除
|
||||
self.timer = Timer(self.interval, self._end_cool_down)
|
||||
self.timer.start()
|
||||
self.log_debug(f"前沿模式: 进入 {self.interval} 秒的冷却期。")
|
||||
|
||||
def _end_cool_down(self):
|
||||
"""
|
||||
计时器到期后,解除冷却状态
|
||||
"""
|
||||
with self.lock:
|
||||
self.is_cooling_down = False
|
||||
self.log_debug("前沿模式: 冷却时间结束,可以再次立即执行。")
|
||||
|
||||
def _call_trailing(self, *args, **kwargs):
|
||||
"""
|
||||
后沿模式的逻辑。
|
||||
"""
|
||||
# 【日志点】记录计时器被重置
|
||||
if self.timer and self.timer.is_alive():
|
||||
self.timer.cancel()
|
||||
self.log_debug("后沿模式: 检测到新的调用,已重置计时器。")
|
||||
|
||||
def execute():
|
||||
self.log_info("后沿模式: 计时结束,开始执行函数。")
|
||||
self.func(*args, **kwargs)
|
||||
|
||||
self.timer = Timer(self.interval, execute)
|
||||
self.timer.start()
|
||||
self.log_debug(f"后沿模式: 计时器已启动,将在 {self.interval} 秒后执行。")
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""
|
||||
取消任何挂起的调用,并重置状态。
|
||||
"""
|
||||
with self.lock:
|
||||
if self.timer and self.timer.is_alive():
|
||||
self.timer.cancel()
|
||||
self.timer = None
|
||||
self.log_info("防抖器被手动取消。")
|
||||
self.is_cooling_down = False
|
||||
|
||||
|
||||
class AsyncDebouncer(BaseDebouncer):
|
||||
"""
|
||||
异步防抖实现类。
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
初始化异步防抖器实例。
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.task: Optional[asyncio.Task] = None
|
||||
self.lock = asyncio.Lock()
|
||||
self.is_cooling_down = False
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
异步调用防抖函数。
|
||||
"""
|
||||
async with self.lock:
|
||||
if self.leading:
|
||||
await self._call_leading(*args, **kwargs)
|
||||
else:
|
||||
await self._call_trailing(*args, **kwargs)
|
||||
|
||||
async def _call_leading(self, *args, **kwargs):
|
||||
"""
|
||||
前沿模式的逻辑。
|
||||
"""
|
||||
if not self.is_cooling_down:
|
||||
self.log_info("前沿模式 (async): 立即执行协程。")
|
||||
await self.func(*args, **kwargs)
|
||||
|
||||
if self.task and not self.task.done():
|
||||
self.task.cancel()
|
||||
|
||||
self.is_cooling_down = True
|
||||
self.task = asyncio.create_task(self._end_cool_down())
|
||||
self.log_debug(f"前沿模式 (async): 进入 {self.interval} 秒的冷却期。")
|
||||
|
||||
async def _end_cool_down(self):
|
||||
"""
|
||||
计时器到期后,解除冷却状态
|
||||
"""
|
||||
await asyncio.sleep(self.interval)
|
||||
async with self.lock:
|
||||
self.is_cooling_down = False
|
||||
self.log_debug("前沿模式 (async): 冷却时间结束。")
|
||||
|
||||
async def _call_trailing(self, *args, **kwargs):
|
||||
"""
|
||||
后沿模式的逻辑。
|
||||
"""
|
||||
if self.task and not self.task.done():
|
||||
self.task.cancel()
|
||||
self.log_debug("后沿模式 (async): 检测到新的调用,已取消旧任务。")
|
||||
|
||||
self.task = asyncio.create_task(self._delayed_execute(*args, **kwargs))
|
||||
self.log_debug(f"后沿模式 (async): 任务已创建,将在 {self.interval} 秒后执行。")
|
||||
|
||||
async def _delayed_execute(self, *args, **kwargs):
|
||||
"""
|
||||
延迟执行实际的协程函数。
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(self.interval)
|
||||
self.log_info("后沿模式 (async): 延迟结束,开始执行协程。")
|
||||
await self.func(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
# 任务被取消是正常行为,无需处理
|
||||
pass
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""
|
||||
取消任何挂起的调用,并重置状态。
|
||||
"""
|
||||
async with self.lock:
|
||||
if self.task and not self.task.done():
|
||||
self.task.cancel()
|
||||
self.task = None
|
||||
self.log_info("异步防抖器被手动取消。")
|
||||
self.is_cooling_down = False
|
||||
|
||||
|
||||
def debounce(interval: float, *, leading: bool = False,
|
||||
enable_logging: bool = False, source: str = "") -> Callable:
|
||||
"""
|
||||
支持同步和异步的防抖装饰器工厂。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# 检查函数类型,并选择合适的引擎
|
||||
if inspect.iscoroutinefunction(func):
|
||||
# 异步函数,使用 AsyncDebouncer
|
||||
instance = AsyncDebouncer(func, interval,
|
||||
leading=leading,
|
||||
enable_logging=enable_logging,
|
||||
source=source)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
await instance(*args, **kwargs)
|
||||
|
||||
async_wrapper.cancel = instance.cancel
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
# 同步函数,使用 Debouncer
|
||||
instance = Debouncer(func, interval,
|
||||
leading=leading,
|
||||
enable_logging=enable_logging,
|
||||
source=source)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
instance(*args, **kwargs)
|
||||
|
||||
wrapper.cancel = instance.cancel
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,5 +1,7 @@
|
||||
import ast
|
||||
import dis
|
||||
import inspect
|
||||
import textwrap
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, get_type_hints
|
||||
|
||||
@@ -39,45 +41,42 @@ class ObjectUtils:
|
||||
return len(list(parameters.keys()))
|
||||
|
||||
@staticmethod
|
||||
def check_method(func: FunctionType) -> bool:
|
||||
def check_method(func: Callable[..., Any]) -> bool:
|
||||
"""
|
||||
检查函数是否已实现
|
||||
"""
|
||||
try:
|
||||
# 尝试通过源代码分析
|
||||
source = inspect.getsource(func)
|
||||
in_comment = False
|
||||
for line in source.split('\n'):
|
||||
line = line.strip()
|
||||
# 跳过空行
|
||||
if not line:
|
||||
continue
|
||||
# 处理"""单行注释
|
||||
if (line.startswith(('"""', "'''"))
|
||||
and line.endswith(('"""', "'''"))
|
||||
and len(line) > 3):
|
||||
continue
|
||||
# 处理"""多行注释
|
||||
if line.startswith(('"""', "'''")):
|
||||
in_comment = not in_comment
|
||||
continue
|
||||
# 在注释中则跳过
|
||||
if in_comment:
|
||||
continue
|
||||
# 跳过#注释、pass语句、装饰器、函数定义行
|
||||
if (line.startswith('#')
|
||||
or line == "pass"
|
||||
or line.startswith('@')
|
||||
or line.startswith('def ')):
|
||||
continue
|
||||
# 发现有效代码行
|
||||
src = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(src))
|
||||
node = tree.body[0]
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
return True
|
||||
body = node.body
|
||||
|
||||
for stmt in body:
|
||||
# 跳过 pass
|
||||
if isinstance(stmt, ast.Pass):
|
||||
continue
|
||||
# 跳过 docstring 或 ...
|
||||
if isinstance(stmt, ast.Expr):
|
||||
expr = stmt.value
|
||||
if isinstance(expr, ast.Constant):
|
||||
if isinstance(expr.value, str) or expr.value is Ellipsis:
|
||||
continue
|
||||
# 检查 raise NotImplementedError
|
||||
if isinstance(stmt, ast.Raise):
|
||||
exc = stmt.exc
|
||||
if isinstance(exc, ast.Call) and getattr(exc.func, "id", None) == "NotImplementedError":
|
||||
continue
|
||||
if isinstance(exc, ast.Name) and exc.id == "NotImplementedError":
|
||||
continue
|
||||
|
||||
return True
|
||||
# 没有有效代码行
|
||||
return False
|
||||
except Exception as err:
|
||||
print(err)
|
||||
# 源代码分析失败时,进行字节码分析
|
||||
code_obj = func.__code__
|
||||
code_obj = func.__code__ # type: ignore[attr-defined]
|
||||
instructions = list(dis.get_instructions(code_obj))
|
||||
# 检查是否为仅返回None的简单结构
|
||||
if len(instructions) == 2:
|
||||
|
||||
@@ -7,7 +7,6 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -225,23 +224,31 @@ class SystemUtils:
|
||||
if directory.is_file():
|
||||
return [directory]
|
||||
|
||||
if not min_filesize:
|
||||
min_filesize = 0
|
||||
|
||||
files = []
|
||||
# 预编译正则表达式
|
||||
if extensions:
|
||||
pattern = r".*(" + "|".join(extensions) + ")$"
|
||||
pattern = re.compile(r".*(" + "|".join(extensions) + r")$", re.IGNORECASE)
|
||||
else:
|
||||
pattern = r".*"
|
||||
pattern = re.compile(r".*")
|
||||
|
||||
# 遍历目录及子目录
|
||||
for matched_glob in glob('**', root_dir=directory, recursive=recursive, include_hidden=True):
|
||||
path = directory.joinpath(matched_glob)
|
||||
if path.is_file() \
|
||||
and re.match(pattern, path.name, re.IGNORECASE) \
|
||||
and path.stat().st_size >= min_filesize * 1024 * 1024:
|
||||
files.append(path)
|
||||
def _scan_directory(dir_path: Path, is_recursive: bool):
|
||||
try:
|
||||
with os.scandir(dir_path) as entries:
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=False):
|
||||
entry_path = Path(entry.path)
|
||||
if (pattern.match(entry.name) and
|
||||
(min_filesize <= 0 or entry.stat().st_size >= min_filesize * 1024 * 1024)):
|
||||
files.append(entry_path)
|
||||
elif entry.is_dir() and is_recursive:
|
||||
_scan_directory(Path(entry.path), is_recursive)
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
|
||||
_scan_directory(directory, recursive)
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
@@ -256,29 +263,44 @@ class SystemUtils:
|
||||
:return: True存在 False不存在
|
||||
"""
|
||||
|
||||
if not min_filesize:
|
||||
min_filesize = 0
|
||||
|
||||
if not directory.exists():
|
||||
return False
|
||||
|
||||
# 预编译正则表达式
|
||||
if extensions:
|
||||
pattern = re.compile(r".*(" + "|".join(extensions) + r")$", re.IGNORECASE)
|
||||
else:
|
||||
pattern = re.compile(r".*")
|
||||
|
||||
if directory.is_file():
|
||||
# 检查单个文件是否符合条件
|
||||
if extensions and not pattern.match(directory.name):
|
||||
return False
|
||||
if min_filesize > 0 and directory.stat().st_size < min_filesize * 1024 * 1024:
|
||||
return False
|
||||
return True
|
||||
|
||||
if not min_filesize:
|
||||
min_filesize = 0
|
||||
def _search_files(dir_path: Path, is_recursive: bool) -> bool:
|
||||
try:
|
||||
with os.scandir(dir_path) as entries:
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=False):
|
||||
# 检查文件是否符合条件
|
||||
if (pattern.match(entry.name) and
|
||||
(min_filesize <= 0 or entry.stat().st_size >= min_filesize * 1024 * 1024)):
|
||||
return True
|
||||
elif entry.is_dir() and is_recursive:
|
||||
# 递归搜索子目录
|
||||
if _search_files(Path(entry.path), is_recursive):
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
pattern = r".*(" + "|".join(extensions) + ")$"
|
||||
|
||||
# 遍历目录及子目录
|
||||
for matched_glob in glob('**', root_dir=directory, recursive=recursive, include_hidden=True):
|
||||
path = directory.joinpath(matched_glob)
|
||||
if path.is_file() \
|
||||
and re.match(pattern, path.name, re.IGNORECASE) \
|
||||
and path.stat().st_size >= min_filesize * 1024 * 1024:
|
||||
return True
|
||||
|
||||
return False
|
||||
return _search_files(directory, recursive)
|
||||
|
||||
@staticmethod
|
||||
def list_sub_files(directory: Path, extensions: list) -> List[Path]:
|
||||
@@ -292,12 +314,20 @@ class SystemUtils:
|
||||
return [directory]
|
||||
|
||||
files = []
|
||||
pattern = r".*(" + "|".join(extensions) + ")$"
|
||||
|
||||
# 遍历目录
|
||||
for path in directory.iterdir():
|
||||
if path.is_file() and re.match(pattern, path.name, re.IGNORECASE):
|
||||
files.append(path)
|
||||
# 预编译正则表达式
|
||||
if extensions:
|
||||
pattern = re.compile(r".*(" + "|".join(extensions) + r")$", re.IGNORECASE)
|
||||
else:
|
||||
pattern = re.compile(r".*")
|
||||
|
||||
try:
|
||||
with os.scandir(directory) as entries:
|
||||
for entry in entries:
|
||||
if entry.is_file() and pattern.match(entry.name):
|
||||
files.append(Path(entry.path))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
@@ -346,7 +376,7 @@ class SystemUtils:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_directory_size(path: Path) -> float:
|
||||
def get_directory_size(path: Path) -> int:
|
||||
"""
|
||||
计算目录的大小
|
||||
|
||||
@@ -358,14 +388,21 @@ class SystemUtils:
|
||||
"""
|
||||
if not path or not path.exists():
|
||||
return 0
|
||||
if path.is_file():
|
||||
return path.stat().st_size
|
||||
total_size = 0
|
||||
for path in path.glob('**/*'):
|
||||
if path.is_file():
|
||||
total_size += path.stat().st_size
|
||||
|
||||
return total_size
|
||||
def _calc_dir_size(dir_path):
|
||||
total = 0
|
||||
try:
|
||||
with os.scandir(dir_path) as entries:
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _calc_dir_size(entry.path)
|
||||
except OSError:
|
||||
pass
|
||||
return total
|
||||
|
||||
return _calc_dir_size(path) if path.is_dir() else path.stat().st_size
|
||||
|
||||
@staticmethod
|
||||
def space_usage(dir_list: Union[Path, List[Path]]) -> Tuple[float, float]:
|
||||
|
||||
@@ -46,7 +46,7 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
# 加载所有动作
|
||||
self._actions = {}
|
||||
actions = ModuleHelper.load(
|
||||
"app.actions",
|
||||
"app.workflow.actions",
|
||||
filter_func=lambda _, obj: filter_func(obj)
|
||||
)
|
||||
for action in actions:
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.actions import BaseAction
|
||||
from app.workflow.actions import BaseAction
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.config import global_vars
|
||||
@@ -44,7 +44,7 @@ class AddDownloadAction(BaseAction):
|
||||
@classmethod
|
||||
@property
|
||||
def data(cls) -> dict: # noqa
|
||||
return AddDownloadParams().dict()
|
||||
return AddDownloadParams().model_dump()
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.actions import BaseAction
|
||||
from app.workflow.actions import BaseAction
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo
|
||||
@@ -37,7 +37,7 @@ class AddSubscribeAction(BaseAction):
|
||||
@classmethod
|
||||
@property
|
||||
def data(cls) -> dict: # noqa
|
||||
return AddSubscribeParams().dict()
|
||||
return AddSubscribeParams().model_dump()
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
@@ -57,7 +57,7 @@ class AddSubscribeAction(BaseAction):
|
||||
logger.info(f"{media.title} {media.year} 已添加过订阅,跳过")
|
||||
continue
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media.dict())
|
||||
mediainfo.from_dict(media.model_dump())
|
||||
subscribechain = SubscribeChain()
|
||||
if subscribechain.exists(mediainfo):
|
||||
logger.info(f"{media.title} 已存在订阅")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user