mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-13 07:26:45 +00:00
Compare commits
164 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48b1ac28de | ||
|
|
6e329b17a9 | ||
|
|
6a492198a8 | ||
|
|
8bf9b6e7cb | ||
|
|
42e23ef564 | ||
|
|
c6806ee648 | ||
|
|
076fae696c | ||
|
|
ed294d3ea4 | ||
|
|
043be409d0 | ||
|
|
a5e7483870 | ||
|
|
365335be46 | ||
|
|
62543dd171 | ||
|
|
e2eef8ff21 | ||
|
|
3acf937d56 | ||
|
|
d572e523ba | ||
|
|
82113abe88 | ||
|
|
b7d121c58f | ||
|
|
6d5a85b144 | ||
|
|
78121917c6 | ||
|
|
a0913f0e32 | ||
|
|
e96e284715 | ||
|
|
c572a1b607 | ||
|
|
1845311f98 | ||
|
|
4f806db8b7 | ||
|
|
22858cc1e9 | ||
|
|
a0329a3eb0 | ||
|
|
b3e92088ee | ||
|
|
46db1c20f1 | ||
|
|
9d182e53b2 | ||
|
|
1205fc7fdb | ||
|
|
ff2826a448 | ||
|
|
ee750115ec | ||
|
|
0e13d22c97 | ||
|
|
8e7d040ac4 | ||
|
|
6755202958 | ||
|
|
8b7374a687 | ||
|
|
c17cca2365 | ||
|
|
8016a9539a | ||
|
|
e885fb15a0 | ||
|
|
c7f098771b | ||
|
|
fcd0908032 | ||
|
|
7ff1285084 | ||
|
|
b45b603b97 | ||
|
|
247208b8a9 | ||
|
|
182c46037b | ||
|
|
438d3210bc | ||
|
|
d523c7c916 | ||
|
|
09a19e94d5 | ||
|
|
3971c145df | ||
|
|
055117d83d | ||
|
|
c6baf43986 | ||
|
|
4ff16af3a7 | ||
|
|
17a1bd352b | ||
|
|
7421ca09cc | ||
|
|
9797e696e5 | ||
|
|
c36d6d8b2d | ||
|
|
3873786b99 | ||
|
|
76fdba7f09 | ||
|
|
72799e9638 | ||
|
|
2e77d03fe9 | ||
|
|
0c58eae5e7 | ||
|
|
b609567c38 | ||
|
|
7ecfa44fa0 | ||
|
|
a685b1dc3b | ||
|
|
63ce49a17c | ||
|
|
820fbe4076 | ||
|
|
efa05b7775 | ||
|
|
003781e903 | ||
|
|
ee71bafc96 | ||
|
|
bdd5f1231e | ||
|
|
6fee532c96 | ||
|
|
78aaad7b59 | ||
|
|
b128b0ede2 | ||
|
|
737d2f3bc6 | ||
|
|
179be53a65 | ||
|
|
1867f5e7c2 | ||
|
|
6662d24565 | ||
|
|
5880566a99 | ||
|
|
5d05b32711 | ||
|
|
fa2b720e92 | ||
|
|
d381238f83 | ||
|
|
751d627ead | ||
|
|
3e66a8de9b | ||
|
|
266052b12b | ||
|
|
803f4328f4 | ||
|
|
8e95568e11 | ||
|
|
ab09ee4819 | ||
|
|
41f94a172f | ||
|
|
566e597994 | ||
|
|
765fb9c05f | ||
|
|
b6720a19f7 | ||
|
|
3b130651c4 | ||
|
|
3f6c35dabe | ||
|
|
db2a952bca | ||
|
|
0ea9770bc3 | ||
|
|
0b20956c90 | ||
|
|
9f73b47d54 | ||
|
|
ce9c99af71 | ||
|
|
784024fb5d | ||
|
|
1145b32299 | ||
|
|
ab71df0011 | ||
|
|
fb137252a9 | ||
|
|
f57a680306 | ||
|
|
8bb3eaa320 | ||
|
|
9489730a44 | ||
|
|
d4795bb897 | ||
|
|
63775872c7 | ||
|
|
beff508a1f | ||
|
|
deaae8a2c6 | ||
|
|
46a27bd50c | ||
|
|
24f2993433 | ||
|
|
c80bfbfac5 | ||
|
|
06abfc45c7 | ||
|
|
440a773081 | ||
|
|
0797bcb38b | ||
|
|
d463b5bf0d | ||
|
|
0733c8edcc | ||
|
|
86c7c05cb1 | ||
|
|
18ff7ce753 | ||
|
|
8f2ed1004d | ||
|
|
14961323c3 | ||
|
|
f8c682b183 | ||
|
|
dd92708f60 | ||
|
|
4d9eeccefa | ||
|
|
cd7b251031 | ||
|
|
db614180b9 | ||
|
|
b6e527e5f4 | ||
|
|
77c0f8f39e | ||
|
|
58816d73c8 | ||
|
|
3b194d282e | ||
|
|
397f66433d | ||
|
|
04a4ed1d0e | ||
|
|
625850d4e7 | ||
|
|
6c572baca5 | ||
|
|
ee0406a13f | ||
|
|
608a049ba3 | ||
|
|
4d9b5198e2 | ||
|
|
24b6c970aa | ||
|
|
239c47f469 | ||
|
|
f0fc64c517 | ||
|
|
8481fd38ce | ||
|
|
5f425129d5 | ||
|
|
92955b1315 | ||
|
|
a3872d5bb5 | ||
|
|
a123ff2c04 | ||
|
|
188de34306 | ||
|
|
3d43750e9b | ||
|
|
fea228c68d | ||
|
|
a71a28e563 | ||
|
|
3b5d4982b5 | ||
|
|
b201e9ab8c | ||
|
|
d30b9282fd | ||
|
|
4f304a70b7 | ||
|
|
59a54d4f04 | ||
|
|
1e94d794ed | ||
|
|
5bd210406b | ||
|
|
e00514d36d | ||
|
|
f013bf1931 | ||
|
|
107cbbad1d | ||
|
|
481f1f9d30 | ||
|
|
704364061c | ||
|
|
c1bd2d6cf1 | ||
|
|
a018e1228c | ||
|
|
d962d9c7f6 |
@@ -40,10 +40,11 @@ git clone https://github.com/jxxghp/MoviePilot
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,设置`app`为源代码根目录,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 main.py
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
|
||||
355
app/agent/__init__.py
Normal file
355
app/agent/__init__.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
|
||||
return agent_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
agent.source = source
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
33
app/agent/callback/__init__.py
Normal file
33
app/agent/callback/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import threading
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
280
app/agent/memory/__init__.py
Normal file
280
app/agent/memory/__init__.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 每小时清理一次
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
current_time = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
for cache_key in expired_sessions:
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
70
app/agent/prompt/Agent Prompt.txt
Normal file
70
app/agent/prompt/Agent Prompt.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
|
||||
## Working Principles
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
|
||||
## Common Operation Workflows
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
118
app/agent/prompt/__init__.py
Normal file
118
app/agent/prompt/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
31
app/agent/tools/__init__.py
Normal file
31
app/agent/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""MoviePilot工具模块"""
|
||||
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from .factory import MoviePilotToolFactory
|
||||
|
||||
__all__ = [
|
||||
"MoviePilotTool",
|
||||
"SearchMediaTool",
|
||||
"AddSubscribeTool",
|
||||
"SearchTorrentsTool",
|
||||
"AddDownloadTool",
|
||||
"QuerySubscribesTool",
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"QuerySitesTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
"MoviePilotToolFactory"
|
||||
]
|
||||
73
app/agent/tools/base.py
Normal file
73
app/agent/tools/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""MoviePilot工具基类"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.chain import ChainBase
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
await self.send_tool_message(f"▶️️{explanation}")
|
||||
return await self.run(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
),
|
||||
escape_markdown=False
|
||||
)
|
||||
84
app/agent/tools/factory.py
Normal file
84
app/agent/tools/factory.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
AddSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
for plugin_info in plugin_tools_info:
|
||||
plugin_id = plugin_info.get("plugin_id")
|
||||
plugin_name = plugin_info.get("plugin_name")
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
0
app/agent/tools/impl/__init__.py
Normal file
0
app/agent/tools/impl/__init__.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
60
app/agent/tools/impl/add_subscribe.py
Normal file
60
app/agent/tools/impl/add_subscribe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id
|
||||
)
|
||||
if sid:
|
||||
return f"成功添加订阅:{title} ({year})"
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加订阅失败: {e}", exc_info=True)
|
||||
return f"添加订阅时发生错误: {str(e)}"
|
||||
84
app/agent/tools/impl/get_recommendations.py
Normal file
84
app/agent/tools/impl/get_recommendations.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""获取推荐工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
name_dicts = {
|
||||
"tmdb_trending": "TMDB 热门推荐",
|
||||
"douban_hot": "豆瓣热门推荐",
|
||||
"bangumi_calendar": "番组计划推荐"
|
||||
}
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
results = await recommend_chain.async_tmdb_trending(limit=limit)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(limit=limit)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(limit=limit)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(limit=limit)
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
# r 已经是字典格式(to_dict的结果)
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"overview": r.get("overview", "")[:200] + "..." if r.get("overview") and len(r.get("overview", "")) > 200 else r.get("overview"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
return f"获取推荐时发生错误: {str(e)}"
|
||||
34
app/agent/tools/impl/query_downloaders.py
Normal file
34
app/agent/tools/impl/query_downloaders.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryDownloadersInput(BaseModel):
|
||||
"""查询下载器工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
|
||||
if downloaders_config:
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载器失败: {e}")
|
||||
return f"查询下载器时发生错误: {str(e)}"
|
||||
80
app/agent/tools/impl/query_downloads.py
Normal file
80
app/agent/tools/impl/query_downloads.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDownloadsInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadsInput
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
# 使用 DownloadChain.downloading 方法获取正在下载的任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
if status != "all" and dl.status != status:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
41
app/agent/tools/impl/query_media_library.py
Normal file
41
app/agent/tools/impl/query_media_library.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class QueryMediaLibraryInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryMediaLibraryInput
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_server_oper = MediaServerOper()
|
||||
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.to_dict() for m in filtered_medias])
|
||||
return "媒体库中未找到相关媒体"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
66
app/agent/tools/impl/query_sites.py
Normal file
66
app/agent/tools/impl/query_sites.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
# 获取所有站点(按优先级排序)
|
||||
sites = await site_oper.async_list()
|
||||
filtered_sites = []
|
||||
for site in sites:
|
||||
# 按状态过滤
|
||||
if status == "active" and not site.is_active:
|
||||
continue
|
||||
if status == "inactive" and site.is_active:
|
||||
continue
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and name.lower() not in (site.name or "").lower():
|
||||
continue
|
||||
filtered_sites.append(site)
|
||||
if filtered_sites:
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_sites = []
|
||||
for s in filtered_sites:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"domain": s.domain,
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关站点"
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
73
app/agent/tools/impl/query_subscribes.py
Normal file
73
app/agent/tools/impl/query_subscribes.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""查询订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
return f"查询订阅时发生错误: {str(e)}"
|
||||
96
app/agent/tools/impl/search_media.py
Normal file
96
app/agent/tools/impl/search_media.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""搜索媒体工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
"""搜索媒体工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 构建搜索标题
|
||||
search_title = title
|
||||
if year:
|
||||
search_title = f"{title} {year}"
|
||||
if media_type:
|
||||
search_title = f"{search_title} {media_type}"
|
||||
if season:
|
||||
search_title = f"{search_title} S{season:02d}"
|
||||
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=search_title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
simplified = {
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
"douban_id": r.douban_id,
|
||||
"overview": r.overview[:200] + "..." if r.overview and len(r.overview) > 200 else r.overview,
|
||||
"vote_average": r.vote_average,
|
||||
"poster_path": r.poster_path,
|
||||
"detail_link": r.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
else:
|
||||
return f"未找到相关媒体资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索媒体失败: {str(e)}"
|
||||
logger.error(f"搜索媒体失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
122
app/agent/tools/impl/search_torrents.py
Normal file
122
app/agent/tools/impl/search_torrents.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
31
app/agent/tools/impl/send_message.py
Normal file
31
app/agent/tools/impl/send_message.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return f"发送消息时发生错误: {str(e)}"
|
||||
@@ -137,7 +137,7 @@ async def transfer(days: Optional[int] = 7,
|
||||
return [stat[1] for stat in transfer_stat]
|
||||
|
||||
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=int)
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=float)
|
||||
def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率
|
||||
@@ -145,7 +145,7 @@ def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.cpu_usage()
|
||||
|
||||
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=int)
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=float)
|
||||
def cpu2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率 API_TOKEN认证(?token=xxx)
|
||||
|
||||
@@ -40,10 +40,10 @@ def download(
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 手动下载始终使用选择的下载器
|
||||
torrentinfo.site_downloader = downloader
|
||||
# 上下文
|
||||
@@ -64,6 +64,9 @@ def download(
|
||||
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
|
||||
def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
bangumiid: Annotated[int | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
@@ -73,12 +76,12 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 上下文
|
||||
context = Context(
|
||||
meta_info=metainfo,
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType, MediaType
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -70,7 +70,7 @@ async def transfer_history(title: Optional[str] = None,
|
||||
|
||||
return schemas.Response(success=True,
|
||||
data={
|
||||
"list": result,
|
||||
"list": [item.to_dict() for item in result],
|
||||
"total": total,
|
||||
})
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ from app import schemas
|
||||
from app.chain.user import UserChain
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -29,7 +31,10 @@ def login_access_token(
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
# 是否显示配置向导
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user_or_message.id,
|
||||
@@ -45,6 +50,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def exists(media_in: schemas.MediaInfo,
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
@@ -108,7 +108,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
meta.year = media_in.year
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
|
||||
@@ -132,7 +132,7 @@ async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload
|
||||
"""
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.dict()
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
@@ -148,7 +148,7 @@ def send_notification(payload: schemas.SubscriptionMessage, _: schemas.TokenPayl
|
||||
try:
|
||||
webpush(
|
||||
subscription_info=sub,
|
||||
data=json.dumps(payload.dict()),
|
||||
data=json.dumps(payload.model_dump()),
|
||||
vapid_private_key=settings.VAPID.get("privateKey"),
|
||||
vapid_claims={
|
||||
"sub": settings.VAPID.get("subject")
|
||||
|
||||
@@ -13,7 +13,7 @@ from app import schemas
|
||||
from app.command import Command
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_apikey, verify_token, verify_apitoken
|
||||
from app.core.security import verify_apikey, verify_token
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
@@ -21,7 +21,6 @@ from app.factory import app
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.plugin import PluginMemoryInfo
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
PROTECTED_ROUTES = {"/api/v1/openapi.json", "/docs", "/docs/oauth2-redirect", "/redoc"}
|
||||
@@ -494,57 +493,6 @@ def clone_plugin(plugin_id: str,
|
||||
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/memory", summary="插件内存使用统计", response_model=List[PluginMemoryInfo])
|
||||
def plugin_memory_stats(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取所有插件的内存使用统计信息
|
||||
"""
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
memory_stats = plugin_manager.get_plugin_memory_stats()
|
||||
return memory_stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件内存统计失败:{str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取插件内存统计失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/memory/{plugin_id}", summary="单个插件内存使用统计", response_model=PluginMemoryInfo)
|
||||
def plugin_memory_stat(plugin_id: str, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取指定插件的内存使用统计信息
|
||||
"""
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
memory_stats = plugin_manager.get_plugin_memory_stats(plugin_id)
|
||||
if not memory_stats:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"插件 {plugin_id} 不存在或未运行")
|
||||
return memory_stats[0]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 内存统计失败:{str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取插件内存统计失败:{str(e)}")
|
||||
|
||||
|
||||
@router.delete("/memory/cache", summary="清除插件内存统计缓存")
|
||||
def clear_plugin_memory_cache(_: Annotated[str, Depends(verify_apitoken)],
|
||||
plugin_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
清除插件内存统计缓存
|
||||
"""
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
plugin_manager.clear_plugin_memory_cache(plugin_id)
|
||||
message = f"已清除插件 {plugin_id} 的内存统计缓存" if plugin_id else "已清除所有插件的内存统计缓存"
|
||||
return schemas.Response(success=True, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"清除插件内存统计缓存失败:{str(e)}")
|
||||
return schemas.Response(success=False, message=f"清除缓存失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", summary="获取插件配置")
|
||||
async def plugin_config(plugin_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> dict:
|
||||
|
||||
@@ -20,7 +20,7 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询搜索结果
|
||||
"""
|
||||
torrents = await SearchChain().async_last_search_results()
|
||||
torrents = await SearchChain().async_last_search_results() or []
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ async def add_site(
|
||||
site_in.name = site_info.get("name")
|
||||
site_in.id = None
|
||||
site_in.public = 1 if site_info.get("public") else 0
|
||||
site = Site(**site_in.dict())
|
||||
site = Site(**site_in.model_dump())
|
||||
site.create(db)
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
@@ -92,7 +92,7 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
await site.async_update(db, site_in.dict())
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
@@ -399,7 +399,7 @@ def auth_site(
|
||||
if not auth_info or not auth_info.site or not auth_info.params:
|
||||
return schemas.Response(success=False, message="请输入认证站点和认证参数")
|
||||
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.model_dump())
|
||||
# 认证成功后,重新初始化插件
|
||||
PluginManager().init_config()
|
||||
Scheduler().init_plugin_jobs()
|
||||
|
||||
@@ -79,7 +79,7 @@ async def create_subscribe(
|
||||
# 订阅用户
|
||||
subscribe_in.username = current_user.name
|
||||
# 转化为字典
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if subscribe_in.id:
|
||||
subscribe_dict.pop("id", None)
|
||||
sid, message = await SubscribeChain().async_add(mtype=mtype,
|
||||
@@ -106,7 +106,7 @@ async def update_subscribe(
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
# 避免更新缺失集数
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if not subscribe_in.lack_episode:
|
||||
# 没有缺失集数时,缺失集数清空,避免更新为0
|
||||
subscribe_dict.pop("lack_episode")
|
||||
@@ -421,11 +421,23 @@ async def popular_subscribes(
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询热门订阅
|
||||
"""
|
||||
subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
|
||||
subscribes = await SubscribeHelper().async_get_statistic(
|
||||
stype=stype,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
if subscribes:
|
||||
ret_medias = []
|
||||
for sub in subscribes:
|
||||
@@ -517,7 +529,7 @@ async def subscribe_fork(
|
||||
"""
|
||||
复用订阅
|
||||
"""
|
||||
sub_dict = sub.dict()
|
||||
sub_dict = sub.model_dump()
|
||||
sub_dict.pop("id")
|
||||
for key in list(sub_dict.keys()):
|
||||
if not hasattr(schemas.Subscribe(), key):
|
||||
@@ -570,11 +582,23 @@ async def popular_subscribes(
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询分享的订阅
|
||||
"""
|
||||
return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
|
||||
return await SubscribeHelper().async_get_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
|
||||
|
||||
@@ -11,10 +11,12 @@ import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
@@ -31,7 +33,6 @@ from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.log import logger
|
||||
@@ -52,19 +53,21 @@ async def fetch_image(
|
||||
proxy: bool = False,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Response:
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
if not url:
|
||||
raise HTTPException(status_code=404, detail="URL not provided")
|
||||
return None
|
||||
|
||||
if allowed_domains is None:
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS)
|
||||
|
||||
# 验证URL安全性
|
||||
if not SecurityUtils.is_safe_url(url, allowed_domains):
|
||||
raise HTTPException(status_code=404, detail="Unsafe URL")
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
@@ -95,18 +98,24 @@ async def fetch_image(
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
raise HTTPException(status_code=502, detail="Invalid image format")
|
||||
logger.warn(f"Invalid image format for URL {url}: {e}")
|
||||
return None
|
||||
|
||||
# 获取请求响应头
|
||||
response_headers = response.headers
|
||||
@@ -138,6 +147,7 @@ async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
) -> Response:
|
||||
@@ -148,7 +158,12 @@ async def proxy_img(
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache,
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
|
||||
|
||||
@@ -176,7 +191,7 @@ def get_global_setting(token: str):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
@@ -197,7 +212,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
)
|
||||
info.update({
|
||||
|
||||
@@ -41,7 +41,7 @@ async def create_user(
|
||||
user = await current_user.async_get_by_name(db, name=user_in.name)
|
||||
if user:
|
||||
return schemas.Response(success=False, message="用户已存在")
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
user_info["hashed_password"] = get_password_hash(user_info["password"])
|
||||
user_info.pop("password")
|
||||
@@ -59,7 +59,7 @@ async def update_user(
|
||||
"""
|
||||
更新用户
|
||||
"""
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
# 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项
|
||||
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$'
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_token
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import Workflow
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -47,7 +47,7 @@ async def create_workflow(workflow: schemas.Workflow,
|
||||
workflow.state = "P"
|
||||
if not workflow.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
workflow_obj = Workflow(**workflow.dict())
|
||||
workflow_obj = Workflow(**workflow.model_dump())
|
||||
await workflow_obj.async_create(db)
|
||||
return schemas.Response(success=True, message="创建工作流成功")
|
||||
|
||||
@@ -277,7 +277,7 @@ def update_workflow(workflow: schemas.Workflow,
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not wf.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
wf.update(db, workflow.dict())
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi.concurrency import run_in_threadpool
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.core.cache import FileCache, AsyncFileCache
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
@@ -358,9 +358,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
with fresh(not cache):
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
async def async_recognize_media(self, meta: MetaBase = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
@@ -391,9 +392,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
|
||||
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
|
||||
@@ -850,9 +852,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.dict())
|
||||
self.messageoper.add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -899,15 +905,15 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
self.messagequeue.send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
async def async_post_message(self,
|
||||
message: Optional[Notification] = None,
|
||||
@@ -929,9 +935,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.dict())
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -978,16 +988,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**message.dict(), "type": message.mtype})
|
||||
data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
await self.messagequeue.async_send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -998,7 +1008,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [media.to_dict() for media in medias]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
@@ -1011,7 +1021,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ class DownloadChain(ChainBase):
|
||||
# 登记下载记录
|
||||
downloadhis = DownloadHistoryOper()
|
||||
downloadhis.add(
|
||||
path=str(download_path),
|
||||
path=download_path.as_posix(),
|
||||
type=_media.type.value,
|
||||
title=_media.title,
|
||||
year=_media.year,
|
||||
@@ -331,8 +331,8 @@ class DownloadChain(ChainBase):
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
"downloader": _downloader,
|
||||
"fullpath": str(_save_path / file),
|
||||
"savepath": str(_save_path),
|
||||
"fullpath": (_save_path / file).as_posix(),
|
||||
"savepath": _save_path.as_posix(),
|
||||
"filepath": file,
|
||||
"torrentname": _meta.org_string,
|
||||
})
|
||||
@@ -994,7 +994,7 @@ class DownloadChain(ChainBase):
|
||||
# 发出下载任务删除事件,如需处理辅种,可监听该事件
|
||||
self.eventmanager.send_event(EventType.DownloadDeleted, {
|
||||
"hash": hash_str,
|
||||
"torrents": [torrent.dict() for torrent in torrents]
|
||||
"torrents": [torrent.model_dump() for torrent in torrents]
|
||||
})
|
||||
else:
|
||||
logger.info(f"没有在下载器中查询到 {hash_str} 对应的下载任务")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from threading import Lock
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
@@ -20,6 +22,9 @@ from app.utils.string import StringUtils
|
||||
recognize_lock = Lock()
|
||||
scraping_lock = Lock()
|
||||
|
||||
current_umask = os.umask(0)
|
||||
os.umask(current_umask)
|
||||
|
||||
|
||||
class MediaChain(ChainBase):
|
||||
"""
|
||||
@@ -456,36 +461,65 @@ class MediaChain(ChainBase):
|
||||
"""
|
||||
if not _fileitem or not _content or not _path:
|
||||
return
|
||||
# 保存文件到临时目录
|
||||
tmp_dir = settings.TEMP_PATH / StringUtils.generate_random_str(10)
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_file = tmp_dir / _path.name
|
||||
tmp_file.write_bytes(_content)
|
||||
# 获取文件的父目录
|
||||
try:
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file, new_name=_path.name)
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 写入内容
|
||||
if isinstance(_content, bytes):
|
||||
tmp_file.write(_content)
|
||||
else:
|
||||
tmp_file.write(_content.encode('utf-8'))
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削文件只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path, new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存文件:{item.path}")
|
||||
else:
|
||||
logger.warn(f"文件保存失败:{_path}")
|
||||
finally:
|
||||
if tmp_file.exists():
|
||||
tmp_file.unlink()
|
||||
|
||||
def __download_image(_url: str) -> Optional[bytes]:
|
||||
def __download_and_save_image(_fileitem: schemas.FileItem, _path: Path, _url: str):
|
||||
"""
|
||||
下载图片并保存
|
||||
流式下载图片并直接保存到文件(减少内存占用)
|
||||
:param _fileitem: 关联的媒体文件项
|
||||
:param _path: 图片文件路径
|
||||
:param _url: 图片下载URL
|
||||
"""
|
||||
if not _fileitem or not _url or not _path:
|
||||
return
|
||||
try:
|
||||
logger.info(f"正在下载图片:{_url} ...")
|
||||
r = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(url=_url)
|
||||
if r:
|
||||
return r.content
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败,请检查网络连通性!")
|
||||
request_utils = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT)
|
||||
with request_utils.get_stream(url=_url) as r:
|
||||
if r and r.status_code == 200:
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 流式写入文件
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削的图片只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path,
|
||||
new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存图片:{item.path}")
|
||||
else:
|
||||
logger.warn(f"图片保存失败:{_path}")
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败")
|
||||
except Exception as err:
|
||||
logger.error(f"{_url} 图片下载失败:{str(err)}!")
|
||||
return None
|
||||
|
||||
if not fileitem:
|
||||
return
|
||||
@@ -587,11 +621,8 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath.with_name(image_name)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 写入图片到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -637,13 +668,10 @@ class MediaChain(ChainBase):
|
||||
for episode, image_url in image_dict.items():
|
||||
image_path = filepath.with_suffix(Path(image_url).suffix)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -694,13 +722,10 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath.with_name(image_name)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到剧集目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -730,13 +755,11 @@ class MediaChain(ChainBase):
|
||||
continue
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path,
|
||||
_url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -786,11 +809,8 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
|
||||
@@ -113,6 +113,16 @@ class MediaServerChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("mediaserver_play_url", server=server, item_id=item_id)
|
||||
|
||||
def get_image_cookies(
|
||||
self, server: Optional[str], image_url: str
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取图片的Cookies
|
||||
"""
|
||||
return self.run_module(
|
||||
"mediaserver_image_cookies", server=server, image_url=image_url
|
||||
)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
同步媒体库所有数据到本地数据库
|
||||
@@ -167,7 +177,7 @@ class MediaServerChain(ChainBase):
|
||||
for episode in espisodes_info:
|
||||
seasoninfo[episode.season] = episode.episodes
|
||||
# 插入数据
|
||||
item_dict = item.dict()
|
||||
item_dict = item.model_dump()
|
||||
item_dict["seasoninfo"] = seasoninfo
|
||||
item_dict["item_type"] = item_type
|
||||
dboper.add(**item_dict)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -163,6 +165,10 @@ class MessageChain(ChainBase):
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
self.eventmanager.send_event(
|
||||
@@ -815,3 +821,86 @@ class MessageChain(ChainBase):
|
||||
buttons.append(page_buttons)
|
||||
|
||||
return buttons
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
try:
|
||||
# 检查AI智能体是否启用
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助手未启用,请在系统设置中启用"
|
||||
))
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助未配置,请在系统设置中配置"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="请输入您的问题或需求"
|
||||
))
|
||||
return
|
||||
|
||||
# 发送处理中消息
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot助手已收到您的请求,请稍候..."
|
||||
))
|
||||
|
||||
# 生成会话ID
|
||||
session_id = f"user_{userid}_{hash(user_message) % 10000}"
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import Dict, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.chain import ChainBase
|
||||
@@ -16,7 +17,6 @@ from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import NotExistMediaInfo
|
||||
@@ -86,13 +86,13 @@ class SearchChain(ChainBase):
|
||||
self.save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
def last_search_results(self) -> List[Context]:
|
||||
def last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
获取上次搜索结果
|
||||
"""
|
||||
return self.load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_search_results(self) -> List[Context]:
|
||||
async def async_last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次搜索结果
|
||||
"""
|
||||
@@ -324,9 +324,6 @@ class SearchChain(ChainBase):
|
||||
:param _torrents: 种子列表
|
||||
:return: 去重后的种子列表
|
||||
"""
|
||||
if not settings.SEARCH_MULTIPLE_NAME:
|
||||
return _torrents
|
||||
# 通过encosure去重
|
||||
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
|
||||
for t in _torrents}.values())
|
||||
|
||||
@@ -384,16 +381,23 @@ class SearchChain(ChainBase):
|
||||
if search_count > 0:
|
||||
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
|
||||
time.sleep(random.randint(1, 10))
|
||||
|
||||
# 搜索站点
|
||||
torrents.extend(
|
||||
self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
)
|
||||
results = self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
# 合并结果
|
||||
|
||||
search_count += 1
|
||||
torrents.extend(results)
|
||||
|
||||
# 有结果则停止
|
||||
if not settings.SEARCH_MULTIPLE_NAME and torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
# 处理结果
|
||||
return self.__parse_result(
|
||||
|
||||
@@ -56,7 +56,7 @@ class SiteChain(ChainBase):
|
||||
if userdata:
|
||||
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
|
||||
name=site.get("name"),
|
||||
payload=userdata.dict())
|
||||
payload=userdata.model_dump())
|
||||
# 发送事件
|
||||
eventmanager.send_event(EventType.SiteRefreshed, {
|
||||
"site_id": site.get("id")
|
||||
|
||||
@@ -173,7 +173,7 @@ class StorageChain(ChainBase):
|
||||
dir_item = fileitem if fileitem.type == "dir" else self.get_parent_item(fileitem)
|
||||
if not dir_item:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 上级目录不存在")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 查找操作文件项匹配的配置目录(资源目录、媒体库目录)
|
||||
associated_dir = max(
|
||||
|
||||
@@ -150,7 +150,7 @@ class TorrentsChain(ChainBase):
|
||||
return []
|
||||
# 解析RSS
|
||||
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
|
||||
timeout=int(site.get("timeout") or 30))
|
||||
timeout=int(site.get("timeout") or 30), ua=site.get("ua") if site.get("ua") else None)
|
||||
if rss_items is None:
|
||||
# rss过期,尝试保留原配置生成新的rss
|
||||
self.__renew_rss_url(domain=domain, site=site)
|
||||
|
||||
@@ -33,6 +33,7 @@ from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey,
|
||||
SystemConfigKey, ChainEventType, ContentType
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
downloader_lock = threading.Lock()
|
||||
job_lock = threading.Lock()
|
||||
@@ -329,8 +330,12 @@ class JobManager:
|
||||
# 计算状态为完成的任务数
|
||||
if __mediaid__ not in self._job_view:
|
||||
return 0
|
||||
return sum([task.fileitem.size for task in self._job_view[__mediaid__].tasks if
|
||||
task.state == "completed" and task.fileitem.size is not None])
|
||||
return sum([
|
||||
task.fileitem.size if task.fileitem.size is not None
|
||||
else (SystemUtils.get_directory_size(Path(task.fileitem.path)) if task.fileitem.storage == "local" else 0)
|
||||
for task in self._job_view[__mediaid__].tasks
|
||||
if task.state == "completed"
|
||||
])
|
||||
|
||||
def total(self) -> int:
|
||||
"""
|
||||
@@ -1111,6 +1116,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
file_meta=file_meta)
|
||||
if begin_ep is not None:
|
||||
file_meta.begin_episode = begin_ep
|
||||
if part is not None:
|
||||
file_meta.part = part
|
||||
if end_ep is not None:
|
||||
file_meta.end_episode = end_ep
|
||||
@@ -1120,10 +1126,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(str(file_path))
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(str(file_path))
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
@@ -1436,7 +1442,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
|
||||
for keyword in exclude_words:
|
||||
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
|
||||
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1472,7 +1478,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (file_path.suffix in self.all_exts
|
||||
and not self._is_blocked_by_exclude_words(str(file_path), transfer_exclude_words)
|
||||
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
|
||||
and file_path.exists()):
|
||||
return False
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from pydantic.fields import Callable
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
@@ -180,7 +180,7 @@ class WorkflowExecutor:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
for key, value in context.dict().items():
|
||||
for key, value in context.model_dump().items():
|
||||
if not getattr(self.context, key, None):
|
||||
setattr(self.context, key, value)
|
||||
|
||||
|
||||
@@ -215,7 +215,7 @@ class Command(metaclass=Singleton):
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
|
||||
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
触发事件,允许调整命令数据
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import contextvars
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union
|
||||
@@ -27,6 +29,9 @@ DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
# 上下文变量来控制缓存行为
|
||||
_fresh = contextvars.ContextVar('fresh', default=False)
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""
|
||||
@@ -455,7 +460,7 @@ class MemoryBackend(CacheBackend):
|
||||
if region_cache:
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.info(f"Cleared cache for region: {region}")
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
@@ -589,13 +594,13 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
if region_cache:
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.info(f"Cleared cache for region: {region}")
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.info("Cleared all cache")
|
||||
logger.info("All cache cleared!")
|
||||
|
||||
async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]:
|
||||
"""
|
||||
@@ -1010,6 +1015,49 @@ class AsyncFileBackend(AsyncCacheBackend):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fresh(fresh: bool = True):
|
||||
"""
|
||||
是否获取新数据(不使用缓存的值)
|
||||
|
||||
Usage:
|
||||
with fresh():
|
||||
result = some_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset fresh mode. {id(token):#x}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_fresh(fresh: bool = True):
|
||||
"""
|
||||
是否获取新数据(不使用缓存的值)
|
||||
|
||||
Usage:
|
||||
async with async_fresh():
|
||||
result = await some_async_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
|
||||
|
||||
def is_fresh() -> bool:
|
||||
"""
|
||||
是否获取新数据
|
||||
"""
|
||||
try:
|
||||
return _fresh.get()
|
||||
except LookupError:
|
||||
return False
|
||||
|
||||
def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend:
|
||||
"""
|
||||
获取文件缓存后端实例(Redis或文件系统),ttl仅在Redis环境中有效
|
||||
@@ -1084,16 +1132,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# 检查是否为异步函数
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
# 根据函数类型选择对应的缓存后端,没有ttl时默认是 LRU 缓存,否则是 TTL 缓存
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
else:
|
||||
# 同步函数使用同步缓存后端
|
||||
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
|
||||
def should_cache(value: Any) -> bool:
|
||||
"""
|
||||
@@ -1169,16 +1207,20 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
# 异步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
cache_key = __get_cache_key(args, kwargs)
|
||||
# 尝试获取缓存
|
||||
cached_value = await cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
|
||||
cache_region):
|
||||
return cached_value
|
||||
|
||||
if not is_fresh():
|
||||
# 尝试获取缓存
|
||||
cached_value = await cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
|
||||
cache_region):
|
||||
return cached_value
|
||||
# 执行异步函数并缓存结果
|
||||
result = await func(*args, **kwargs)
|
||||
# 判断是否需要缓存
|
||||
@@ -1198,15 +1240,19 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
async_wrapper.cache_clear = cache_clear
|
||||
return async_wrapper
|
||||
else:
|
||||
# 同步函数使用同步缓存后端
|
||||
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
|
||||
# 同步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
cache_key = __get_cache_key(args, kwargs)
|
||||
# 尝试获取缓存
|
||||
cached_value = cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
|
||||
return cached_value
|
||||
|
||||
if not is_fresh():
|
||||
# 尝试获取缓存
|
||||
cached_value = cache_backend.get(cache_key, region=cache_region)
|
||||
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
|
||||
return cached_value
|
||||
# 执行函数并缓存结果
|
||||
result = func(*args, **kwargs)
|
||||
# 判断是否需要缓存
|
||||
|
||||
@@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, BaseSettings, validator, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from app.log import logger, log_settings, LogConfigModel
|
||||
from app.schemas import MediaType
|
||||
@@ -49,8 +50,7 @@ class ConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# ==================== 基础应用配置 ====================
|
||||
# 项目名称
|
||||
@@ -75,6 +75,8 @@ class ConfigModel(BaseModel):
|
||||
DEBUG: bool = False
|
||||
# 是否开发模式
|
||||
DEV: bool = False
|
||||
# 高级设置模式
|
||||
ADVANCED_MODE: bool = True
|
||||
|
||||
# ==================== 安全认证配置 ====================
|
||||
# 密钥
|
||||
@@ -87,8 +89,10 @@ class ConfigModel(BaseModel):
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
|
||||
# RESOURCE_TOKEN过期时间
|
||||
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
|
||||
# 超级管理员
|
||||
# 超级管理员初始用户名
|
||||
SUPERUSER: str = "admin"
|
||||
# 超级管理员初始密码
|
||||
SUPERUSER_PASSWORD: Optional[str] = None
|
||||
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
@@ -167,7 +171,7 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 媒体元数据配置 ====================
|
||||
# 媒体搜索来源 themoviedb/douban/bangumi,多个用,分隔
|
||||
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
|
||||
SEARCH_SOURCE: str = "themoviedb"
|
||||
# 媒体识别来源 themoviedb/douban
|
||||
RECOGNIZE_SOURCE: str = "themoviedb"
|
||||
# 刮削来源 themoviedb/douban
|
||||
@@ -252,7 +256,7 @@ class ConfigModel(BaseModel):
|
||||
# 订阅搜索时间间隔(小时)
|
||||
SUBSCRIBE_SEARCH_INTERVAL: int = 24
|
||||
# 检查本地媒体库是否存在资源开关
|
||||
LOCAL_EXISTS_SEARCH: bool = False
|
||||
LOCAL_EXISTS_SEARCH: bool = True
|
||||
|
||||
# ==================== 站点配置 ====================
|
||||
# 站点数据刷新间隔(小时)
|
||||
@@ -364,6 +368,8 @@ class ConfigModel(BaseModel):
|
||||
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
|
||||
# 编码探测的最低置信度阈值
|
||||
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
|
||||
# 主动内存回收时间间隔(分钟),0为不启用
|
||||
MEMORY_GC_INTERVAL: int = 30
|
||||
|
||||
# ==================== 安全配置 ====================
|
||||
# 允许的图片缓存域名
|
||||
@@ -392,24 +398,51 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 存储配置 ====================
|
||||
# 对rclone进行快照对比时,是否检查文件夹的修改时间
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对OpenList进行快照对比时,是否检查文件夹的修改时间
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
系统配置类
|
||||
"""
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -506,33 +539,54 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
|
||||
return default, True
|
||||
|
||||
@validator('*', pre=True, always=True)
|
||||
def generic_type_validator(cls, value: Any, field): # noqa
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def generic_type_validator(cls, data: Any): # noqa
|
||||
"""
|
||||
通用校验器,尝试将配置值转换为期望的类型
|
||||
"""
|
||||
if field.name == "API_TOKEN":
|
||||
converted_value, needs_update = cls.validate_api_token(value, value)
|
||||
else:
|
||||
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
|
||||
field.name)
|
||||
if needs_update:
|
||||
cls.update_env_config(field, value, converted_value)
|
||||
return converted_value
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# 处理 API_TOKEN 特殊验证
|
||||
if 'API_TOKEN' in data:
|
||||
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
|
||||
if needs_update:
|
||||
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
|
||||
data['API_TOKEN'] = converted_value
|
||||
|
||||
# 对其他字段进行类型转换
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
value = data[field_name]
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
field = cls.model_fields.get(field_name)
|
||||
if field:
|
||||
converted_value, needs_update = cls.generic_type_converter(
|
||||
value, value, field.annotation, field.default, field_name
|
||||
)
|
||||
if needs_update:
|
||||
cls.update_env_config(field_name, value, converted_value)
|
||||
data[field_name] = converted_value
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 env 配置
|
||||
"""
|
||||
message = None
|
||||
is_converted = original_value is not None and str(original_value) != str(converted_value)
|
||||
if is_converted:
|
||||
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
logger.warning(message)
|
||||
|
||||
if field.name in os.environ:
|
||||
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
if field_name in os.environ:
|
||||
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
@@ -542,10 +596,10 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
else:
|
||||
value_to_write = str(converted_value) if converted_value is not None else ""
|
||||
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field.name, value_to_set=value_to_write,
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
|
||||
quote_mode="always")
|
||||
if is_converted:
|
||||
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
|
||||
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
|
||||
return True, message
|
||||
|
||||
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
|
||||
@@ -559,19 +613,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
|
||||
try:
|
||||
field = self.__fields__[key]
|
||||
field = Settings.model_fields[key]
|
||||
original_value = getattr(self, key)
|
||||
if field.name == "API_TOKEN":
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(value, original_value)
|
||||
else:
|
||||
converted_value, needs_update = self.generic_type_converter(value,
|
||||
original_value,
|
||||
field.type_,
|
||||
field.default,
|
||||
key)
|
||||
converted_value, needs_update = self.generic_type_converter(
|
||||
value, original_value, field.annotation, field.default, key
|
||||
)
|
||||
# 如果没有抛出异常,则统一使用 converted_value 进行更新
|
||||
if needs_update or str(value) != str(converted_value):
|
||||
success, message = self.update_env_config(field, value, converted_value)
|
||||
success, message = self.update_env_config(key, value, converted_value)
|
||||
# 仅成功更新配置时,才更新内存
|
||||
if success:
|
||||
setattr(self, key, converted_value)
|
||||
|
||||
@@ -250,6 +250,8 @@ class MediaInfo:
|
||||
production_countries: list = field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: list = field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = field(default_factory=list)
|
||||
# 状态
|
||||
status: str = None
|
||||
# 标签
|
||||
@@ -257,7 +259,7 @@ class MediaInfo:
|
||||
# 评价数量
|
||||
vote_count: int = None
|
||||
# 流行度
|
||||
popularity: int = None
|
||||
popularity: float = None
|
||||
# 时长
|
||||
runtime: int = None
|
||||
# 下一集
|
||||
@@ -433,6 +435,18 @@ class MediaInfo:
|
||||
if self.release_date:
|
||||
# 年份
|
||||
self.year = self.release_date[:4]
|
||||
# 所有发行日期
|
||||
self.release_dates = [
|
||||
{
|
||||
"date": release_date.get("release_date"),
|
||||
"iso_code": result.get("iso_3166_1"),
|
||||
"note": release_date.get("note"),
|
||||
"type": release_date.get("type"),
|
||||
}
|
||||
for result in info.get("release_dates", {}).get("results", [])
|
||||
for release_date in result.get("release_dates", [])
|
||||
if release_date.get("release_date")
|
||||
]
|
||||
else:
|
||||
# 电视剧
|
||||
self.title = info.get('name')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import random
|
||||
@@ -71,15 +72,26 @@ class EventManager(metaclass=Singleton):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
|
||||
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
|
||||
self.__event_queue = PriorityQueue() # 优先级队列
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {} # 广播事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {} # 链式事件的订阅者
|
||||
self.__disabled_handlers = set() # 禁用的事件处理器集合
|
||||
self.__disabled_classes = set() # 禁用的事件处理器类集合
|
||||
self.__lock = threading.Lock() # 线程锁
|
||||
self.__event = threading.Event() # 退出事件
|
||||
# 动态线程池,用于消费事件
|
||||
self.__executor = ThreadHelper()
|
||||
# 用于保存启动的事件消费者线程
|
||||
self.__consumer_threads = []
|
||||
# 优先级队列
|
||||
self.__event_queue = PriorityQueue()
|
||||
# 广播事件的订阅者
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {}
|
||||
# 链式事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {}
|
||||
# 禁用的事件处理器集合
|
||||
self.__disabled_handlers = set()
|
||||
# 禁用的事件处理器类集合
|
||||
self.__disabled_classes = set()
|
||||
# 线程锁
|
||||
self.__lock = threading.Lock()
|
||||
# 退出事件
|
||||
self.__event = threading.Event()
|
||||
# 当前事件循环
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -438,7 +450,15 @@ class EventManager(metaclass=Singleton):
|
||||
isolated_event = Event(event_type=event.event_type,
|
||||
event_data=event_data_copy,
|
||||
priority=event.priority)
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
self.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
|
||||
|
||||
def __safe_invoke_handler(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -566,7 +586,8 @@ class EventManager(metaclass=Singleton):
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
await run_in_threadpool(method, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, handler=handler, e=e, module_name=plugin.name)
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_module_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
|
||||
@@ -94,7 +94,6 @@ class MetaVideo(MetaBase):
|
||||
title = re.sub(r'\d{4}[\s._-]\d{1,2}[\s._-]\d{1,2}', "", title)
|
||||
# 拆分tokens
|
||||
tokens = Tokens(title)
|
||||
self.tokens = tokens
|
||||
# 实例化StreamingPlatforms对象
|
||||
streaming_platforms = StreamingPlatforms()
|
||||
# 解析名称、年份、季、集、资源类型、分辨率等
|
||||
@@ -102,7 +101,7 @@ class MetaVideo(MetaBase):
|
||||
while token:
|
||||
self._index += 1 # 更新当前处理的token索引
|
||||
# Part
|
||||
self.__init_part(token)
|
||||
self.__init_part(token, tokens)
|
||||
# 标题
|
||||
if self._continue_flag:
|
||||
self.__init_name(token)
|
||||
@@ -123,7 +122,7 @@ class MetaVideo(MetaBase):
|
||||
self.__init_resource_type(token)
|
||||
# 流媒体平台
|
||||
if self._continue_flag:
|
||||
self.__init_web_source(token, streaming_platforms)
|
||||
self.__init_web_source(token, tokens, streaming_platforms)
|
||||
# 视频编码
|
||||
if self._continue_flag:
|
||||
self.__init_video_encode(token)
|
||||
@@ -311,7 +310,7 @@ class MetaVideo(MetaBase):
|
||||
self.en_name = token
|
||||
self._last_token_type = "enname"
|
||||
|
||||
def __init_part(self, token: str):
|
||||
def __init_part(self, token: str, tokens: Tokens):
|
||||
"""
|
||||
识别Part
|
||||
"""
|
||||
@@ -327,12 +326,12 @@ class MetaVideo(MetaBase):
|
||||
if re_res:
|
||||
if not self.part:
|
||||
self.part = re_res.group(1)
|
||||
nextv = self.tokens.cur()
|
||||
nextv = tokens.cur()
|
||||
if nextv \
|
||||
and ((nextv.isdigit() and (len(nextv) == 1 or len(nextv) == 2 and nextv.startswith('0')))
|
||||
or nextv.upper() in ['A', 'B', 'C', 'I', 'II', 'III']):
|
||||
self.part = "%s%s" % (self.part, nextv)
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
self._last_token_type = "part"
|
||||
self._continue_flag = False
|
||||
# self._stop_name_flag = False
|
||||
@@ -582,7 +581,7 @@ class MetaVideo(MetaBase):
|
||||
self._effect.append(effect)
|
||||
self._last_token = effect.upper()
|
||||
|
||||
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
|
||||
def __init_web_source(self, token: str, tokens: Tokens, streaming_platforms: StreamingPlatforms):
|
||||
"""
|
||||
识别流媒体平台
|
||||
"""
|
||||
@@ -594,10 +593,10 @@ class MetaVideo(MetaBase):
|
||||
|
||||
prev_token = None
|
||||
prev_idx = self._index - 2
|
||||
if 0 <= prev_idx < len(self.tokens.tokens):
|
||||
prev_token = self.tokens.tokens[prev_idx]
|
||||
if 0 <= prev_idx < len(tokens.tokens):
|
||||
prev_token = tokens.tokens[prev_idx]
|
||||
|
||||
next_token = self.tokens.peek()
|
||||
next_token = tokens.peek()
|
||||
|
||||
if streaming_platforms.is_streaming_platform(token):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(token)
|
||||
@@ -616,7 +615,7 @@ class MetaVideo(MetaBase):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
|
||||
query_range = 2
|
||||
if is_next:
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
break
|
||||
|
||||
if not platform_name:
|
||||
@@ -626,8 +625,8 @@ class MetaVideo(MetaBase):
|
||||
match_start_idx = self._index - query_range
|
||||
match_end_idx = self._index - 1
|
||||
start_index = max(0, match_start_idx - query_range)
|
||||
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = self.tokens.tokens[start_index:end_index]
|
||||
end_index = min(len(tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = tokens.tokens[start_index:end_index]
|
||||
|
||||
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
|
||||
self.web_source = platform_name
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
@@ -9,14 +10,15 @@ import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
from watchfiles import watch
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
@@ -26,64 +28,12 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.limit import rate_limit_window
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginMonitorHandler(FileSystemEventHandler):
|
||||
|
||||
def on_modified(self, event):
|
||||
"""
|
||||
插件文件修改后重载
|
||||
"""
|
||||
if event.is_directory:
|
||||
return
|
||||
# 使用 pathlib 处理文件路径,跳过非 .py 文件以及 pycache 目录中的文件
|
||||
event_path = Path(event.src_path)
|
||||
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
|
||||
return
|
||||
|
||||
# 读取插件根目录下的__init__.py文件,读取class XXXX(_PluginBase)的类名
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if plugins_root not in event_path.parents:
|
||||
return
|
||||
# 获取插件目录路径,没有找到__init__.py时,说明不是有效包,跳过插件重载
|
||||
# 插件重载目前没有支持app/plugins/plugin/package/__init__.py的场景,这里也不做支持
|
||||
plugin_dir = event_path.parent
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
logger.debug(f"{plugin_dir} 下没有找到 __init__.py,跳过插件重载")
|
||||
return
|
||||
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
pid = None
|
||||
for line in lines:
|
||||
if line.startswith("class") and "(_PluginBase)" in line:
|
||||
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
|
||||
if pid:
|
||||
self.__reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
@staticmethod
|
||||
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
|
||||
def __reload_plugin(pid):
|
||||
"""
|
||||
重新加载插件
|
||||
"""
|
||||
try:
|
||||
logger.info(f"插件 {pid} 文件修改,重新加载...")
|
||||
PluginManager().reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
@@ -96,8 +46,10 @@ class PluginManager(metaclass=Singleton):
|
||||
self._running_plugins: dict = {}
|
||||
# 配置Key
|
||||
self._config_key: str = "plugin.%s"
|
||||
# 监听器
|
||||
self._observer: Observer = None
|
||||
# 监控线程
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -264,7 +216,6 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(module_name)
|
||||
importlib.reload(module)
|
||||
|
||||
# 检查模块中的类
|
||||
for name, obj in module.__dict__.items():
|
||||
@@ -318,10 +269,9 @@ class PluginManager(metaclass=Singleton):
|
||||
重新加载插件文件修改监测
|
||||
"""
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
if self._observer and self._observer.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
else:
|
||||
self.__start_monitor()
|
||||
# 先关闭已有监测,再重新启动
|
||||
self.stop_monitor()
|
||||
self.__start_monitor()
|
||||
else:
|
||||
self.stop_monitor()
|
||||
|
||||
@@ -329,25 +279,123 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
启用监测插件文件修改监测
|
||||
"""
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
return
|
||||
|
||||
logger.info("开始监测插件文件修改...")
|
||||
monitor_handler = PluginMonitorHandler()
|
||||
self._observer = Observer()
|
||||
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
|
||||
self._observer.start()
|
||||
|
||||
# 在启动新线程之前,确保停止事件是清除状态
|
||||
self._stop_monitor_event.clear()
|
||||
|
||||
# 创建并启动监控线程
|
||||
self._monitor_thread = threading.Thread(
|
||||
target=self._run_file_watcher,
|
||||
daemon=True
|
||||
)
|
||||
self._monitor_thread.start()
|
||||
|
||||
def stop_monitor(self):
|
||||
"""
|
||||
停止监测插件文件修改监测
|
||||
"""
|
||||
# 停止监测
|
||||
if self._observer and self._observer.is_alive():
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("正在停止插件文件修改监测...")
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
self._stop_monitor_event.set()
|
||||
self._monitor_thread.join(timeout=5)
|
||||
if self._monitor_thread.is_alive():
|
||||
logger.warning("插件文件修改监测线程在5秒内未能正常停止。")
|
||||
self._monitor_thread = None
|
||||
logger.info("插件文件修改监测停止完成")
|
||||
else:
|
||||
logger.info("未启用插件文件修改监测,无需停止")
|
||||
|
||||
def _run_file_watcher(self):
|
||||
"""
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
if plugins_to_reload:
|
||||
logger.info(f"检测到插件文件变化,准备重载: {list(plugins_to_reload)}")
|
||||
for pid in plugins_to_reload:
|
||||
try:
|
||||
self.reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {pid} 热重载失败: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_id_from_path(event_path: Path) -> Optional[str]:
|
||||
"""
|
||||
根据文件路径解析出插件的ID。
|
||||
:param event_path: 被修改文件的 Path 对象。
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
return None
|
||||
|
||||
try:
|
||||
plugin_dir_name = event_path.relative_to(plugins_root).parts[0]
|
||||
plugin_dir = plugins_root / plugin_dir_name
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
return None
|
||||
|
||||
# 读取 __init__.py 文件,查找插件主类名
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# 遍历AST,查找继承自 _PluginBase 的类
|
||||
for node in ast.walk(tree):
|
||||
# 检查节点是否为类定义
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# 遍历该类的所有基类
|
||||
for base in node.bases:
|
||||
# 检查基类是否是我们寻找的 _PluginBase
|
||||
# ast.Name 用于处理简单的基类名
|
||||
if isinstance(base, ast.Name) and base.id == '_PluginBase':
|
||||
# 返回这个类的名字
|
||||
return node.name
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -410,6 +458,10 @@ class PluginManager(metaclass=Singleton):
|
||||
except KeyError:
|
||||
# 模块可能已经被删除
|
||||
pass
|
||||
|
||||
importlib.invalidate_caches()
|
||||
logger.debug("已清除查找器的缓存")
|
||||
|
||||
if plugin_id:
|
||||
if modules_to_remove:
|
||||
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
|
||||
@@ -693,6 +745,36 @@ class PluginManager(metaclass=Singleton):
|
||||
logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}")
|
||||
return ret_actions
|
||||
|
||||
def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
[{
|
||||
"plugin_id": "插件ID",
|
||||
"plugin_name": "插件名称",
|
||||
"tools": [ToolClass1, ToolClass2, ...]
|
||||
}]
|
||||
"""
|
||||
ret_tools = []
|
||||
# 创建字典快照避免并发修改
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if pid and pid != plugin_id:
|
||||
continue
|
||||
if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools):
|
||||
try:
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
tools = plugin.get_agent_tools()
|
||||
if tools:
|
||||
ret_tools.append({
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}")
|
||||
return ret_tools
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str:
|
||||
"""
|
||||
@@ -1024,7 +1106,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version, force)
|
||||
with fresh(force):
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
@@ -1231,7 +1314,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
|
||||
async with async_fresh(force):
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
|
||||
@@ -66,6 +66,12 @@ class Site(Base):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_name(cls, db: AsyncSession, name: str):
|
||||
result = await db.execute(select(cls).where(cls.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_actives(cls, db: Session):
|
||||
|
||||
@@ -85,6 +85,12 @@ class SiteOper(DbOper):
|
||||
"""
|
||||
return await Site.async_get_by_domain(self._db, domain)
|
||||
|
||||
async def async_get_by_name(self, name: str) -> Site:
|
||||
"""
|
||||
异步按名称获取站点
|
||||
"""
|
||||
return await Site.async_get_by_name(self._db, name)
|
||||
|
||||
def get_domains_by_ids(self, ids: List[int]) -> List[str]:
|
||||
"""
|
||||
按ID获取站点域名
|
||||
|
||||
@@ -128,10 +128,10 @@ class TransferHistoryOper(DbOper):
|
||||
self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
dest=transferinfo.target_item.path if transferinfo.target_item else None,
|
||||
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
|
||||
mode=mode,
|
||||
type=mediainfo.type.value,
|
||||
category=mediainfo.category,
|
||||
@@ -159,10 +159,10 @@ class TransferHistoryOper(DbOper):
|
||||
his = self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
dest=transferinfo.target_item.path if transferinfo.target_item else None,
|
||||
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
|
||||
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
|
||||
mode=mode,
|
||||
type=mediainfo.type.value,
|
||||
category=mediainfo.category,
|
||||
@@ -188,7 +188,7 @@ class TransferHistoryOper(DbOper):
|
||||
year=meta.year,
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.dict(),
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
mode=mode,
|
||||
seasons=meta.season,
|
||||
episodes=meta.episode,
|
||||
|
||||
@@ -367,7 +367,6 @@ class TemplateHelper(metaclass=SingletonClass):
|
||||
return rendered
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"模板处理失败: {str(e)}")
|
||||
raise ValueError(f"模板处理失败: {str(e)}") from e
|
||||
|
||||
@staticmethod
|
||||
@@ -713,6 +712,7 @@ class MessageQueueManager(metaclass=SingletonClass):
|
||||
self._running = False
|
||||
logger.info("正在停止消息队列...")
|
||||
self.thread.join()
|
||||
logger.info("消息队列已停止")
|
||||
|
||||
|
||||
class MessageHelper(metaclass=Singleton):
|
||||
|
||||
@@ -48,35 +48,13 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if self.install_report():
|
||||
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
|
||||
|
||||
def get_plugins(self, repo_url: str, package_version: Optional[str] = None,
|
||||
force: bool = False) -> Optional[Dict[str, dict]]:
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
def get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
:param force: 是否强制刷新,忽略缓存
|
||||
"""
|
||||
# 如果强制刷新,直接调用不带缓存的版本
|
||||
if force:
|
||||
return self._get_plugins_uncached(repo_url, package_version)
|
||||
|
||||
# 正常情况下调用带缓存的版本
|
||||
return self._get_plugins_cached(repo_url, package_version)
|
||||
|
||||
@cached(maxsize=64, ttl=1800)
|
||||
def _get_plugins_cached(self, repo_url: str, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
return self._get_plugins_uncached(repo_url, package_version)
|
||||
|
||||
def _get_plugins_uncached(self, repo_url: str, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(不使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
if not repo_url:
|
||||
return None
|
||||
@@ -161,7 +139,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
def install_reg(self, pid: str) -> bool:
|
||||
def install_reg(self, pid: str, repo_url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
安装插件统计
|
||||
"""
|
||||
@@ -170,24 +148,39 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not pid:
|
||||
return False
|
||||
install_reg_url = self._install_reg.format(pid=pid)
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=5).get_res(install_reg_url)
|
||||
res = RequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
content_type="application/json",
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install_report(self) -> bool:
|
||||
def install_report(self, items: Optional[List[Tuple[str, Optional[str]]]] = None) -> bool:
|
||||
"""
|
||||
上报存量插件安装统计
|
||||
上报存量插件安装统计(批量)。支持上送 repo_url。
|
||||
:param items: 可选,形如 [(plugin_id, repo_url), ...];不传则回落到历史配置,仅上送 plugin_id。
|
||||
"""
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return False
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
return False
|
||||
payload_plugins = []
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
return False
|
||||
payload_plugins = [{"plugin_id": plugin, "repo_url": None} for plugin in plugins]
|
||||
res = RequestUtils(proxies=settings.PROXY,
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": [{"plugin_id": plugin} for plugin in plugins]})
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
|
||||
def install(self, pid: str, repo_url: str, package_version: Optional[str] = None, force_install: bool = False) \
|
||||
@@ -252,16 +245,16 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
# 使用 release 进行安装
|
||||
def prepare_release() -> Tuple[bool, str]:
|
||||
return self.__install_from_release(
|
||||
pid.lower(), user_repo, release_tag
|
||||
pid, user_repo, release_tag
|
||||
)
|
||||
|
||||
return self.__install_flow_sync(pid.lower(), force_install, prepare_release)
|
||||
return self.__install_flow_sync(pid, force_install, prepare_release, repo_url)
|
||||
else:
|
||||
# 如果 release_tag 不存在,说明插件没有发布版本,使用文件列表方式安装
|
||||
def prepare_filelist() -> Tuple[bool, str]:
|
||||
return self.__prepare_content_via_filelist_sync(pid.lower(), user_repo, package_version)
|
||||
|
||||
return self.__install_flow_sync(pid.lower(), force_install, prepare_filelist)
|
||||
return self.__install_flow_sync(pid, force_install, prepare_filelist, repo_url)
|
||||
|
||||
def __get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
Tuple[Optional[list], Optional[str]]:
|
||||
@@ -275,7 +268,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
# 如果 package_version 存在(如 "v2"),则加上版本号
|
||||
if package_version:
|
||||
file_api += f".{package_version}"
|
||||
file_api += f"/{pid}"
|
||||
file_api += f"/{pid.lower()}"
|
||||
|
||||
res = self.__request_with_fallback(file_api,
|
||||
headers=settings.REPO_GITHUB_HEADERS(repo=user_repo),
|
||||
@@ -408,8 +401,8 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param pid: 插件 ID
|
||||
:return: 备份目录路径
|
||||
"""
|
||||
plugin_dir = PLUGIN_DIR / pid
|
||||
backup_dir = Path(settings.TEMP_PATH) / "plugin_backup" / pid
|
||||
plugin_dir = PLUGIN_DIR / pid.lower()
|
||||
backup_dir = Path(settings.TEMP_PATH) / "plugin_backup" / pid.lower()
|
||||
|
||||
if plugin_dir.exists():
|
||||
# 备份时清理已有的备份目录,防止残留文件影响
|
||||
@@ -429,7 +422,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param pid: 插件 ID
|
||||
:param backup_dir: 备份目录路径
|
||||
"""
|
||||
plugin_dir = PLUGIN_DIR / pid
|
||||
plugin_dir = PLUGIN_DIR / pid.lower()
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
logger.debug(f"{pid} 已清理插件目录 {plugin_dir}")
|
||||
@@ -446,7 +439,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
删除旧插件
|
||||
:param pid: 插件 ID
|
||||
"""
|
||||
plugin_dir = PLUGIN_DIR / pid
|
||||
plugin_dir = PLUGIN_DIR / pid.lower()
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
|
||||
@@ -557,41 +550,42 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"获取插件 {pid} 元数据失败:{e}")
|
||||
return {}
|
||||
|
||||
def __install_flow_sync(self, pid_lower: str, force_install: bool,
|
||||
prepare_content: Callable[[], Tuple[bool, str]]) -> Tuple[bool, str]:
|
||||
def __install_flow_sync(self, pid: str, force_install: bool,
|
||||
prepare_content: Callable[[], Tuple[bool, str]],
|
||||
repo_url: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
同步安装统一流程:备份→清理→准备内容→安装依赖→上报
|
||||
prepare_content 负责把插件文件放到 app/plugins/{pid}
|
||||
"""
|
||||
backup_dir = None
|
||||
if not force_install:
|
||||
backup_dir = self.__backup_plugin(pid_lower)
|
||||
backup_dir = self.__backup_plugin(pid)
|
||||
|
||||
self.__remove_old_plugin(pid_lower)
|
||||
self.__remove_old_plugin(pid)
|
||||
|
||||
success, message = prepare_content()
|
||||
if not success:
|
||||
logger.error(f"{pid_lower} 准备插件内容失败:{message}")
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid_lower, backup_dir)
|
||||
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid_lower)
|
||||
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid_lower)
|
||||
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid)
|
||||
if dependencies_exist and not dep_ok:
|
||||
logger.error(f"{pid_lower} 依赖安装失败:{dep_msg}")
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid_lower, backup_dir)
|
||||
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid_lower)
|
||||
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
self.install_reg(pid_lower)
|
||||
self.install_reg(pid, repo_url)
|
||||
return True, ""
|
||||
|
||||
def __install_from_release(self, pid: str, user_repo: str, release_tag: str) -> Tuple[bool, str]:
|
||||
@@ -915,35 +909,13 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"[GitHub] 所有策略均请求失败,URL: {url},请检查网络连接或 GitHub 配置")
|
||||
return None
|
||||
|
||||
async def async_get_plugins(self, repo_url: str, package_version: Optional[str] = None,
|
||||
force: bool = False) -> Optional[Dict[str, dict]]:
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
async def async_get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
:param force: 是否强制刷新,忽略缓存
|
||||
"""
|
||||
# 异步版本直接调用不带缓存的版本(缓存在异步环境下可能有并发问题)
|
||||
if force:
|
||||
await self._async_get_plugins_cached.cache_clear()
|
||||
return await self._async_get_plugins_cached(repo_url, package_version)
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
async def _async_get_plugins_cached(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表(使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
return await self._async_get_plugins_uncached(repo_url, package_version)
|
||||
|
||||
async def _async_get_plugins_uncached(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表(不使用缓存)
|
||||
:param repo_url: Github仓库地址
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
"""
|
||||
if not repo_url:
|
||||
return None
|
||||
@@ -980,7 +952,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
async def async_install_reg(self, pid: str) -> bool:
|
||||
async def async_install_reg(self, pid: str, repo_url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
异步安装插件统计
|
||||
"""
|
||||
@@ -989,24 +961,39 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not pid:
|
||||
return False
|
||||
install_reg_url = self._install_reg.format(pid=pid)
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5).get_res(install_reg_url)
|
||||
res = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
content_type="application/json",
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def async_install_report(self) -> bool:
|
||||
async def async_install_report(self, items: Optional[List[Tuple[str, Optional[str]]]] = None) -> bool:
|
||||
"""
|
||||
异步上报存量插件安装统计
|
||||
异步上报存量插件安装统计(批量)。支持上送 repo_url。
|
||||
:param items: 可选,形如 [(plugin_id, repo_url), ...];不传则回落到历史配置,仅上送 plugin_id。
|
||||
"""
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return False
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
return False
|
||||
payload_plugins = []
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
return False
|
||||
payload_plugins = [{"plugin_id": plugin, "repo_url": None} for plugin in plugins]
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY,
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": [{"plugin_id": plugin} for plugin in plugins]})
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
|
||||
async def __async_get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
@@ -1021,7 +1008,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
# 如果 package_version 存在(如 "v2"),则加上版本号
|
||||
if package_version:
|
||||
file_api += f".{package_version}"
|
||||
file_api += f"/{pid}"
|
||||
file_api += f"/{pid.lower()}"
|
||||
|
||||
res = await self.__async_request_with_fallback(file_api,
|
||||
headers=settings.REPO_GITHUB_HEADERS(repo=user_repo),
|
||||
@@ -1133,8 +1120,8 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param pid: 插件 ID
|
||||
:return: 备份目录路径
|
||||
"""
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
|
||||
backup_dir = AsyncPath(settings.TEMP_PATH) / "plugin_backup" / pid
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
|
||||
backup_dir = AsyncPath(settings.TEMP_PATH) / "plugin_backup" / pid.lower()
|
||||
|
||||
if await plugin_dir.exists():
|
||||
# 备份时清理已有的备份目录,防止残留文件影响
|
||||
@@ -1154,7 +1141,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param pid: 插件 ID
|
||||
:param backup_dir: 备份目录路径
|
||||
"""
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
|
||||
if await plugin_dir.exists():
|
||||
await aioshutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
logger.debug(f"{pid} 已清理插件目录 {plugin_dir}")
|
||||
@@ -1172,7 +1159,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
异步删除旧插件
|
||||
:param pid: 插件 ID
|
||||
"""
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
|
||||
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
|
||||
if await plugin_dir.exists():
|
||||
await aioshutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
|
||||
@@ -1414,16 +1401,16 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
# 使用 release 进行安装
|
||||
async def prepare_release() -> Tuple[bool, str]:
|
||||
return await self.__async_install_from_release(
|
||||
pid.lower(), user_repo, release_tag
|
||||
pid, user_repo, release_tag
|
||||
)
|
||||
|
||||
return await self.__install_flow_async(pid.lower(), force_install, prepare_release)
|
||||
return await self.__install_flow_async(pid, force_install, prepare_release, repo_url)
|
||||
else:
|
||||
# 如果没有 release_tag,则使用文件列表安装方式
|
||||
async def prepare_filelist() -> Tuple[bool, str]:
|
||||
return await self.__prepare_content_via_filelist_async(pid.lower(), user_repo, package_version)
|
||||
return await self.__prepare_content_via_filelist_async(pid, user_repo, package_version)
|
||||
|
||||
return await self.__install_flow_async(pid.lower(), force_install, prepare_filelist)
|
||||
return await self.__install_flow_async(pid, force_install, prepare_filelist, repo_url)
|
||||
|
||||
async def __async_get_plugin_meta(self, pid: str, repo_url: str,
|
||||
package_version: Optional[str]) -> dict:
|
||||
@@ -1438,78 +1425,79 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.warn(f"获取插件 {pid} 元数据失败:{e}")
|
||||
return {}
|
||||
|
||||
async def __install_flow_async(self, pid_lower: str, force_install: bool,
|
||||
prepare_content: Callable[[], Awaitable[Tuple[bool, str]]]) -> Tuple[bool, str]:
|
||||
async def __install_flow_async(self, pid: str, force_install: bool,
|
||||
prepare_content: Callable[[], Awaitable[Tuple[bool, str]]],
|
||||
repo_url: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
异步安装流程,处理插件内容准备、依赖安装和注册
|
||||
"""
|
||||
backup_dir = None
|
||||
if not force_install:
|
||||
backup_dir = await self.__async_backup_plugin(pid_lower)
|
||||
backup_dir = await self.__async_backup_plugin(pid)
|
||||
|
||||
await self.__async_remove_old_plugin(pid_lower)
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
|
||||
success, message = await prepare_content()
|
||||
if not success:
|
||||
logger.error(f"{pid_lower} 准备插件内容失败:{message}")
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid_lower, backup_dir)
|
||||
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid_lower)
|
||||
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid_lower)
|
||||
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid)
|
||||
if dependencies_exist and not dep_ok:
|
||||
logger.error(f"{pid_lower} 依赖安装失败:{dep_msg}")
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid_lower, backup_dir)
|
||||
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid_lower)
|
||||
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
await self.async_install_reg(pid_lower)
|
||||
await self.async_install_reg(pid, repo_url)
|
||||
return True, ""
|
||||
|
||||
def __prepare_content_via_filelist_sync(self, pid_lower: str, user_repo: str,
|
||||
def __prepare_content_via_filelist_sync(self, pid: str, user_repo: str,
|
||||
package_version: Optional[str]) -> Tuple[bool, str]:
|
||||
"""
|
||||
同步准备插件内容,通过文件列表获取插件文件和依赖
|
||||
"""
|
||||
file_list, msg = self.__get_file_list(pid_lower, user_repo, package_version)
|
||||
file_list, msg = self.__get_file_list(pid, user_repo, package_version)
|
||||
if not file_list:
|
||||
return False, msg
|
||||
requirements_file_info = next((f for f in file_list if f.get("name") == "requirements.txt"), None)
|
||||
if requirements_file_info:
|
||||
ok, m = self.__download_and_install_requirements(requirements_file_info, pid_lower, user_repo)
|
||||
ok, m = self.__download_and_install_requirements(requirements_file_info, pid, user_repo)
|
||||
if not ok:
|
||||
logger.debug(f"{pid_lower} 依赖预安装失败:{m}")
|
||||
logger.debug(f"{pid} 依赖预安装失败:{m}")
|
||||
else:
|
||||
logger.debug(f"{pid_lower} 依赖预安装成功")
|
||||
ok, m = self.__download_files(pid_lower, file_list, user_repo, package_version, True)
|
||||
logger.debug(f"{pid} 依赖预安装成功")
|
||||
ok, m = self.__download_files(pid, file_list, user_repo, package_version, True)
|
||||
if not ok:
|
||||
return False, m
|
||||
return True, ""
|
||||
|
||||
async def __prepare_content_via_filelist_async(self, pid_lower: str, user_repo: str,
|
||||
async def __prepare_content_via_filelist_async(self, pid: str, user_repo: str,
|
||||
package_version: Optional[str]) -> Tuple[bool, str]:
|
||||
"""
|
||||
异步准备插件内容,通过文件列表获取插件文件和依赖
|
||||
"""
|
||||
file_list, msg = await self.__async_get_file_list(pid_lower, user_repo, package_version)
|
||||
file_list, msg = await self.__async_get_file_list(pid, user_repo, package_version)
|
||||
if not file_list:
|
||||
return False, msg
|
||||
requirements_file_info = next((f for f in file_list if f.get("name") == "requirements.txt"), None)
|
||||
if requirements_file_info:
|
||||
ok, m = await self.__async_download_and_install_requirements(requirements_file_info, pid_lower, user_repo)
|
||||
ok, m = await self.__async_download_and_install_requirements(requirements_file_info, pid, user_repo)
|
||||
if not ok:
|
||||
logger.debug(f"{pid_lower} 依赖预安装失败:{m}")
|
||||
logger.debug(f"{pid} 依赖预安装失败:{m}")
|
||||
else:
|
||||
logger.debug(f"{pid_lower} 依赖预安装成功")
|
||||
ok, m = await self.__async_download_files(pid_lower, file_list, user_repo, package_version, True)
|
||||
logger.debug(f"{pid} 依赖预安装成功")
|
||||
ok, m = await self.__async_download_files(pid, file_list, user_repo, package_version, True)
|
||||
if not ok:
|
||||
return False, m
|
||||
return True, ""
|
||||
|
||||
@@ -258,10 +258,10 @@ class RedisHelper(metaclass=Singleton):
|
||||
for key in self.client.scan_iter(redis_key):
|
||||
pipe.delete(key)
|
||||
pipe.execute()
|
||||
logger.info(f"Cleared Redis cache for region: {region}")
|
||||
logger.debug(f"Cleared Redis cache for region: {region}")
|
||||
else:
|
||||
self.client.flushdb()
|
||||
logger.info("Cleared all Redis cache")
|
||||
logger.info("All Redis cache Cleared!")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache, region: {region}, error: {e}")
|
||||
|
||||
@@ -496,7 +496,7 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
async for key in self.client.scan_iter(redis_key):
|
||||
await pipe.delete(key)
|
||||
await pipe.execute()
|
||||
logger.info(f"Cleared Redis cache for region (async): {region}")
|
||||
logger.debug(f"Cleared Redis cache for region (async): {region}")
|
||||
else:
|
||||
await self.client.flushdb()
|
||||
logger.info("Cleared all Redis cache (async)")
|
||||
|
||||
@@ -228,13 +228,14 @@ class RssHelper:
|
||||
}
|
||||
|
||||
def parse(self, url, proxy: bool = False,
|
||||
timeout: Optional[int] = 15, headers: dict = None) -> Union[List[dict], None, bool]:
|
||||
timeout: Optional[int] = 15, headers: dict = None, ua: str = None) -> Union[List[dict], None, bool]:
|
||||
"""
|
||||
解析RSS订阅URL,获取RSS中的种子信息
|
||||
:param url: RSS地址
|
||||
:param proxy: 是否使用代理
|
||||
:param timeout: 请求超时
|
||||
:param headers: 自定义请求头
|
||||
:param ua: 自定义User-Agent
|
||||
:return: 种子信息列表,如为None代表Rss过期,如果为False则为错误
|
||||
"""
|
||||
# 开始处理
|
||||
@@ -243,8 +244,9 @@ class RssHelper:
|
||||
return False
|
||||
|
||||
try:
|
||||
ret = RequestUtils(proxies=settings.PROXY if proxy else None,
|
||||
timeout=timeout, headers=headers).get_res(url)
|
||||
ret = RequestUtils(ua=ua,
|
||||
proxies=settings.PROXY if proxy else None,
|
||||
timeout=timeout or 30, headers=headers).get_res(url)
|
||||
if not ret:
|
||||
logger.error(f"获取RSS失败:请求返回空值,URL: {url}")
|
||||
return False
|
||||
@@ -384,6 +386,9 @@ class RssHelper:
|
||||
pubdate = ""
|
||||
if pubdate_nodes and pubdate_nodes[0].text:
|
||||
pubdate = StringUtils.get_time(pubdate_nodes[0].text)
|
||||
if pubdate is not None:
|
||||
# 转为本地时区
|
||||
pubdate = pubdate.astimezone(tz=None)
|
||||
|
||||
# 获取豆瓣昵称
|
||||
nickname_nodes = item.xpath('.//*[local-name()="creator"]')
|
||||
|
||||
@@ -47,7 +47,7 @@ class StorageHelper:
|
||||
if s.type == storage:
|
||||
s.config = conf
|
||||
break
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
def add_storage(self, storage: str, name: str, conf: dict):
|
||||
"""
|
||||
@@ -68,7 +68,7 @@ class StorageHelper:
|
||||
name=name,
|
||||
config=conf
|
||||
))
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
def reset_storage(self, storage: str):
|
||||
"""
|
||||
@@ -79,4 +79,4 @@ class StorageHelper:
|
||||
if s.type == storage:
|
||||
s.config = {}
|
||||
break
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
|
||||
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
|
||||
|
||||
@@ -131,7 +131,9 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
return []
|
||||
|
||||
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
|
||||
def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
获取订阅统计数据
|
||||
"""
|
||||
@@ -139,16 +141,30 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
|
||||
params = {
|
||||
"stype": stype,
|
||||
"page": page,
|
||||
"count": count
|
||||
})
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
if genre_id is not None:
|
||||
params["genre_id"] = genre_id
|
||||
if min_rating is not None:
|
||||
params["min_rating"] = min_rating
|
||||
if max_rating is not None:
|
||||
params["max_rating"] = max_rating
|
||||
if sort_type is not None:
|
||||
params["sort_type"] = sort_type
|
||||
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params=params)
|
||||
|
||||
return self._handle_list_response(res)
|
||||
|
||||
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
|
||||
async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
异步获取订阅统计数据
|
||||
"""
|
||||
@@ -156,11 +172,23 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
|
||||
params = {
|
||||
"stype": stype,
|
||||
"page": page,
|
||||
"count": count
|
||||
})
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
if genre_id is not None:
|
||||
params["genre_id"] = genre_id
|
||||
if min_rating is not None:
|
||||
params["min_rating"] = min_rating
|
||||
if max_rating is not None:
|
||||
params["max_rating"] = max_rating
|
||||
if sort_type is not None:
|
||||
params["sort_type"] = sort_type
|
||||
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params=params)
|
||||
|
||||
return self._handle_list_response(res)
|
||||
|
||||
@@ -358,7 +386,9 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
return self._handle_response(res, clear_cache=False)
|
||||
|
||||
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
|
||||
def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
获取订阅分享数据
|
||||
"""
|
||||
@@ -366,17 +396,30 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
|
||||
params = {
|
||||
"name": name,
|
||||
"page": page,
|
||||
"count": count
|
||||
})
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
if genre_id is not None:
|
||||
params["genre_id"] = genre_id
|
||||
if min_rating is not None:
|
||||
params["min_rating"] = min_rating
|
||||
if max_rating is not None:
|
||||
params["max_rating"] = max_rating
|
||||
if sort_type is not None:
|
||||
params["sort_type"] = sort_type
|
||||
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params=params)
|
||||
|
||||
return self._handle_list_response(res)
|
||||
|
||||
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
|
||||
async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> \
|
||||
List[dict]:
|
||||
async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
异步获取订阅分享数据
|
||||
"""
|
||||
@@ -384,11 +427,23 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
|
||||
params = {
|
||||
"name": name,
|
||||
"page": page,
|
||||
"count": count
|
||||
})
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
if genre_id is not None:
|
||||
params["genre_id"] = genre_id
|
||||
if min_rating is not None:
|
||||
params["min_rating"] = min_rating
|
||||
if max_rating is not None:
|
||||
params["max_rating"] = max_rating
|
||||
if sort_type is not None:
|
||||
params["sort_type"] = sort_type
|
||||
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params=params)
|
||||
|
||||
return self._handle_list_response(res)
|
||||
|
||||
|
||||
15
app/log.py
15
app/log.py
@@ -11,7 +11,8 @@ from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import click
|
||||
from pydantic import BaseSettings, BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -21,8 +22,7 @@ class LogConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# 配置文件目录
|
||||
CONFIG_DIR: Optional[str] = None
|
||||
@@ -71,10 +71,11 @@ class LogSettings(BaseSettings, LogConfigModel):
|
||||
"""
|
||||
return self.LOG_MAX_FILE_SIZE * 1024 * 1024
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = ConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
# 实例化日志设置
|
||||
|
||||
@@ -95,4 +95,4 @@ if __name__ == '__main__':
|
||||
# 更新数据库
|
||||
update_db()
|
||||
# 启动API服务
|
||||
Server.run()
|
||||
Server.run()
|
||||
@@ -232,6 +232,19 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
super().__init__()
|
||||
self._default_config_name: Optional[str] = None
|
||||
|
||||
def init_service(self, service_name: str,
|
||||
service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None):
|
||||
"""
|
||||
初始化服务,获取配置并实例化对应服务
|
||||
|
||||
:param service_name: 服务名称,作为配置匹配的依据
|
||||
:param service_type: 服务的类型,可以是类类型(Type[TService])、工厂函数(Callable)或 None 来跳过实例化
|
||||
"""
|
||||
# 重置默认配置名称
|
||||
self.reset_default_config_name()
|
||||
# 初始化服务
|
||||
super().init_service(service_name, service_type)
|
||||
|
||||
def get_default_config_name(self) -> Optional[str]:
|
||||
"""
|
||||
获取默认服务配置的名称
|
||||
@@ -263,6 +276,12 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
return {}
|
||||
return {conf.name: conf for conf in configs if conf.type == self._service_name and conf.enabled}
|
||||
|
||||
def reset_default_config_name(self):
|
||||
"""
|
||||
重置默认配置名称
|
||||
"""
|
||||
self._default_config_name = None
|
||||
|
||||
|
||||
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
|
||||
"""
|
||||
|
||||
@@ -984,11 +984,13 @@ class DoubanModule(_ModuleBase):
|
||||
"""
|
||||
if result:
|
||||
doubanid = result.get("id")
|
||||
if doubanid and not str(doubanid).isdigit():
|
||||
doubanid = re.search(r"\d+", doubanid).group(0)
|
||||
result["id"] = doubanid
|
||||
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
|
||||
return result
|
||||
if doubanid:
|
||||
if not str(doubanid).isdigit():
|
||||
doubanid = re.search(r"\d+", doubanid).group(0)
|
||||
result["id"] = doubanid
|
||||
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
|
||||
return result
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -10,10 +10,10 @@ from requests import Response
|
||||
from app import schemas
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class Emby:
|
||||
@@ -22,9 +22,10 @@ class Emby:
|
||||
_apikey: Optional[str] = None
|
||||
_sync_libraries: List[str] = []
|
||||
user: Optional[Union[str, int]] = None
|
||||
_username: Optional[str] = None
|
||||
|
||||
def __init__(self, host: Optional[str] = None, apikey: Optional[str] = None, play_host: Optional[str] = None,
|
||||
sync_libraries: list = None, **kwargs):
|
||||
username: Optional[str] = None, sync_libraries: list = None, **kwargs):
|
||||
if not host or not apikey:
|
||||
logger.error("Emby服务器配置不完整!")
|
||||
return
|
||||
@@ -35,7 +36,8 @@ class Emby:
|
||||
if self._playhost:
|
||||
self._playhost = UrlUtils.standardize_base_url(self._playhost)
|
||||
self._apikey = apikey
|
||||
self.user = self.get_user(settings.SUPERUSER)
|
||||
self._username = username
|
||||
self.user = self.get_user(username or settings.SUPERUSER)
|
||||
self.folders = self.get_emby_folders()
|
||||
self.serverid = self.get_server_id()
|
||||
self._sync_libraries = sync_libraries or []
|
||||
@@ -139,7 +141,8 @@ class Emby:
|
||||
logger.error(f"连接User/Views 出错:" + str(e))
|
||||
return []
|
||||
|
||||
def get_librarys(self, username: Optional[str] = None, hidden: Optional[bool] = False) -> List[schemas.MediaServerLibrary]:
|
||||
def get_librarys(self, username: Optional[str] = None, hidden: Optional[bool] = False) -> List[
|
||||
schemas.MediaServerLibrary]:
|
||||
"""
|
||||
获取媒体服务器所有媒体库列表
|
||||
"""
|
||||
@@ -567,6 +570,7 @@ class Emby:
|
||||
if library_id != "/":
|
||||
return self.__refresh_emby_library_by_id(library_id)
|
||||
logger.info(f"Emby媒体库刷新完成")
|
||||
return True
|
||||
|
||||
def __get_emby_library_id_by_item(self, item: schemas.RefreshMediaItem) -> Optional[str]:
|
||||
"""
|
||||
@@ -636,7 +640,7 @@ class Emby:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=item.get("ProductionYear"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
@@ -706,9 +710,9 @@ class Emby:
|
||||
yield items
|
||||
elif item.get("Type") in ["Movie", "Series"]:
|
||||
yield self.__format_item_info(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接Users/Items出错:" + str(e))
|
||||
return None
|
||||
|
||||
def get_webhook_message(self, form: any, args: dict) -> Optional[schemas.WebhookEventInfo]:
|
||||
"""
|
||||
@@ -1109,7 +1113,8 @@ class Emby:
|
||||
return ""
|
||||
return "%sItems/%s/Images/Primary" % (self._host, item_id)
|
||||
|
||||
def get_resume(self, num: Optional[int] = 12, username: Optional[str] = None) -> Optional[List[schemas.MediaServerPlayItem]]:
|
||||
def get_resume(self, num: Optional[int] = 12, username: Optional[str] = None) -> Optional[
|
||||
List[schemas.MediaServerPlayItem]]:
|
||||
"""
|
||||
获得继续观看
|
||||
"""
|
||||
@@ -1146,7 +1151,7 @@ class Emby:
|
||||
link = self.get_play_url(item.get("Id"))
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.get("Name")
|
||||
subtitle = item.get("ProductionYear")
|
||||
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
|
||||
else:
|
||||
title = f'{item.get("SeriesName")}'
|
||||
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
|
||||
@@ -1178,7 +1183,8 @@ class Emby:
|
||||
logger.error(f"连接Users/Items/Resume出错:" + str(e))
|
||||
return []
|
||||
|
||||
def get_latest(self, num: Optional[int] = 20, username: Optional[str] = None) -> Optional[List[schemas.MediaServerPlayItem]]:
|
||||
def get_latest(self, num: Optional[int] = 20, username: Optional[str] = None) -> Optional[
|
||||
List[schemas.MediaServerPlayItem]]:
|
||||
"""
|
||||
获得最近更新
|
||||
"""
|
||||
@@ -1217,7 +1223,7 @@ class Emby:
|
||||
ret_latest.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
title=item.get("Name"),
|
||||
subtitle=item.get("ProductionYear"),
|
||||
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -15,7 +15,7 @@ def transfer_process(path: str) -> Callable[[int | float], None]:
|
||||
"""
|
||||
传输进度回调
|
||||
"""
|
||||
pbar = tqdm(total=100, desc="整理进度", unit="%")
|
||||
pbar = tqdm(total=100, desc="进度", unit="%")
|
||||
progress = ProgressHelper(HashUtils.md5(path))
|
||||
progress.start()
|
||||
|
||||
@@ -23,7 +23,7 @@ def transfer_process(path: str) -> Callable[[int | float], None]:
|
||||
"""
|
||||
更新进度百分比
|
||||
"""
|
||||
percent_value = int(percent)
|
||||
percent_value = round(percent, 2) if isinstance(percent, float) else percent
|
||||
pbar.n = percent_value
|
||||
# 更新进度
|
||||
pbar.refresh()
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.log import logger
|
||||
from app.modules.filemanager import StorageBase
|
||||
from app.modules.filemanager.storages import transfer_process
|
||||
from app.schemas.types import StorageSchema
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
@@ -251,6 +252,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
# 检查会话
|
||||
self._check_session()
|
||||
|
||||
# 错误日志控制
|
||||
no_error_log = kwargs.pop("no_error_log", False)
|
||||
|
||||
try:
|
||||
resp = self.session.request(
|
||||
method, f"{self.base_url}{endpoint}",
|
||||
@@ -273,7 +277,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
# 返回数据
|
||||
ret_data = resp.json()
|
||||
if ret_data.get("code"):
|
||||
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
|
||||
if not no_error_log:
|
||||
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
|
||||
|
||||
if result_key:
|
||||
return ret_data.get(result_key)
|
||||
@@ -597,7 +602,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
file_size = local_path.stat().st_size
|
||||
|
||||
# 1. 创建文件并检查秒传
|
||||
chunk_size = 100 * 1024 * 1024 # 分片大小 100M
|
||||
chunk_size = 10 * 1024 * 1024 # 分片大小 10M
|
||||
create_res = self._create_file(drive_id=target_dir.drive_id,
|
||||
parent_file_id=target_dir.fileid,
|
||||
file_name=target_name,
|
||||
@@ -729,7 +734,25 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
progress_callback = transfer_process(Path(fileitem.path).as_posix())
|
||||
|
||||
try:
|
||||
with requests.get(download_url, stream=True) as r:
|
||||
# 构建请求头,包含必要的认证信息
|
||||
headers = {
|
||||
"User-Agent": settings.NORMAL_USER_AGENT,
|
||||
"Referer": "https://www.aliyundrive.com/",
|
||||
"Accept": "*/*",
|
||||
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Connection": "keep-alive",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "cross-site"
|
||||
}
|
||||
|
||||
# 如果有access_token,添加到请求头
|
||||
if self.access_token:
|
||||
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
request_utils = RequestUtils(headers=headers)
|
||||
with request_utils.get_stream(download_url, raise_exception=True) as r:
|
||||
r.raise_for_status()
|
||||
downloaded_size = 0
|
||||
with open(local_path, "wb") as f:
|
||||
@@ -748,22 +771,13 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
# 完成下载
|
||||
progress_callback(100)
|
||||
logger.info(f"【阿里云盘】下载完成: {fileitem.name}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"【阿里云盘】下载网络错误: {fileitem.name} - {str(e)}")
|
||||
# 删除可能部分下载的文件
|
||||
if local_path.exists():
|
||||
local_path.unlink()
|
||||
return None
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"【阿里云盘】下载失败: {fileitem.name} - {str(e)}")
|
||||
# 删除可能部分下载的文件
|
||||
if local_path.exists():
|
||||
local_path.unlink()
|
||||
return None
|
||||
|
||||
return local_path
|
||||
|
||||
def check(self) -> bool:
|
||||
return self.access_token is not None
|
||||
|
||||
@@ -815,7 +829,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
json={
|
||||
"drive_id": drive_id or self._default_drive_id,
|
||||
"file_path": path.as_posix()
|
||||
}
|
||||
},
|
||||
no_error_log=True
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
|
||||
@@ -4,8 +4,6 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import requests
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings, global_vars
|
||||
@@ -569,18 +567,22 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
else:
|
||||
local_path = path / fileitem.name
|
||||
|
||||
with requests.get(download_url, headers=self.__get_header_with_token(), stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if global_vars.is_transfer_stopped(fileitem.path):
|
||||
logger.info(f"【OpenList】{fileitem.path} 下载已取消!")
|
||||
return None
|
||||
f.write(chunk)
|
||||
request_utils = RequestUtils(headers=self.__get_header_with_token())
|
||||
try:
|
||||
with request_utils.get_stream(download_url, raise_exception=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if global_vars.is_transfer_stopped(fileitem.path):
|
||||
logger.info(f"【OpenList】{fileitem.path} 下载已取消!")
|
||||
return None
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"【OpenList】下载文件 {fileitem.path} 失败:{e}")
|
||||
if local_path.exists():
|
||||
return local_path
|
||||
|
||||
if local_path.exists():
|
||||
return local_path
|
||||
return None
|
||||
return local_path
|
||||
|
||||
def upload(
|
||||
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None, task: bool = False
|
||||
|
||||
@@ -26,8 +26,8 @@ class LocalStorage(StorageBase):
|
||||
"softlink": "软链接"
|
||||
}
|
||||
|
||||
# 文件块大小,默认100MB
|
||||
chunk_size = 100 * 1024 * 1024
|
||||
# 文件块大小,默认10MB
|
||||
chunk_size = 10 * 1024 * 1024
|
||||
|
||||
def init_storage(self):
|
||||
"""
|
||||
|
||||
@@ -39,8 +39,8 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
"copy": "复制",
|
||||
}
|
||||
|
||||
# 文件块大小,默认100MB
|
||||
chunk_size = 100 * 1024 * 1024
|
||||
# 文件块大小,默认10MB
|
||||
chunk_size = 10 * 1024 * 1024
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -49,6 +49,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
self._host = None
|
||||
self._username = None
|
||||
self._password = None
|
||||
|
||||
self._init_connection()
|
||||
|
||||
def _init_connection(self):
|
||||
@@ -380,19 +381,95 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
self._check_connection()
|
||||
|
||||
smb_path = self._normalize_path(fileitem.path.rstrip("/"))
|
||||
logger.info(f"【SMB】开始删除: {fileitem.path} (类型: {fileitem.type})")
|
||||
|
||||
# 先检查路径是否存在
|
||||
if not smbclient.path.exists(smb_path):
|
||||
logger.warn(f"【SMB】路径不存在,跳过删除: {fileitem.path}")
|
||||
return True
|
||||
|
||||
if fileitem.type == "dir":
|
||||
# 删除目录
|
||||
smbclient.rmdir(smb_path)
|
||||
# 递归删除目录及其内容
|
||||
logger.debug(f"【SMB】递归删除目录: {smb_path}")
|
||||
self._recursive_delete(smb_path)
|
||||
else:
|
||||
# 删除文件
|
||||
logger.debug(f"【SMB】删除文件: {smb_path}")
|
||||
smbclient.remove(smb_path)
|
||||
|
||||
logger.info(f"【SMB】删除成功: {fileitem.path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】删除失败: {e}")
|
||||
except SMBConnectionError as e:
|
||||
logger.error(f"【SMB】删除失败 - 连接错误: {fileitem.path} - {e}")
|
||||
return False
|
||||
except SMBResponseException as e:
|
||||
logger.error(f"【SMB】删除失败 - SMB响应错误: {fileitem.path} - {e}")
|
||||
return False
|
||||
except SMBException as e:
|
||||
logger.error(f"【SMB】删除失败 - SMB错误: {fileitem.path} - {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】删除失败 - 未知错误: {fileitem.path} - {e}")
|
||||
return False
|
||||
|
||||
def _recursive_delete(self, smb_path: str):
|
||||
"""
|
||||
递归删除目录及其所有内容
|
||||
"""
|
||||
try:
|
||||
# 检查路径是否存在
|
||||
if not smbclient.path.exists(smb_path):
|
||||
logger.debug(f"【SMB】路径不存在,跳过删除: {smb_path}")
|
||||
return
|
||||
|
||||
# 如果是文件,直接删除
|
||||
if smbclient.path.isfile(smb_path):
|
||||
logger.debug(f"【SMB】删除文件: {smb_path}")
|
||||
smbclient.remove(smb_path)
|
||||
return
|
||||
|
||||
# 如果是目录,先删除其内容
|
||||
if smbclient.path.isdir(smb_path):
|
||||
logger.debug(f"【SMB】开始删除目录内容: {smb_path}")
|
||||
try:
|
||||
# 列出目录内容
|
||||
entries = smbclient.listdir(smb_path)
|
||||
logger.debug(f"【SMB】目录 {smb_path} 包含 {len(entries)} 个项目")
|
||||
|
||||
for entry in entries:
|
||||
if entry in [".", ".."]:
|
||||
continue
|
||||
entry_path = f"{smb_path}\\{entry}"
|
||||
logger.debug(f"【SMB】递归删除子项: {entry_path}")
|
||||
# 递归删除子项
|
||||
self._recursive_delete(entry_path)
|
||||
|
||||
# 删除空目录
|
||||
logger.debug(f"【SMB】删除空目录: {smb_path}")
|
||||
smbclient.rmdir(smb_path)
|
||||
logger.debug(f"【SMB】目录删除成功: {smb_path}")
|
||||
|
||||
except SMBResponseException as e:
|
||||
# 如果目录不为空,尝试强制删除
|
||||
logger.warn(f"【SMB】目录不为空,尝试强制删除: {smb_path} - {e}")
|
||||
# 使用remove方法尝试删除(某些SMB服务器支持)
|
||||
try:
|
||||
smbclient.remove(smb_path)
|
||||
logger.info(f"【SMB】强制删除目录成功: {smb_path}")
|
||||
except Exception as remove_error:
|
||||
# 如果还是失败,记录错误并抛出异常
|
||||
logger.error(f"【SMB】无法删除非空目录: {smb_path} - {remove_error}")
|
||||
raise SMBConnectionError(f"无法删除非空目录 {smb_path}: {remove_error}")
|
||||
except SMBException as e:
|
||||
logger.error(f"【SMB】SMB操作失败: {smb_path} - {e}")
|
||||
raise SMBConnectionError(f"SMB操作失败 {smb_path}: {e}")
|
||||
|
||||
except SMBConnectionError:
|
||||
# 重新抛出SMB连接错误
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】递归删除失败: {smb_path} - {e}")
|
||||
raise SMBConnectionError(f"递归删除失败 {smb_path}: {e}")
|
||||
|
||||
def rename(self, fileitem: schemas.FileItem, name: str) -> bool:
|
||||
"""
|
||||
@@ -584,8 +661,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
析构函数,清理连接
|
||||
"""
|
||||
try:
|
||||
# smbclient 自动管理连接池,但我们可以重置缓存
|
||||
if hasattr(self, '_connected') and self._connected:
|
||||
if self._connected:
|
||||
reset_connection_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"【SMB】清理连接失败: {e}")
|
||||
|
||||
@@ -91,6 +91,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"refresh_time": int(time.time()),
|
||||
**tokens
|
||||
})
|
||||
else:
|
||||
return None
|
||||
access_token = tokens.get("access_token")
|
||||
if access_token:
|
||||
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||
@@ -209,6 +211,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 检查会话
|
||||
self._check_session()
|
||||
|
||||
# 错误日志标志
|
||||
no_error_log = kwargs.pop("no_error_log", False)
|
||||
# 重试次数
|
||||
retry_times = kwargs.pop("retry_limit", 5)
|
||||
|
||||
try:
|
||||
resp = self.session.request(
|
||||
method, f"{self.base_url}{endpoint}",
|
||||
@@ -222,6 +229,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 失败!")
|
||||
return None
|
||||
|
||||
kwargs["retry_limit"] = retry_times
|
||||
|
||||
# 处理速率限制
|
||||
if resp.status_code == 429:
|
||||
reset_time = 5 + int(resp.headers.get("X-RateLimit-Reset", 60))
|
||||
@@ -238,8 +247,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
ret_data = resp.json()
|
||||
if ret_data.get("code") != 0:
|
||||
error_msg = ret_data.get("message")
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}!")
|
||||
retry_times = kwargs.get("retry_limit", 5)
|
||||
if not no_error_log:
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
|
||||
if "已达到当前访问上限" in error_msg:
|
||||
if retry_times <= 0:
|
||||
logger.error(f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!")
|
||||
@@ -536,8 +545,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
security_token=SecurityToken
|
||||
)
|
||||
bucket = oss2.Bucket(auth, endpoint, bucket_name) # noqa
|
||||
# determine_part_size方法用于确定分片大小,设置分片大小为 100M
|
||||
part_size = determine_part_size(file_size, preferred_size=100 * 1024 * 1024)
|
||||
# determine_part_size方法用于确定分片大小,设置分片大小为 10M
|
||||
part_size = determine_part_size(file_size, preferred_size=10 * 1024 * 1024)
|
||||
|
||||
# 初始化进度条
|
||||
logger.info(f"【115】开始上传: {local_path} -> {target_path},分片大小:{StringUtils.str_filesize(part_size)}")
|
||||
@@ -718,7 +727,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"data",
|
||||
data={
|
||||
"path": path.as_posix()
|
||||
}
|
||||
},
|
||||
no_error_log=True
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
|
||||
@@ -14,10 +14,10 @@ from app.helper.directory import DirectoryHelper
|
||||
from app.helper.message import TemplateHelper
|
||||
from app.log import logger
|
||||
from app.modules.filemanager.storages import StorageBase
|
||||
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData
|
||||
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData, \
|
||||
TransferRenameEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
from app.utils.system import SystemUtils
|
||||
from app.schemas import TransferRenameEventData
|
||||
|
||||
lock = Lock()
|
||||
|
||||
@@ -129,7 +129,7 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
new_path = target_path / fileitem.name
|
||||
# 整理目录
|
||||
@@ -147,21 +147,18 @@ class TransHandler:
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
logger.info(f"文件夹 {fileitem.path} 整理成功")
|
||||
# 计算目录下所有文件大小
|
||||
total_size = sum(file.stat().st_size for file in Path(fileitem.path).rglob('*') if file.is_file())
|
||||
# 返回整理后的路径
|
||||
self.__set_result(success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
total_size=total_size,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
# 整理单个文件
|
||||
if mediainfo.type == MediaType.TV:
|
||||
@@ -174,7 +171,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
# 文件结束季为空
|
||||
in_meta.end_season = None
|
||||
@@ -210,7 +207,7 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
new_file = target_path / fileitem.name
|
||||
folder_path = target_path
|
||||
@@ -227,7 +224,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
# 目标文件
|
||||
target_item = target_oper.get_item(new_file)
|
||||
if target_item:
|
||||
@@ -239,7 +236,8 @@ class TransHandler:
|
||||
overflag = True
|
||||
if not overflag:
|
||||
# 目标文件已存在
|
||||
logger.info(f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
logger.info(
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
if overwrite_mode == 'always':
|
||||
# 总是覆盖同名文件
|
||||
overflag = True
|
||||
@@ -257,7 +255,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'never':
|
||||
# 存在不覆盖
|
||||
self.__set_result(success=False,
|
||||
@@ -268,7 +266,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'latest':
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
@@ -295,7 +293,7 @@ class TransHandler:
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
|
||||
logger.info(f"文件 {fileitem.path} 整理成功")
|
||||
self.__set_result(success=True,
|
||||
@@ -305,7 +303,7 @@ class TransHandler:
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.copy()
|
||||
return self.result.model_copy()
|
||||
finally:
|
||||
self.result = None
|
||||
|
||||
@@ -424,7 +422,7 @@ class TransHandler:
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
|
||||
@@ -154,7 +154,7 @@ class FilterModule(_ModuleBase):
|
||||
custom_rules = self.rulehelper.get_custom_rules()
|
||||
for rule in custom_rules:
|
||||
logger.info(f"加载自定义规则 {rule.id} - {rule.name}")
|
||||
self.rule_set[rule.id] = rule.dict()
|
||||
self.rule_set[rule.id] = rule.model_dump()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@@ -33,6 +33,8 @@ class SiteSchema(Enum):
|
||||
MTorrent = "MTorrent"
|
||||
Yema = "Yema"
|
||||
HDDolby = "HDDolby"
|
||||
Zhixing = "Zhixing"
|
||||
Bitpt = "Bitpt"
|
||||
|
||||
|
||||
class SiteParserBase(metaclass=ABCMeta):
|
||||
|
||||
161
app/modules/indexer/parser/bitpt.py
Normal file
161
app/modules/indexer/parser/bitpt.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#
|
||||
# 极速之星 https://bitpt.cn/
|
||||
# author: ThedoRap
|
||||
# time: 2025-10-02
|
||||
#
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urljoin, urlencode
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
class BitptSiteUserInfo(SiteParserBase):
|
||||
schema = SiteSchema.Bitpt
|
||||
|
||||
def _parse_site_page(self, html_text: str):
|
||||
self._user_basic_page = "userdetails.php?uid={uid}"
|
||||
self._user_detail_page = None
|
||||
self._user_basic_params = {}
|
||||
self._user_traffic_page = None
|
||||
self._sys_mail_unread_page = None
|
||||
self._user_mail_unread_page = None
|
||||
self._mail_unread_params = {}
|
||||
self._torrent_seeding_base = "browse.php"
|
||||
self._torrent_seeding_params = {"t": "myseed", "st": "2", "d": "desc"}
|
||||
self._torrent_seeding_headers = {}
|
||||
self._addition_headers = {}
|
||||
|
||||
def _parse_logged_in(self, html_text):
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
return bool(soup.find(id='userinfotop'))
|
||||
|
||||
def _parse_user_base_info(self, html_text: str):
|
||||
if not html_text:
|
||||
return None
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
table = soup.find('table', class_='frmtable')
|
||||
if not table:
|
||||
return
|
||||
|
||||
rows = table.find_all('tr')
|
||||
info_dict = {}
|
||||
for row in rows:
|
||||
cells = row.find_all('td')
|
||||
if len(cells) == 2:
|
||||
key = cells[0].text.strip()
|
||||
value = cells[1].text.strip()
|
||||
info_dict[key] = value
|
||||
|
||||
self.userid = info_dict.get('UID')
|
||||
self.username = info_dict.get('用户名').split('\xa0')[0] if '用户名' in info_dict else None
|
||||
self.user_level = info_dict.get('用户级别') if '用户级别' in info_dict else None
|
||||
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
|
||||
|
||||
self.upload = StringUtils.num_filesize(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
|
||||
self.download = StringUtils.num_filesize(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
|
||||
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
|
||||
bonus_str = info_dict.get('星辰', '')
|
||||
self.bonus = float(re.search(r'累计([\d\.]+)', bonus_str).group(1)) if re.search(r'累计([\d\.]+)', bonus_str) else 0
|
||||
self.message_unread = 0
|
||||
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
self.seeding = 0
|
||||
self.seeding_size = 0
|
||||
else:
|
||||
seeding_info = soup.find('div', style="margin:0 auto;width:90%;font-size:14px;margin-top:10px;margin-bottom:10px;text-align:center;")
|
||||
if seeding_info:
|
||||
seeding_link = seeding_info.find_all('a')[1].text if len(seeding_info.find_all('a')) > 1 else ''
|
||||
match = re.search(r'当前上传的种子\((\d+)个, 共([\d\.]+ [KMGT]B)\)', seeding_link)
|
||||
if match:
|
||||
self.seeding = int(match.group(1))
|
||||
self.seeding_size = StringUtils.num_filesize(match.group(2))
|
||||
else:
|
||||
self.seeding = 0
|
||||
self.seeding_size = 0
|
||||
|
||||
def _parse_user_traffic_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
|
||||
if not html_text:
|
||||
return 0, 0
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
torrent_table = soup.find('table', class_='torrenttable')
|
||||
if not torrent_table:
|
||||
return 0, 0
|
||||
rows = torrent_table.find_all('tr')
|
||||
if len(rows) <= 1:
|
||||
return 0, 0
|
||||
torrents = [row for row in rows[1:] if 'btr' in row.get('class', [])]
|
||||
page_seeding = 0
|
||||
page_seeding_size = 0
|
||||
for torrent in torrents:
|
||||
size_td = torrent.find('td', class_='r')
|
||||
if size_td:
|
||||
size_a = size_td.find('a')
|
||||
size_text = size_a.text.strip() if size_a else size_td.text.strip()
|
||||
if size_text:
|
||||
page_seeding += 1
|
||||
page_seeding_size += StringUtils.num_filesize(size_text)
|
||||
return page_seeding, page_seeding_size
|
||||
|
||||
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
super().parse()
|
||||
if self._index_html:
|
||||
soup = BeautifulSoup(self._index_html, 'html.parser')
|
||||
user_link = soup.find('a', href=re.compile(r'userdetails\.php\?uid=\d+'))
|
||||
if user_link:
|
||||
uid_match = re.search(r'uid=(\d+)', user_link['href'])
|
||||
if uid_match:
|
||||
self.userid = uid_match.group(1)
|
||||
|
||||
if self.userid and self._user_basic_page:
|
||||
basic_url = self._user_basic_page.format(uid=self.userid)
|
||||
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
|
||||
self._parse_user_base_info(basic_html)
|
||||
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
seeding_base_url = urljoin(self._base_url, self._torrent_seeding_base)
|
||||
params = self._torrent_seeding_params.copy()
|
||||
page_num = 1
|
||||
while True:
|
||||
params['p'] = page_num
|
||||
query_string = urlencode(params)
|
||||
full_url = f"{seeding_base_url}?{query_string}"
|
||||
seeding_html = self._get_page_content(url=full_url)
|
||||
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
|
||||
self.seeding += page_seeding
|
||||
self.seeding_size += page_seeding_size
|
||||
if page_seeding == 0:
|
||||
break
|
||||
page_num += 1
|
||||
|
||||
# 🔑 最终对外统一转字符串
|
||||
self.userid = str(self.userid or "")
|
||||
self.username = str(self.username or "")
|
||||
self.user_level = str(self.user_level or "")
|
||||
self.join_at = str(self.join_at or "")
|
||||
|
||||
self.upload = str(self.upload or 0)
|
||||
self.download = str(self.download or 0)
|
||||
self.ratio = str(self.ratio or 0)
|
||||
self.bonus = str(self.bonus or 0.0)
|
||||
self.message_unread = str(self.message_unread or 0)
|
||||
|
||||
self.seeding = str(self.seeding or 0)
|
||||
self.seeding_size = str(self.seeding_size or 0)
|
||||
184
app/modules/indexer/parser/zhixing.py
Normal file
184
app/modules/indexer/parser/zhixing.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# 知行 http://pt.zhixing.bjtu.edu.cn/
|
||||
# author: ThedoRap
|
||||
# time: 2025-10-02
|
||||
#
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.utils.string import StringUtils
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
class ZhixingSiteUserInfo(SiteParserBase):
|
||||
schema = SiteSchema.Zhixing
|
||||
|
||||
def _parse_site_page(self, html_text: str):
|
||||
"""
|
||||
获取站点页面地址
|
||||
"""
|
||||
self._user_basic_page = "user/{uid}/"
|
||||
self._user_detail_page = None
|
||||
self._user_basic_params = {}
|
||||
self._user_traffic_page = None
|
||||
self._sys_mail_unread_page = None
|
||||
self._user_mail_unread_page = None
|
||||
self._mail_unread_params = {}
|
||||
self._torrent_seeding_base = "user/{uid}/seeding"
|
||||
self._torrent_seeding_params = {}
|
||||
self._torrent_seeding_headers = {}
|
||||
self._addition_headers = {}
|
||||
|
||||
def _parse_logged_in(self, html_text):
|
||||
"""
|
||||
判断是否登录成功, 通过判断是否存在用户信息
|
||||
"""
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
return bool(soup.find(id='um'))
|
||||
|
||||
def _parse_user_base_info(self, html_text: str):
|
||||
"""
|
||||
解析用户基本信息,这里把_parse_user_traffic_info和_parse_user_detail_info合并到这里
|
||||
"""
|
||||
if not html_text:
|
||||
return None
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
details_tabs = soup.find_all('div', class_='user-details-tabs')
|
||||
info_dict = {}
|
||||
for tab in details_tabs:
|
||||
for p in tab.find_all('p'):
|
||||
text = p.text.strip()
|
||||
if ':' in text:
|
||||
parts = text.split(':', 1)
|
||||
elif ':' in text:
|
||||
parts = text.split(':', 1)
|
||||
else:
|
||||
continue
|
||||
if len(parts) == 2:
|
||||
key = parts[0].strip()
|
||||
value_text = parts[1].strip()
|
||||
value = re.split(r'\s*\(', value_text)[0].strip().split('查看')[0].strip()
|
||||
info_dict[key] = value
|
||||
|
||||
self._basic_info = info_dict # Save for fallback
|
||||
|
||||
self.userid = info_dict.get('UID')
|
||||
self.username = info_dict.get('用户名')
|
||||
self.user_level = info_dict.get('用户组')
|
||||
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
|
||||
|
||||
def num_filesize_safe(s: str):
|
||||
if s:
|
||||
s = s.strip()
|
||||
if re.match(r'^\d+(\.\d+)?$', s):
|
||||
s += ' B'
|
||||
return StringUtils.num_filesize(s) if s else 0
|
||||
|
||||
self.upload = num_filesize_safe(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
|
||||
self.download = num_filesize_safe(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
|
||||
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
|
||||
self.bonus = float(info_dict.get('保种积分')) if '保种积分' in info_dict else 0.0
|
||||
self.message_unread = 0 # 暂无消息解析
|
||||
|
||||
# Temporarily set seeding from basic, will override or fallback later
|
||||
self.seeding = int(info_dict.get('当前保种数量')) if '当前保种数量' in info_dict else 0
|
||||
self.seeding_size = num_filesize_safe(info_dict.get('当前保种容量')) if '当前保种容量' in info_dict else 0
|
||||
|
||||
def _parse_user_traffic_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
|
||||
"""
|
||||
解析用户做种信息单页,返回本页数量和大小
|
||||
"""
|
||||
if not html_text:
|
||||
return 0, 0
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
torrents = soup.find_all('tr', id=re.compile(r'^t\d+'))
|
||||
page_seeding = 0
|
||||
page_seeding_size = 0
|
||||
for torrent in torrents:
|
||||
size_td = torrent.find('td', class_='r')
|
||||
if size_td:
|
||||
size_text = size_td.find('a').text if size_td.find('a') else size_td.text.strip()
|
||||
page_seeding += 1
|
||||
page_seeding_size += StringUtils.num_filesize(size_text)
|
||||
return page_seeding, page_seeding_size
|
||||
|
||||
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
"""
|
||||
占位,避免抽象类报错
|
||||
"""
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
"""
|
||||
解析站点信息
|
||||
"""
|
||||
super().parse()
|
||||
# 先从首页解析userid
|
||||
if self._index_html:
|
||||
soup = BeautifulSoup(self._index_html, 'html.parser')
|
||||
user_link = soup.find('a', href=re.compile(r'/user/\d+/'))
|
||||
if user_link:
|
||||
uid_match = re.search(r'/user/(\d+)/', user_link['href'])
|
||||
if uid_match:
|
||||
self.userid = uid_match.group(1)
|
||||
# 如果有userid,则格式化页面
|
||||
if self.userid:
|
||||
if self._user_basic_page:
|
||||
basic_url = self._user_basic_page.format(uid=self.userid)
|
||||
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
|
||||
self._parse_user_base_info(basic_html)
|
||||
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
|
||||
self.seeding = 0 # Reset to sum from pages
|
||||
self.seeding_size = 0
|
||||
seeding_base = self._torrent_seeding_base.format(uid=self.userid)
|
||||
seeding_base_url = urljoin(self._base_url, seeding_base)
|
||||
page_num = 1
|
||||
while True:
|
||||
seeding_url = f"{seeding_base_url}/p{page_num}"
|
||||
seeding_html = self._get_page_content(url=seeding_url)
|
||||
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
|
||||
self.seeding += page_seeding
|
||||
self.seeding_size += page_seeding_size
|
||||
if page_seeding == 0:
|
||||
break
|
||||
page_num += 1
|
||||
# Fallback to basic if no seeding found from pages
|
||||
if self.seeding == 0 and hasattr(self, '_basic_info'):
|
||||
def num_filesize_safe(s: str):
|
||||
if s:
|
||||
s = s.strip()
|
||||
if re.match(r'^\d+(\.\d+)?$', s):
|
||||
s += ' B'
|
||||
return StringUtils.num_filesize(s) if s else 0
|
||||
self.seeding = int(self._basic_info.get('当前保种数量', 0))
|
||||
self.seeding_size = num_filesize_safe(self._basic_info.get('当前保种容量', ''))
|
||||
|
||||
# 🔑 最终对外统一转字符串,避免 join 报错
|
||||
self.userid = str(self.userid or "")
|
||||
self.username = str(self.username or "")
|
||||
self.user_level = str(self.user_level or "")
|
||||
self.join_at = str(self.join_at or "")
|
||||
|
||||
self.upload = str(self.upload or 0)
|
||||
self.download = str(self.download or 0)
|
||||
self.ratio = str(self.ratio or 0)
|
||||
self.bonus = str(self.bonus or 0.0)
|
||||
self.message_unread = str(self.message_unread or 0)
|
||||
|
||||
self.seeding = str(self.seeding or 0)
|
||||
self.seeding_size = str(self.seeding_size or 0)
|
||||
@@ -75,6 +75,9 @@ class MTorrentSpider:
|
||||
categories = self._tv_category
|
||||
else:
|
||||
categories = self._movie_category
|
||||
# mtorrent搜索imdb需要输入完整imdb链接,参见 https://wiki.m-team.cc/zh-tw/imdbtosearch
|
||||
if keyword and keyword.startswith("tt"):
|
||||
keyword = f"https://www.imdb.com/title/{keyword}"
|
||||
return {
|
||||
"keyword": keyword,
|
||||
"categories": categories,
|
||||
|
||||
@@ -732,7 +732,7 @@ class Jellyfin:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=item.get("ProductionYear"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
@@ -924,7 +924,7 @@ class Jellyfin:
|
||||
image = self.generate_image_link(item.get("Id"), "Backdrop", False)
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.get("Name")
|
||||
subtitle = item.get("ProductionYear")
|
||||
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
|
||||
else:
|
||||
title = f'{item.get("SeriesName")}'
|
||||
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
|
||||
@@ -984,7 +984,7 @@ class Jellyfin:
|
||||
ret_latest.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
title=item.get("Name"),
|
||||
subtitle=item.get("ProductionYear"),
|
||||
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -437,7 +437,7 @@ class Plex:
|
||||
|
||||
@staticmethod
|
||||
def __get_ids(guids: List[Any]) -> dict:
|
||||
def parse_tmdb_id(value: str) -> (bool, int):
|
||||
def parse_tmdb_id(value: str) -> tuple[bool, int]:
|
||||
"""尝试将TMDB ID字符串转换为整数。如果成功,返回(True, int),失败则返回(False, None)。"""
|
||||
try:
|
||||
int_value = int(value)
|
||||
@@ -509,7 +509,7 @@ class Plex:
|
||||
item_type=item.type,
|
||||
title=item.title,
|
||||
original_title=item.originalTitle,
|
||||
year=item.year,
|
||||
year=str(item.year),
|
||||
tmdbid=ids.get("tmdb_id"),
|
||||
imdbid=ids.get("imdb_id"),
|
||||
tvdbid=ids.get("tvdb_id"),
|
||||
@@ -746,7 +746,7 @@ class Plex:
|
||||
item_type = MediaType.MOVIE.value if item.TYPE == "movie" else MediaType.TV.value
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.title
|
||||
subtitle = item.year
|
||||
subtitle = str(item.year) if item.year else None
|
||||
else:
|
||||
title = item.grandparentTitle
|
||||
subtitle = f"S{item.parentIndex}:E{item.index} - {item.title}"
|
||||
@@ -825,7 +825,7 @@ class Plex:
|
||||
ret_resume.append(schemas.MediaServerPlayItem(
|
||||
id=item.key,
|
||||
title=title,
|
||||
subtitle=item.year,
|
||||
subtitle=str(item.year) if item.year else None,
|
||||
type=item_type,
|
||||
image=image,
|
||||
link=link,
|
||||
|
||||
@@ -51,10 +51,6 @@ class RedisModule(_ModuleBase):
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE != "redis":
|
||||
return None
|
||||
redis_helper = RedisHelper()
|
||||
try:
|
||||
if redis_helper.test():
|
||||
return True, ""
|
||||
return False, "Redis连接失败,请检查配置"
|
||||
finally:
|
||||
redis_helper.close()
|
||||
if RedisHelper().test():
|
||||
return True, ""
|
||||
return False, "Redis连接失败,请检查配置"
|
||||
|
||||
@@ -264,7 +264,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
userid=userid, username=username, text=text)
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息
|
||||
|
||||
@@ -120,7 +120,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
logger.debug(f"解析SynologyChat消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息体
|
||||
|
||||
@@ -261,7 +261,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return cleaned
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息体
|
||||
@@ -283,7 +283,8 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
original_chat_id=message.original_chat_id,
|
||||
escape_markdown=kwargs.get("escape_markdown"))
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,8 @@ class Telegram:
|
||||
_callback_handlers: Dict[str, Callable] = {} # 存储回调处理器
|
||||
_user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
|
||||
_escape_chars = r'_*[]()~`>#+-=|{}.!' # Telegram MarkdownV2
|
||||
_markdown_escape_pattern = re.compile(f'([{re.escape(_escape_chars)}])') # Telegram MarkdownV2 规则转义特殊字符正则pattern
|
||||
def __init__(self, TELEGRAM_TOKEN: Optional[str] = None, TELEGRAM_CHAT_ID: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
初始化参数
|
||||
@@ -52,7 +53,7 @@ class Telegram:
|
||||
else:
|
||||
apihelper.proxy = settings.PROXY
|
||||
# bot
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown")
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
# 记录句柄
|
||||
self._bot = _bot
|
||||
# 获取并存储bot用户名用于@检测
|
||||
@@ -215,7 +216,8 @@ class Telegram:
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
original_chat_id: Optional[str] = None,
|
||||
escape_markdown: bool = True) -> Optional[bool]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
:param title: 消息标题
|
||||
@@ -226,7 +228,8 @@ class Telegram:
|
||||
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
|
||||
:param original_message_id: 原消息ID,如果提供则编辑原消息
|
||||
:param original_chat_id: 原消息的聊天ID,编辑消息时需要
|
||||
:userid: 发送消息的目标用户ID,为空则发给管理员
|
||||
:param escape_markdown: 是否对内容进行Markdown转义
|
||||
|
||||
"""
|
||||
if not self._telegram_token or not self._telegram_chat_id:
|
||||
return None
|
||||
@@ -236,10 +239,20 @@ class Telegram:
|
||||
return False
|
||||
|
||||
try:
|
||||
if title:
|
||||
# 标题总是转义(因为通常标题不包含Markdown格式)
|
||||
title = self.escape_markdown(title)
|
||||
if text:
|
||||
# 对text进行Markdown特殊字符转义
|
||||
text = re.sub(r"([_`])", r"\\\1", text)
|
||||
caption = f"*{title}*\n{text}"
|
||||
if escape_markdown:
|
||||
# 完全转义模式:转义所有特殊字符
|
||||
text = self.escape_markdown(text)
|
||||
else:
|
||||
# 智能转义模式:保留Markdown格式,只转义普通文本中的特殊字符
|
||||
text = self.escape_markdown_smart(text)
|
||||
if title:
|
||||
caption = f"*{title}*\n{text}"
|
||||
else:
|
||||
caption = text
|
||||
else:
|
||||
caption = f"*{title}*"
|
||||
|
||||
@@ -499,7 +512,7 @@ class Telegram:
|
||||
|
||||
if image:
|
||||
# 如果有图片,使用edit_message_media
|
||||
media = InputMediaPhoto(media=image, caption=text, parse_mode="Markdown")
|
||||
media = InputMediaPhoto(media=image, caption=text, parse_mode="MarkdownV2")
|
||||
self._bot.edit_message_media(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
@@ -512,7 +525,7 @@ class Telegram:
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup
|
||||
)
|
||||
return True
|
||||
@@ -542,7 +555,7 @@ class Telegram:
|
||||
ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id,
|
||||
photo=photo,
|
||||
caption=caption,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送图片消息失败")
|
||||
@@ -553,12 +566,12 @@ class Telegram:
|
||||
for i in range(0, len(caption), 4095):
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption[i:i + 4095],
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup if i == 0 else None)
|
||||
else:
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption,
|
||||
parse_mode="Markdown",
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送文本消息失败")
|
||||
@@ -597,3 +610,84 @@ class Telegram:
|
||||
self._bot.stop_polling()
|
||||
self._polling_thread.join()
|
||||
logger.info("Telegram消息接收服务已停止")
|
||||
|
||||
def escape_markdown(self, text: str) -> str:
|
||||
# 按 Telegram MarkdownV2 规则转义特殊字符
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
return self._markdown_escape_pattern.sub(r'\\\1', text)
|
||||
|
||||
def escape_markdown_smart(self, text: str) -> str:
|
||||
"""
|
||||
智能转义Markdown文本:只转义不在Markdown标记内的特殊字符
|
||||
这样可以保留已有的Markdown格式(如*粗体*、_斜体_、[链接](url)等),
|
||||
同时转义普通文本中的特殊字符以避免API错误
|
||||
|
||||
注意:Telegram MarkdownV2不支持以下语法,这些字符会被转义:
|
||||
- 标题语法(#、##、###)会被转义为 \#、\##、\###
|
||||
- 列表语法(-、*、+)会被转义为 \-、\*、\+
|
||||
- 引用语法(>)会被转义为 \>
|
||||
|
||||
建议使用加粗文本模拟标题:*标题文本*
|
||||
|
||||
:param text: 要转义的文本
|
||||
:return: 转义后的文本
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
|
||||
# 如果没有特殊字符,直接返回
|
||||
if not any(char in self._escape_chars for char in text):
|
||||
return text
|
||||
|
||||
# 标记受保护的区域(Markdown标记内的内容不转义)
|
||||
protected = [False] * len(text)
|
||||
|
||||
# 按优先级匹配Markdown标记(从最复杂到最简单)
|
||||
# 1. 链接:[text](url) - 必须最先匹配
|
||||
link_pattern = r'\[([^\]]*)\]\(([^)]*)\)'
|
||||
for match in re.finditer(link_pattern, text):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 2. 粗体:*text*(单个*,不是**)
|
||||
bold_pattern = r'(?<!\*)\*(?!\*)([^*]+?)(?<!\*)\*(?!\*)'
|
||||
for match in re.finditer(bold_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 3. 斜体:_text_(单个_,不是__)
|
||||
italic_pattern = r'(?<!_)_(?!_)([^_]+?)(?<!_)_(?!_)'
|
||||
for match in re.finditer(italic_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 4. 代码:`text`
|
||||
code_pattern = r'`([^`]+)`'
|
||||
for match in re.finditer(code_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 5. 删除线:~text~
|
||||
strikethrough_pattern = r'~([^~]+)~'
|
||||
for match in re.finditer(strikethrough_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 构建结果:只转义未保护区域的特殊字符
|
||||
result = []
|
||||
for i, char in enumerate(text):
|
||||
if protected[i]:
|
||||
# 受保护区域(Markdown标记内),不转义
|
||||
result.append(char)
|
||||
elif char in self._escape_chars:
|
||||
# 未保护区域,转义特殊字符
|
||||
result.append('\\' + char)
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
return ''.join(result)
|
||||
@@ -747,6 +747,9 @@ class TmdbApi:
|
||||
logger.info("正在从TheDbMovie网站查询:%s ..." % name)
|
||||
tmdb_url = self._build_tmdb_search_url(name)
|
||||
res = RequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(url=tmdb_url)
|
||||
if res is None:
|
||||
logger.error("无法连接TheDbMovie")
|
||||
return None
|
||||
|
||||
# 响应验证
|
||||
response_result = self._validate_response(res)
|
||||
@@ -1857,6 +1860,9 @@ class TmdbApi:
|
||||
tmdb_url = self._build_tmdb_search_url(name)
|
||||
res = await AsyncRequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(
|
||||
url=tmdb_url)
|
||||
if res is None:
|
||||
logger.error("无法连接TheDbMovie")
|
||||
return None
|
||||
|
||||
# 响应验证
|
||||
response_result = self._validate_response(res)
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from .exceptions import TMDbException
|
||||
@@ -18,14 +18,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TMDb(object):
|
||||
|
||||
def __init__(self, obj_cached=True, session=None, language=None):
|
||||
def __init__(self, session=None, language=None):
|
||||
self._api_key = settings.TMDB_API_KEY
|
||||
self._language = language or settings.TMDB_LOCALE or "en-US"
|
||||
self._session_id = None
|
||||
self._session = session
|
||||
self._wait_on_rate_limit = True
|
||||
self._debug_enabled = False
|
||||
self._cache_enabled = obj_cached
|
||||
self._proxies = settings.PROXY
|
||||
self._domain = settings.TMDB_API_DOMAIN
|
||||
self._page = None
|
||||
@@ -41,7 +39,6 @@ class TMDb(object):
|
||||
self._remaining = 40
|
||||
self._reset = None
|
||||
self._timeout = 15
|
||||
self.obj_cached = obj_cached
|
||||
|
||||
self.__clear_async_cache__ = False
|
||||
|
||||
@@ -111,36 +108,8 @@ class TMDb(object):
|
||||
def wait_on_rate_limit(self, wait_on_rate_limit):
|
||||
self._wait_on_rate_limit = bool(wait_on_rate_limit)
|
||||
|
||||
@property
|
||||
def debug(self):
|
||||
return self._debug_enabled
|
||||
|
||||
@debug.setter
|
||||
def debug(self, debug):
|
||||
self._debug_enabled = bool(debug)
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._cache_enabled
|
||||
|
||||
@cache.setter
|
||||
def cache(self, cache):
|
||||
self._cache_enabled = bool(cache)
|
||||
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
def cached_request(self, method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
|
||||
return self.request(method, url, data, json)
|
||||
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
async def async_cached_request(self, method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
|
||||
if self.__clear_async_cache__:
|
||||
self.__clear_async_cache__ = False
|
||||
await self.async_cached_request.cache_clear()
|
||||
return await self.async_request(method, url, data, json)
|
||||
|
||||
def request(self, method, url, data, json):
|
||||
def request(self, method, url, data, json, **kwargs):
|
||||
if method == "GET":
|
||||
req = self._req.get_res(url, params=data, json=json)
|
||||
else:
|
||||
@@ -149,7 +118,8 @@ class TMDb(object):
|
||||
raise TMDbException("无法连接TheMovieDb,请检查网络连接!")
|
||||
return req
|
||||
|
||||
async def async_request(self, method, url, data, json):
|
||||
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
|
||||
async def async_request(self, method, url, data, json, **kwargs):
|
||||
if method == "GET":
|
||||
req = await self._async_req.get_res(url, params=data, json=json)
|
||||
else:
|
||||
@@ -160,7 +130,7 @@ class TMDb(object):
|
||||
|
||||
def cache_clear(self):
|
||||
self.__clear_async_cache__ = True
|
||||
return self.cached_request.cache_clear()
|
||||
return self.request.cache_clear()
|
||||
|
||||
def _validate_api_key(self):
|
||||
if self.api_key is None or self.api_key == "":
|
||||
@@ -204,13 +174,6 @@ class TMDb(object):
|
||||
if "total_pages" in json_data:
|
||||
self._total_pages = json_data["total_pages"]
|
||||
|
||||
if self.debug:
|
||||
logger.info(json_data)
|
||||
if is_async:
|
||||
logger.info(self.async_cached_request.cache_info())
|
||||
else:
|
||||
logger.info(self.cached_request.cache_info())
|
||||
|
||||
@staticmethod
|
||||
def _handle_errors(json_data):
|
||||
if "errors" in json_data:
|
||||
@@ -224,10 +187,9 @@ class TMDb(object):
|
||||
self._validate_api_key()
|
||||
url = self._build_url(action, params)
|
||||
|
||||
if self.cache and self.obj_cached and call_cached and method != "POST":
|
||||
req = self.cached_request(method, url, data, json)
|
||||
else:
|
||||
req = self.request(method, url, data, json)
|
||||
with fresh(not call_cached or method == "POST"):
|
||||
req = self.request(method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
|
||||
|
||||
if req is None:
|
||||
return None
|
||||
@@ -253,10 +215,13 @@ class TMDb(object):
|
||||
self._validate_api_key()
|
||||
url = self._build_url(action, params)
|
||||
|
||||
if self.cache and self.obj_cached and call_cached and method != "POST":
|
||||
req = await self.async_cached_request(method, url, data, json)
|
||||
else:
|
||||
req = await self.async_request(method, url, data, json)
|
||||
if self.__clear_async_cache__:
|
||||
self.__clear_async_cache__ = False
|
||||
await self.async_request.cache_clear()
|
||||
|
||||
async with async_fresh(not call_cached or method == "POST"):
|
||||
req = await self.async_request(method, url, data, json,
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
|
||||
|
||||
if req is None:
|
||||
return None
|
||||
|
||||
@@ -154,7 +154,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
source = args.get("source")
|
||||
if source:
|
||||
server: TrimeMedia = self.get_instance(source)
|
||||
server: Optional[TrimeMedia] = self.get_instance(source)
|
||||
if not server:
|
||||
return None
|
||||
result = server.get_webhook_message(body)
|
||||
@@ -247,7 +247,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
媒体数量统计
|
||||
"""
|
||||
if server:
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return None
|
||||
servers = [server_obj]
|
||||
@@ -268,7 +268,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
媒体库列表
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if server_obj:
|
||||
return server_obj.get_librarys(hidden=hidden)
|
||||
return None
|
||||
@@ -290,7 +290,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
|
||||
:return: 返回一个生成器对象,用于逐步获取媒体服务器中的项目
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if server_obj:
|
||||
return server_obj.get_items(library_id, start_index, limit)
|
||||
return None
|
||||
@@ -301,7 +301,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
媒体库项目详情
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if server_obj:
|
||||
return server_obj.get_iteminfo(item_id)
|
||||
return None
|
||||
@@ -312,7 +312,9 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
获取剧集信息
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
if not isinstance(item_id, str):
|
||||
return None
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return None
|
||||
_, seasoninfo = server_obj.get_tv_episodes(item_id=item_id)
|
||||
@@ -329,10 +331,10 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
获取媒体服务器正在播放信息
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return []
|
||||
return server_obj.get_resume(num=count)
|
||||
return server_obj.get_resume(num=count) or []
|
||||
|
||||
def mediaserver_play_url(
|
||||
self, server: str, item_id: Union[str, int]
|
||||
@@ -340,7 +342,9 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
获取媒体库播放地址
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
if not isinstance(item_id, str):
|
||||
return None
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return None
|
||||
return server_obj.get_play_url(item_id)
|
||||
@@ -354,10 +358,10 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
"""
|
||||
获取媒体服务器最新入库条目
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return []
|
||||
return server_obj.get_latest(num=count)
|
||||
return server_obj.get_latest(num=count) or []
|
||||
|
||||
def mediaserver_latest_images(
|
||||
self,
|
||||
@@ -374,7 +378,31 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
:param remote: True为外网链接, False为内网链接
|
||||
:return: 图片链接列表
|
||||
"""
|
||||
server_obj: TrimeMedia = self.get_instance(server)
|
||||
server_obj: Optional[TrimeMedia] = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return []
|
||||
return server_obj.get_latest_backdrops(num=count, remote=remote)
|
||||
return server_obj.get_latest_backdrops(num=count, remote=remote) or []
|
||||
|
||||
def mediaserver_image_cookies(
|
||||
self,
|
||||
server: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取飞牛影视服务器的图片Cookies
|
||||
|
||||
:param server: 媒体服务器名称
|
||||
:param image_url: 图片网址
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
if server:
|
||||
server_obj = self.get_instance(server)
|
||||
if not server_obj:
|
||||
return None
|
||||
return server_obj.get_image_cookies(image_url)
|
||||
else:
|
||||
for server_obj in self.get_instances().values():
|
||||
if cookies := server_obj.get_image_cookies(image_url):
|
||||
return cookies
|
||||
|
||||
@@ -140,13 +140,13 @@ class Api:
|
||||
self._token: Optional[str] = None
|
||||
self._version: Optional[Version] = None
|
||||
self._session = requests.Session()
|
||||
self._request_utils = RequestUtils(session=self._session)
|
||||
self._request_utils = RequestUtils(session=self._session, timeout=10)
|
||||
|
||||
def sys_version(self) -> Optional[Version]:
|
||||
"""
|
||||
飞牛影视版本号
|
||||
"""
|
||||
if (res := self.__request_api("/sys/version")) and res.success:
|
||||
if (res := self.request("/sys/version")) and res.success:
|
||||
if res.data:
|
||||
self._version = Version(
|
||||
frontend=res.data.get("version"),
|
||||
@@ -162,7 +162,7 @@ class Api:
|
||||
:return: 成功返回token 否则返回None
|
||||
"""
|
||||
if (
|
||||
res := self.__request_api(
|
||||
res := self.request(
|
||||
"/login",
|
||||
data={
|
||||
"username": username,
|
||||
@@ -178,7 +178,9 @@ class Api:
|
||||
"""
|
||||
退出账号
|
||||
"""
|
||||
if (res := self.__request_api("/user/logout", method="post")) and res.success:
|
||||
if not self._token:
|
||||
return True
|
||||
if (res := self.request("/user/logout", method="post")) and res.success:
|
||||
if res.data:
|
||||
self._token = None
|
||||
return True
|
||||
@@ -188,7 +190,9 @@ class Api:
|
||||
"""
|
||||
用户列表(仅管理员有权访问)
|
||||
"""
|
||||
if (res := self.__request_api("/manager/user/list")) and res.success:
|
||||
if (res := self.request("/manager/user/list")) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [
|
||||
User(
|
||||
guid=info.get("guid"),
|
||||
@@ -203,7 +207,7 @@ class Api:
|
||||
"""
|
||||
当前用户信息
|
||||
"""
|
||||
if (res := self.__request_api("/user/info")) and res.success:
|
||||
if (res := self.request("/user/info")) and res.success:
|
||||
_user = User("", "")
|
||||
_user.__dict__.update(res.data)
|
||||
return _user
|
||||
@@ -213,7 +217,7 @@ class Api:
|
||||
"""
|
||||
媒体数量统计
|
||||
"""
|
||||
if (res := self.__request_api("/mediadb/sum")) and res.success:
|
||||
if (res := self.request("/mediadb/sum")) and res.success:
|
||||
sums = MediaDbSummary()
|
||||
sums.__dict__.update(res.data)
|
||||
return sums
|
||||
@@ -223,9 +227,9 @@ class Api:
|
||||
"""
|
||||
媒体库列表(普通用户)
|
||||
"""
|
||||
if (res := self.__request_api("/mediadb/list")) and res.success:
|
||||
if (res := self.request("/mediadb/list")) and res.success:
|
||||
_items = []
|
||||
for info in res.data:
|
||||
for info in res.data or []:
|
||||
mdb = MediaDb(
|
||||
guid=info.get("guid"),
|
||||
category=Category(info.get("category")),
|
||||
@@ -250,9 +254,9 @@ class Api:
|
||||
"""
|
||||
媒体库列表(管理员)
|
||||
"""
|
||||
if (res := self.__request_api("/mdb/list")) and res.success:
|
||||
if (res := self.request("/mdb/list")) and res.success:
|
||||
_items = []
|
||||
for info in res.data:
|
||||
for info in res.data or []:
|
||||
mdb = MediaDb(
|
||||
guid=info.get("guid"),
|
||||
category=Category(info.get("category")),
|
||||
@@ -271,7 +275,7 @@ class Api:
|
||||
"""
|
||||
扫描所有媒体库
|
||||
"""
|
||||
if (res := self.__request_api("/mdb/scanall", method="post")) and res.success:
|
||||
if (res := self.request("/mdb/scanall", method="post")) and res.success:
|
||||
if res.data:
|
||||
return True
|
||||
return False
|
||||
@@ -280,9 +284,7 @@ class Api:
|
||||
"""
|
||||
扫描指定媒体库
|
||||
"""
|
||||
if (
|
||||
res := self.__request_api(f"/mdb/scan/{mdb.guid}", data={})
|
||||
) and res.success:
|
||||
if (res := self.request(f"/mdb/scan/{mdb.guid}", data={})) and res.success:
|
||||
if res.data:
|
||||
return True
|
||||
return False
|
||||
@@ -291,9 +293,7 @@ class Api:
|
||||
"""
|
||||
当前正在运行的任务
|
||||
"""
|
||||
if (
|
||||
res := self.__request_api("/task/running")
|
||||
) and res.success:
|
||||
if (res := self.request("/task/running")) and res.success:
|
||||
if res.data:
|
||||
# TODO 具体正在运行的任务
|
||||
return True
|
||||
@@ -341,7 +341,9 @@ class Api:
|
||||
if exclude_grouped_video:
|
||||
post["exclude_grouped_video"] = 1
|
||||
|
||||
if (res := self.__request_api("/item/list", data=post)) and res.success:
|
||||
if (res := self.request("/item/list", data=post)) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [self.__build_item(info) for info in res.data.get("list", [])]
|
||||
return None
|
||||
|
||||
@@ -350,8 +352,10 @@ class Api:
|
||||
搜索影片、演员
|
||||
"""
|
||||
if (
|
||||
res := self.__request_api("/search/list", params={"q": keywords})
|
||||
res := self.request("/search/list", params={"q": keywords})
|
||||
) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [self.__build_item(info) for info in res.data]
|
||||
return None
|
||||
|
||||
@@ -359,7 +363,7 @@ class Api:
|
||||
"""
|
||||
查询媒体详情
|
||||
"""
|
||||
if (res := self.__request_api(f"/item/{guid}")) and res.success:
|
||||
if (res := self.request(f"/item/{guid}")) and res.success:
|
||||
return self.__build_item(res.data)
|
||||
return None
|
||||
|
||||
@@ -370,7 +374,7 @@ class Api:
|
||||
:param delete_file: True删除媒体文件,False仅从媒体库移除
|
||||
"""
|
||||
if (
|
||||
res := self.__request_api(
|
||||
res := self.request(
|
||||
f"/item/{guid}",
|
||||
method="delete",
|
||||
data={"delete_file": 1 if delete_file else 0, "media_guids": []},
|
||||
@@ -384,7 +388,9 @@ class Api:
|
||||
"""
|
||||
查询季列表
|
||||
"""
|
||||
if (res := self.__request_api(f"/season/list/{tv_guid}")) and res.success:
|
||||
if (res := self.request(f"/season/list/{tv_guid}")) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [self.__build_item(info) for info in res.data]
|
||||
return None
|
||||
|
||||
@@ -392,7 +398,9 @@ class Api:
|
||||
"""
|
||||
查询剧集列表
|
||||
"""
|
||||
if (res := self.__request_api(f"/episode/list/{season_guid}")) and res.success:
|
||||
if (res := self.request(f"/episode/list/{season_guid}")) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [self.__build_item(info) for info in res.data]
|
||||
return None
|
||||
|
||||
@@ -400,7 +408,9 @@ class Api:
|
||||
"""
|
||||
继续观看列表
|
||||
"""
|
||||
if (res := self.__request_api("/play/list")) and res.success:
|
||||
if (res := self.request("/play/list")) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [self.__build_item(info) for info in res.data]
|
||||
return None
|
||||
|
||||
@@ -431,7 +441,7 @@ class Api:
|
||||
sign = md5.hexdigest()
|
||||
return f"nonce={nonce}×tamp={ts}&sign={sign}"
|
||||
|
||||
def __request_api(
|
||||
def request(
|
||||
self,
|
||||
api: str,
|
||||
method: Optional[str] = None,
|
||||
@@ -482,6 +492,8 @@ class Api:
|
||||
queries_unquoted = None
|
||||
headers = {
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Accept": "application/json",
|
||||
"Referer": self._host,
|
||||
"Authorization": self._token,
|
||||
"authx": self.__get_authx(api_path, json_body or queries_unquoted),
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import app.modules.trimemedia.api as fnapi
|
||||
from app import schemas
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
@@ -13,12 +14,14 @@ class TrimeMedia:
|
||||
_password: Optional[str] = None
|
||||
|
||||
_userinfo: Optional[fnapi.User] = None
|
||||
_host: Optional[str] = None
|
||||
_playhost: Optional[str] = None
|
||||
|
||||
_libraries: dict[str, fnapi.MediaDb] = {}
|
||||
_sync_libraries: List[str] = []
|
||||
|
||||
_api: Optional[fnapi.Api] = None
|
||||
_version: Optional[fnapi.Version] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -34,20 +37,19 @@ class TrimeMedia:
|
||||
return
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._host = host
|
||||
self._sync_libraries = sync_libraries or []
|
||||
|
||||
if (api := self.__create_api(host)) is None:
|
||||
if not self.reconnect():
|
||||
logger.error(f"请检查服务端地址 {host}")
|
||||
return
|
||||
self._api = api
|
||||
if play_api := self.__create_api(play_host):
|
||||
self._playhost = play_api.host
|
||||
if result := self.__create_api(play_host):
|
||||
self._playhost = result.api.host
|
||||
result.api.close()
|
||||
elif play_host:
|
||||
logger.warning(f"请检查外网播放地址 {play_host}")
|
||||
self._playhost = UrlUtils.standardize_base_url(play_host).rstrip("/")
|
||||
|
||||
self.reconnect()
|
||||
|
||||
@property
|
||||
def api(self) -> Optional[fnapi.Api]:
|
||||
"""
|
||||
@@ -55,14 +57,26 @@ class TrimeMedia:
|
||||
"""
|
||||
return self._api
|
||||
|
||||
@property
|
||||
def version(self) -> Optional[fnapi.Version]:
|
||||
"""
|
||||
获得飞牛API的版本
|
||||
"""
|
||||
return self._version
|
||||
|
||||
class _ApiCreateResult:
|
||||
api: fnapi.Api
|
||||
version: fnapi.Version
|
||||
|
||||
@staticmethod
|
||||
def __create_api(host: Optional[str]) -> Optional[fnapi.Api]:
|
||||
def __create_api(host: Optional[str]) -> Optional["TrimeMedia._ApiCreateResult"]:
|
||||
"""
|
||||
创建一个飞牛API
|
||||
|
||||
:param host: 服务端地址
|
||||
:return: 如果地址无效、不可访问则返回None
|
||||
"""
|
||||
|
||||
if not host:
|
||||
return None
|
||||
api_key = "16CCEB3D-AB42-077D-36A1-F355324E4237"
|
||||
@@ -70,21 +84,35 @@ class TrimeMedia:
|
||||
|
||||
if not host.endswith("/v"):
|
||||
# 尝试补上结尾的/v 测试能否正常访问
|
||||
api = fnapi.Api(host + "/v", api_key)
|
||||
if api.sys_version():
|
||||
return api
|
||||
res = TrimeMedia._ApiCreateResult()
|
||||
res.api = fnapi.Api(host + "/v", api_key)
|
||||
if fnver := res.api.sys_version():
|
||||
res.version = fnver
|
||||
return res
|
||||
# 测试用户配置的地址
|
||||
api = fnapi.Api(host, api_key)
|
||||
return api if api.sys_version() else None
|
||||
res = TrimeMedia._ApiCreateResult()
|
||||
res.api = fnapi.Api(host, api_key)
|
||||
if fnver := res.api.sys_version():
|
||||
res.version = fnver
|
||||
return res
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
self.disconnect()
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return self._api is not None
|
||||
return bool(self._host and self._username and self._password)
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
return self.is_configured() and self._api.token is not None
|
||||
"""
|
||||
是否已登录
|
||||
"""
|
||||
return (
|
||||
self.is_configured()
|
||||
and self._api is not None
|
||||
and self._api.token is not None
|
||||
and self._userinfo is not None
|
||||
)
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
"""
|
||||
@@ -101,10 +129,17 @@ class TrimeMedia:
|
||||
"""
|
||||
if not self.is_configured():
|
||||
return False
|
||||
if (fnver := self._api.sys_version()) is None:
|
||||
self.disconnect()
|
||||
if result := self.__create_api(self._host):
|
||||
self._api = result.api
|
||||
self._version = result.version
|
||||
# 版本号:0.8.53, 服务版本:0.8.23
|
||||
# 版本号:0.8.56, 服务版本:0.8.23 接口/memory/user/list改为/manager/user/list
|
||||
logger.debug(
|
||||
f"版本号:{result.version.frontend}, 服务版本:{result.version.backend}"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
# 版本号:0.8.36, 服务版本:0.8.19
|
||||
logger.debug(f"版本号:{fnver.frontend}, 服务版本:{fnver.backend}")
|
||||
if self._api.login(self._username, self._password) is None:
|
||||
return False
|
||||
self._userinfo = self._api.user_info()
|
||||
@@ -119,9 +154,10 @@ class TrimeMedia:
|
||||
"""
|
||||
断开与飞牛的连接
|
||||
"""
|
||||
if self.is_authenticated():
|
||||
if self._api:
|
||||
self._api.logout()
|
||||
self._api.close()
|
||||
self._api = None
|
||||
self._userinfo = None
|
||||
logger.debug(f"{self._username} 已断开飞牛影视")
|
||||
|
||||
@@ -163,7 +199,8 @@ class TrimeMedia:
|
||||
for img_path in library.posters or []
|
||||
],
|
||||
link=f"{self._playhost or self._api.host}/library/{library.guid}",
|
||||
server_type='trimemedia'
|
||||
server_type="trimemedia",
|
||||
use_cookies=True,
|
||||
)
|
||||
)
|
||||
return libraries
|
||||
@@ -205,10 +242,12 @@ class TrimeMedia:
|
||||
return None
|
||||
if not self.is_configured():
|
||||
return None
|
||||
feiniu = fnapi.Api(self._api.host, self._api.apikey)
|
||||
if token := feiniu.login(username, password):
|
||||
feiniu.logout()
|
||||
return token
|
||||
if result := self.__create_api(self._host):
|
||||
try:
|
||||
return result.api.login(username, password)
|
||||
finally:
|
||||
result.api.logout()
|
||||
result.api.close()
|
||||
|
||||
def get_movies(
|
||||
self, title: str, year: Optional[str] = None, tmdb_id: Optional[int] = None
|
||||
@@ -410,7 +449,7 @@ class TrimeMedia:
|
||||
item_type=item_type,
|
||||
title=item.title,
|
||||
original_title=item.original_title,
|
||||
year=year,
|
||||
year=str(year),
|
||||
tmdbid=item.tmdb_id,
|
||||
imdbid=item.imdb_id,
|
||||
user_state=user_state,
|
||||
@@ -459,7 +498,8 @@ class TrimeMedia:
|
||||
if item.duration and item.ts is not None
|
||||
else 0
|
||||
),
|
||||
server_type='trimemedia',
|
||||
server_type="trimemedia",
|
||||
use_cookies=True,
|
||||
)
|
||||
|
||||
def get_items(
|
||||
@@ -576,6 +616,7 @@ class TrimeMedia:
|
||||
if (item_details := self._api.item(item.guid)) is None:
|
||||
continue
|
||||
if remote:
|
||||
# FIXME 新版飞牛的壁纸无法直接在浏览器中访问
|
||||
img_host = self._playhost or self._api.host
|
||||
else:
|
||||
img_host = self._api.host
|
||||
@@ -604,3 +645,15 @@ class TrimeMedia:
|
||||
)
|
||||
else False
|
||||
)
|
||||
|
||||
def get_image_cookies(self, image_url: str):
|
||||
"""
|
||||
获得指定图片的Cookies
|
||||
"""
|
||||
if not self.is_authenticated():
|
||||
return None
|
||||
if not image_url or not SecurityUtils.is_safe_url(
|
||||
image_url, [self._api.host], strict=True
|
||||
):
|
||||
return None
|
||||
return {"Trim-MC-token": self._api.token}
|
||||
|
||||
@@ -139,7 +139,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
logger.error(f"VoceChat消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -71,7 +71,7 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -184,7 +184,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification) -> None:
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
:param message: 消息内容
|
||||
|
||||
@@ -46,12 +46,18 @@ class FileMonitorHandler(FileSystemEventHandler):
|
||||
self.callback = callback
|
||||
|
||||
def on_created(self, event: FileSystemEvent):
|
||||
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
|
||||
file_size=Path(event.src_path).stat().st_size)
|
||||
try:
|
||||
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
|
||||
file_size=Path(event.src_path).stat().st_size)
|
||||
except Exception as e:
|
||||
logger.error(f"on_created 异常: {e}")
|
||||
|
||||
def on_moved(self, event: FileSystemMovedEvent):
|
||||
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
|
||||
file_size=Path(event.dest_path).stat().st_size)
|
||||
try:
|
||||
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
|
||||
file_size=Path(event.dest_path).stat().st_size)
|
||||
except Exception as e:
|
||||
logger.error(f"on_moved 异常: {e}")
|
||||
|
||||
|
||||
class Monitor(metaclass=SingletonClass):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Dict, Tuple, Optional
|
||||
from typing import Any, List, Dict, Tuple, Optional, Type
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
@@ -200,6 +200,20 @@ class _PluginBase(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_agent_tools(self) -> List[Type]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
返回工具类列表,每个工具类必须继承自 MoviePilotTool
|
||||
[ToolClass1, ToolClass2, ...]
|
||||
|
||||
对工具类的要求:
|
||||
1、工具类必须继承自 app.agent.tools.base.MoviePilotTool
|
||||
2、工具类需要实现 run 方法(异步方法)
|
||||
3、工具类需要定义 name 和 description 属性
|
||||
4、工具类可以定义 args_schema 来指定输入参数模型
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_service(self):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import threading
|
||||
@@ -30,6 +31,7 @@ from app.helper.wallpaper import WallpaperHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType, Workflow, ConfigChangeEventData
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.gc import get_memory_usage
|
||||
from app.utils.singleton import SingletonClass
|
||||
from app.utils.timer import TimerUtils
|
||||
|
||||
@@ -181,6 +183,11 @@ class Scheduler(metaclass=SingletonClass):
|
||||
"name": "订阅日历缓存",
|
||||
"func": SubscribeChain().cache_calendar,
|
||||
"running": False
|
||||
},
|
||||
"full_gc": {
|
||||
"name": "主动内存回收",
|
||||
"func": self.full_gc,
|
||||
"running": False
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,6 +420,19 @@ class Scheduler(metaclass=SingletonClass):
|
||||
}
|
||||
)
|
||||
|
||||
# 主动内存回收
|
||||
if settings.MEMORY_GC_INTERVAL:
|
||||
self._scheduler.add_job(
|
||||
self.start,
|
||||
"interval",
|
||||
id="full_gc",
|
||||
name="主动内存回收",
|
||||
minutes=settings.MEMORY_GC_INTERVAL,
|
||||
kwargs={
|
||||
'job_id': 'full_gc'
|
||||
}
|
||||
)
|
||||
|
||||
# 初始化工作流服务
|
||||
self.init_workflow_jobs()
|
||||
|
||||
@@ -747,6 +767,17 @@ class Scheduler(metaclass=SingletonClass):
|
||||
"""
|
||||
SchedulerChain().clear_cache()
|
||||
|
||||
@staticmethod
|
||||
def full_gc():
|
||||
"""
|
||||
主动内存回收
|
||||
"""
|
||||
memory_before = get_memory_usage()
|
||||
collected = gc.collect()
|
||||
memory_after = get_memory_usage()
|
||||
memory_freed = memory_before - memory_after
|
||||
logger.info(f"主动内存回收完成,回收对象数: {collected},释放内存: {memory_freed:.2f} MB")
|
||||
|
||||
def user_auth(self):
|
||||
"""
|
||||
用户认证检查
|
||||
|
||||
58
app/schemas/agent.py
Normal file
58
app/schemas/agent.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""AI智能体相关数据模型"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
class ConversationMemory(BaseModel):
|
||||
"""对话记忆模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
title: Optional[str] = Field(default=None, description="会话标题")
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
@field_serializer('created_at', 'updated_at', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""AI智能体状态模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
current_task: Optional[str] = Field(default=None, description="当前任务")
|
||||
is_thinking: bool = Field(default=False, description="是否正在思考")
|
||||
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
@field_serializer('last_activity', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""用户消息模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
content: str = Field(description="消息内容")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
channel: Optional[str] = Field(default=None, description="消息渠道")
|
||||
source: Optional[str] = Field(default=None, description="消息来源")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
call_id: str = Field(description="调用ID")
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Optional[str] = Field(default=None, description="执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
@@ -86,7 +86,7 @@ class MediaInfo(BaseModel):
|
||||
# IMDB ID
|
||||
imdb_id: Optional[str] = None
|
||||
# TVDB ID
|
||||
tvdb_id: Optional[str] = None
|
||||
tvdb_id: Optional[int] = None
|
||||
# 豆瓣ID
|
||||
douban_id: Optional[str] = None
|
||||
# Bangumi ID
|
||||
@@ -158,6 +158,8 @@ class MediaInfo(BaseModel):
|
||||
production_countries: Optional[list] = Field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: Optional[list] = Field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = Field(default_factory=list)
|
||||
# 状态
|
||||
status: Optional[str] = None
|
||||
# 标签
|
||||
@@ -167,7 +169,7 @@ class MediaInfo(BaseModel):
|
||||
# 评价数量
|
||||
vote_count: Optional[int] = 0
|
||||
# 流行度
|
||||
popularity: Optional[int] = 0
|
||||
popularity: Optional[float] = 0.0
|
||||
# 时长
|
||||
runtime: Optional[int] = None
|
||||
# 下一集
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Set, Callable
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.schemas.message import MessageChannel
|
||||
from app.schemas.file import FileItem
|
||||
@@ -68,7 +68,8 @@ class AuthCredentials(ChainEventData):
|
||||
channel: Optional[str] = Field(default=None, description="认证渠道")
|
||||
service: Optional[str] = Field(default=None, description="服务名称")
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def check_fields_based_on_grant_type(cls, values): # noqa
|
||||
grant_type = values.get("grant_type")
|
||||
if not grant_type:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class DownloadHistory(BaseModel):
|
||||
@@ -51,8 +51,7 @@ class DownloadHistory(BaseModel):
|
||||
# 自定义剧集组
|
||||
episode_group: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TransferHistory(BaseModel):
|
||||
@@ -97,5 +96,4 @@ class TransferHistory(BaseModel):
|
||||
# 日期
|
||||
date: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user