mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-15 23:16:45 +00:00
Compare commits
77 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30059eff4f | ||
|
|
bc289b48c8 | ||
|
|
067d8b99b8 | ||
|
|
00a6a9c42d | ||
|
|
070425d446 | ||
|
|
7405883444 | ||
|
|
66959937ed | ||
|
|
e431efbcba | ||
|
|
ba00baa5a0 | ||
|
|
0fb5d4a164 | ||
|
|
1ac717b67f | ||
|
|
273cbd447e | ||
|
|
cee41567a2 | ||
|
|
1aae5eb1a6 | ||
|
|
28a4c81aff | ||
|
|
5e077cd64d | ||
|
|
e3f957a59b | ||
|
|
55c62a3ab5 | ||
|
|
22e7eef1bd | ||
|
|
d6524907f3 | ||
|
|
357db334cd | ||
|
|
f8bed3909b | ||
|
|
182bbdde91 | ||
|
|
2c70f990c2 | ||
|
|
0b01a6aa91 | ||
|
|
e557dffbc6 | ||
|
|
7f33b0b1b8 | ||
|
|
41ddf77a5b | ||
|
|
8c657ce41d | ||
|
|
3ff3b9ed4a | ||
|
|
ef43419ecd | ||
|
|
2ca375c214 | ||
|
|
cbd45c1d0f | ||
|
|
2592ea3464 | ||
|
|
73ac97cd96 | ||
|
|
e014663e97 | ||
|
|
58592e961f | ||
|
|
9a99b9ce82 | ||
|
|
8c6dca1751 | ||
|
|
cf488d5f5f | ||
|
|
515584d34c | ||
|
|
fb2becc7f2 | ||
|
|
0f8ceb0fac | ||
|
|
a70bf18770 | ||
|
|
2de83c44ab | ||
|
|
7b99f09810 | ||
|
|
6b4ba8bfad | ||
|
|
0c6cfc5020 | ||
|
|
abd9733e7f | ||
|
|
98c3ae5e76 | ||
|
|
bb5a657469 | ||
|
|
7797532350 | ||
|
|
c3a5106adc | ||
|
|
c5fd935dd0 | ||
|
|
ec375a19ae | ||
|
|
51e940617c | ||
|
|
58ec8bd437 | ||
|
|
a096395086 | ||
|
|
4bd08bd915 | ||
|
|
2c849cfa7a | ||
|
|
501d530d1d | ||
|
|
91fc4327f4 | ||
|
|
8d56c67079 | ||
|
|
e52d43458e | ||
|
|
9b125bf9b0 | ||
|
|
0716c65269 | ||
|
|
ba3ce4f1b5 | ||
|
|
07f72b0cdc | ||
|
|
bda19df87f | ||
|
|
5d82fae2b0 | ||
|
|
0813b87221 | ||
|
|
961ecfc720 | ||
|
|
81f30ef25a | ||
|
|
140b0d3df2 | ||
|
|
b3d69d7de4 | ||
|
|
8e65564fb8 | ||
|
|
06ce9bd4de |
@@ -7,7 +7,7 @@ from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
@@ -56,9 +56,6 @@ class MoviePilotAgent:
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
@@ -127,7 +124,8 @@ class MoviePilotAgent:
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -137,34 +135,36 @@ class MoviePilotAgent:
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
@@ -306,8 +306,6 @@ class MoviePilotAgent:
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
|
||||
@@ -45,17 +45,27 @@ class ConversationMemoryManager:
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
@@ -180,15 +190,13 @@ class ConversationMemoryManager:
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
return memory.messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
@@ -218,7 +226,7 @@ class ConversationMemoryManager:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
@@ -229,14 +237,14 @@ class ConversationMemoryManager:
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -35,24 +37,53 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
# 发送和记忆工具调用前的信息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
# 发送消息
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__,
|
||||
"tool_name": self.__class__.__name__,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -84,6 +115,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
|
||||
@@ -51,7 +51,7 @@ class MoviePilotToolFactory:
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
@@ -102,6 +102,7 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -124,6 +125,7 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,6 +9,8 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
@@ -27,6 +29,27 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
def _get_all_torrents(self, download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
) or []
|
||||
all_torrents.extend(downloading_torrents)
|
||||
|
||||
# 查询已完成的任务(可转移状态)
|
||||
transfer_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.TRANSFER
|
||||
) or []
|
||||
all_torrents.extend(transfer_torrents)
|
||||
|
||||
return all_torrents
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
@@ -60,7 +83,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
@@ -84,14 +107,25 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
title_lower = title.lower()
|
||||
for torrent in all_torrents:
|
||||
# 检查标题或名称是否匹配
|
||||
if (title.lower() in (torrent.title or "").lower()) or \
|
||||
(title.lower() in (torrent.name or "").lower()):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
|
||||
# 检查标题或名称是否匹配(包括下载历史中的标题)
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (torrent.name or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
if title_lower in history.title.lower():
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
@@ -110,7 +144,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
downloads = download_chain.downloading(name=downloader) or []
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
@@ -119,7 +153,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
@@ -21,7 +23,7 @@ class ToolDefinition:
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
@@ -32,6 +34,7 @@ class MoviePilotToolsManager:
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
@@ -44,7 +47,8 @@ class MoviePilotToolsManager:
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -96,6 +100,73 @@ class MoviePilotToolsManager:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in properties:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
# 从 anyOf 中提取实际类型
|
||||
for type_option in any_of:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = value
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = value
|
||||
elif field_type == "boolean" and isinstance(value, str):
|
||||
# 转换字符串为布尔值
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
else:
|
||||
# 其他类型保持原样
|
||||
normalized[key] = value
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
@@ -116,14 +187,21 @@ class MoviePilotToolsManager:
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**arguments)
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
|
||||
@@ -2,11 +2,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
|
||||
@@ -29,6 +29,13 @@ def login_access_token(
|
||||
mfa_code=otp_password)
|
||||
|
||||
if not success:
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
|
||||
# 用户等级
|
||||
@@ -50,7 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,43 +2,251 @@
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
# 导入版本号
|
||||
try:
|
||||
from version import APP_VERSION
|
||||
except ImportError:
|
||||
APP_VERSION = "unknown"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局工具管理器实例(单例模式,按用户ID缓存)
|
||||
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例(按用户ID缓存)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
global _tools_managers
|
||||
# 使用用户ID作为缓存键
|
||||
cache_key = f"{user_id}_{session_id}"
|
||||
if cache_key not in _tools_managers:
|
||||
_tools_managers[cache_key] = MoviePilotToolsManager(
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
return _tools_managers[cache_key]
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
if data is not None:
|
||||
error["error"]["data"] = data
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析请求体失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
|
||||
)
|
||||
|
||||
# 验证 JSON-RPC 格式
|
||||
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
|
||||
)
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
# 如果有 id,则为请求;没有 id 则为通知
|
||||
is_notification = request_id is None
|
||||
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
elif method == "ping":
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
|
||||
|
||||
# 未知方法
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
# 客户端版本在支持列表中,使用客户端版本
|
||||
negotiated_version = protocol_version
|
||||
logger.info(f"使用客户端协议版本: {negotiated_version}")
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
for tool in tools:
|
||||
mcp_tool = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
mcp_tools.append(mcp_tool)
|
||||
|
||||
return {
|
||||
"tools": mcp_tools
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result_text
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误: {str(e)}"
|
||||
}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
@@ -72,7 +280,7 @@ async def list_tools(
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
|
||||
@@ -82,7 +82,7 @@ def exists(media_in: schemas.MediaInfo,
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
return {}
|
||||
if media_in.season:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
|
||||
463
app/api/endpoints/mfa.py
Normal file
463
app/api/endpoints/mfa.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db import get_async_db
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.passkey import PassKeyHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey)
|
||||
"""
|
||||
user: User = await User.async_get_by_name(db, username)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
|
||||
# 检查是否启用了OTP
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
|
||||
|
||||
# ==================== OTP 相关接口 ====================
|
||||
|
||||
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""生成 OTP 密钥及对应的 URI"""
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
if has_passkey:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
# ==================== PassKey 相关接口 ====================
|
||||
|
||||
class PassKeyRegistrationStart(schemas.BaseModel):
|
||||
"""PassKey注册开始请求"""
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
"""PassKey注册完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
"""PassKey认证完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
|
||||
|
||||
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_start(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:必须先启用 OTP
|
||||
if not current_user.is_otp:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
] if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
user_id=current_user.id,
|
||||
username=current_user.name,
|
||||
display_name=current_user.settings.get('nickname') if current_user.settings else None,
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey注册选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成注册选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_finish(
|
||||
passkey_req: PassKeyRegistrationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""完成注册 PassKey - 验证并保存凭证"""
|
||||
try:
|
||||
# 验证注册响应
|
||||
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge
|
||||
)
|
||||
|
||||
# 提取transports
|
||||
transports = None
|
||||
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
|
||||
transports = ','.join(passkey_req.credential['response']['transports'])
|
||||
|
||||
# 保存到数据库
|
||||
passkey = PassKey(
|
||||
user_id=current_user.id,
|
||||
credential_id=credential_id,
|
||||
public_key=public_key,
|
||||
sign_count=sign_count,
|
||||
name=passkey_req.name or "通行密钥",
|
||||
aaguid=aaguid,
|
||||
transports=transports
|
||||
)
|
||||
passkey.create()
|
||||
|
||||
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥注册成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"注册失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
|
||||
def passkey_authenticate_start(
|
||||
passkey_req: PassKeyAuthenticationStart = Body(...)
|
||||
) -> Any:
|
||||
"""开始 PassKey 认证 - 生成认证选项"""
|
||||
try:
|
||||
existing_credentials = None
|
||||
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
if not user:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="用户不存在"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="该用户未注册通行密钥"
|
||||
)
|
||||
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成认证选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
|
||||
def passkey_authenticate_finish(
|
||||
passkey_req: PassKeyAuthenticationFinish
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise HTTPException(status_code=400, detail="无效的凭证")
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
|
||||
# 获取用户
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
# 生成token
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
def passkey_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
key_list = [
|
||||
{
|
||||
'id': pk.id,
|
||||
'name': pk.name,
|
||||
'created_at': pk.created_at.isoformat() if pk.created_at else None,
|
||||
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
|
||||
'aaguid': pk.aaguid,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=key_list
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取PassKey列表失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"获取列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或无权删除"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"删除PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"删除失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
|
||||
def passkey_verify_mfa(
|
||||
passkey_req: PassKeyAuthenticationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="无效的凭证"
|
||||
)
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="二次验证成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
)
|
||||
@@ -134,18 +134,24 @@ def get_global_setting(token: str):
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
# 白名单模式,仅包含前端业务逻辑必需的字段
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE"
|
||||
}
|
||||
)
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
@@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
|
||||
|
||||
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
|
||||
def get_config(key: str,
|
||||
current_user: User = Depends(get_current_active_user)):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -128,9 +128,12 @@ async def get_cookie(
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
async def post_cookie(
|
||||
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
|
||||
request: schemas.CookiePassword):
|
||||
request: Optional[schemas.CookiePassword] = Body(None)):
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
if request is not None:
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -4,6 +4,7 @@ import pickle
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
@@ -849,6 +850,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
@@ -932,6 +935,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
|
||||
@@ -618,7 +618,7 @@ class MediaChain(ChainBase):
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 流式下载图片并直接保存
|
||||
|
||||
@@ -195,10 +195,14 @@ class MessageChain(ChainBase):
|
||||
if text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
if not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
@@ -370,12 +374,13 @@ class MessageChain(ChainBase):
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
if _current_page == 0:
|
||||
# 第一页
|
||||
@@ -422,12 +427,13 @@ class MessageChain(ChainBase):
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
|
||||
@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -169,7 +169,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
|
||||
@@ -52,7 +52,10 @@ class UserChain(ChainBase):
|
||||
success, user_or_message = self.password_authenticate(credentials=credentials)
|
||||
if success:
|
||||
# 如果用户启用了二次验证码,则进一步验证
|
||||
if not self._verify_mfa(user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
logger.info(f"用户 {username} 通过密码认证成功")
|
||||
return True, user_or_message
|
||||
@@ -63,7 +66,10 @@ class UserChain(ChainBase):
|
||||
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
|
||||
if aux_success:
|
||||
# 辅助认证成功后再验证二次验证码
|
||||
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
return True, aux_user_or_message
|
||||
else:
|
||||
@@ -159,22 +165,46 @@ class UserChain(ChainBase):
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
|
||||
@staticmethod
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]:
|
||||
"""
|
||||
验证 MFA(二次验证码)
|
||||
检查用户是否启用了 OTP 或 PassKey,如果启用了任何一种,都需要提供验证
|
||||
|
||||
:param user: 用户对象
|
||||
:param mfa_code: 二次验证码
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:param mfa_code: 二次验证码(如果提供了则验证OTP)
|
||||
:return:
|
||||
- 如果验证成功返回 True
|
||||
- 如果需要MFA但未提供,返回 "MFA_REQUIRED"
|
||||
- 如果MFA验证失败,返回 False
|
||||
"""
|
||||
if not user.is_otp:
|
||||
# 检查用户是否有PassKey
|
||||
from app.db.models.passkey import PassKey
|
||||
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
|
||||
|
||||
# 如果用户既没有启用OTP也没有PassKey,直接通过
|
||||
if not user.is_otp and not has_passkey:
|
||||
return True
|
||||
|
||||
# 如果用户启用了OTP或PassKey,但没有提供验证码,需要进行二次验证
|
||||
if not mfa_code:
|
||||
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
|
||||
return False
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
# 如果提供了验证码,且用户启用了 OTP,则验证 OTP
|
||||
if user.is_otp:
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
# OTP 验证成功
|
||||
return True
|
||||
|
||||
# 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证
|
||||
if has_passkey:
|
||||
logger.info(
|
||||
f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
|
||||
)
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
return True
|
||||
|
||||
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
|
||||
|
||||
@@ -393,6 +393,8 @@ class ConfigModel(BaseModel):
|
||||
])
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
@@ -407,6 +409,8 @@ class ConfigModel(BaseModel):
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
# Playwright浏览器类型,chromium/firefox
|
||||
PLAYWRIGHT_BROWSER_TYPE: str = "chromium"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
@@ -430,9 +434,9 @@ class ConfigModel(BaseModel):
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 30
|
||||
# 内存记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
@@ -95,18 +95,20 @@ class TorrentInfo:
|
||||
if upload_volume_factor is None or download_volume_factor is None:
|
||||
return "未知"
|
||||
free_strs = {
|
||||
"1.0 1.0": "普通",
|
||||
"1.0 0.0": "免费",
|
||||
"2.0 1.0": "2X",
|
||||
"4.0 1.0": "4X",
|
||||
"2.0 0.0": "2X免费",
|
||||
"4.0 0.0": "4X免费",
|
||||
"1.0 0.5": "50%",
|
||||
"2.0 0.5": "2X 50%",
|
||||
"1.0 0.7": "70%",
|
||||
"1.0 0.3": "30%"
|
||||
"1.00 1.00": "普通",
|
||||
"1.00 0.00": "免费",
|
||||
"2.00 1.00": "2X",
|
||||
"4.00 1.00": "4X",
|
||||
"2.00 0.00": "2X免费",
|
||||
"4.00 0.00": "4X免费",
|
||||
"1.00 0.50": "50%",
|
||||
"2.00 0.50": "2X 50%",
|
||||
"1.00 0.70": "70%",
|
||||
"1.00 0.30": "30%",
|
||||
"1.00 0.75": "75%",
|
||||
"1.00 0.25": "25%"
|
||||
}
|
||||
return free_strs.get('%.1f %.1f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
return free_strs.get('%.2f %.2f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
|
||||
@property
|
||||
def volume_factor(self):
|
||||
|
||||
@@ -71,12 +71,14 @@ def MetaInfoPath(path: Path) -> MetaBase:
|
||||
file_meta = MetaInfo(title=path.name)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
# 上上级目录元数据
|
||||
root_meta = MetaInfo(title=path.parent.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
return file_meta
|
||||
|
||||
|
||||
|
||||
@@ -454,7 +454,6 @@ class Base:
|
||||
|
||||
@db_update
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
@@ -462,7 +461,6 @@ class Base:
|
||||
|
||||
@async_db_update
|
||||
async def async_update(self, db: AsyncSession, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
|
||||
126
app/db/models/passkey.py
Normal file
126
app/db/models/passkey.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, select, ForeignKey
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
|
||||
|
||||
|
||||
class PassKey(Base):
|
||||
"""
|
||||
用户PassKey凭证表
|
||||
"""
|
||||
# ID
|
||||
id = get_id_column()
|
||||
# 用户ID
|
||||
user_id = Column(Integer, ForeignKey('user.id'), nullable=False, index=True)
|
||||
# 凭证ID (credential_id)
|
||||
credential_id = Column(String, nullable=False, unique=True, index=True)
|
||||
# 凭证公钥
|
||||
public_key = Column(Text, nullable=False)
|
||||
# 签名计数器
|
||||
sign_count = Column(Integer, default=0)
|
||||
# 凭证名称(用户自定义)
|
||||
name = Column(String, default="通行密钥")
|
||||
# AAGUID (Authenticator Attestation GUID)
|
||||
aaguid = Column(String, nullable=True)
|
||||
# 创建时间
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
# 最后使用时间
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
# 是否启用
|
||||
is_active = Column(Boolean, default=True)
|
||||
# 传输方式 (usb, nfc, ble, internal)
|
||||
transports = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_user_id(cls, db: Session, user_id: int):
|
||||
"""获取用户的所有PassKey"""
|
||||
return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
|
||||
"""异步获取用户的所有PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_credential_id(cls, db: Session, credential_id: str):
|
||||
"""根据凭证ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
|
||||
"""异步根据凭证ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_id(cls, db: Session, passkey_id: int):
|
||||
"""根据ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.id == passkey_id).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_id(cls, db: AsyncSession, passkey_id: int):
|
||||
"""异步根据ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.id == passkey_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
|
||||
"""删除指定用户的PassKey"""
|
||||
passkey = db.query(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
).first()
|
||||
if passkey:
|
||||
passkey.delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
|
||||
"""异步删除指定用户的PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
)
|
||||
)
|
||||
passkey = result.scalars().first()
|
||||
if passkey:
|
||||
await passkey.async_delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@db_update
|
||||
def update_last_used(self, db: Session, sign_count: int):
|
||||
"""更新最后使用时间和签名计数"""
|
||||
self.update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_update_last_used(self, db: AsyncSession, sign_count: int):
|
||||
"""异步更新最后使用时间和签名计数"""
|
||||
await self.async_update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
@@ -10,7 +10,7 @@ from app.utils.http import RequestUtils, cookie_parse
|
||||
|
||||
|
||||
class PlaywrightHelper:
|
||||
def __init__(self, browser_type="chromium"):
|
||||
def __init__(self, browser_type=settings.PLAYWRIGHT_BROWSER_TYPE):
|
||||
self.browser_type = browser_type
|
||||
|
||||
@staticmethod
|
||||
|
||||
352
app/helper/passkey.py
Normal file
352
app/helper/passkey.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
PassKey WebAuthn 辅助工具类
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import binascii
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from webauthn import (
|
||||
generate_registration_options,
|
||||
verify_registration_response,
|
||||
generate_authentication_options,
|
||||
verify_authentication_response,
|
||||
options_to_json
|
||||
)
|
||||
from webauthn.helpers import (
|
||||
parse_registration_credential_json,
|
||||
parse_authentication_credential_json
|
||||
)
|
||||
from webauthn.helpers.structs import (
|
||||
PublicKeyCredentialDescriptor,
|
||||
AuthenticatorTransport,
|
||||
UserVerificationRequirement,
|
||||
AuthenticatorAttachment,
|
||||
ResidentKeyRequirement,
|
||||
AuthenticatorSelectionCriteria
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PassKeyHelper:
|
||||
"""
|
||||
PassKey WebAuthn 辅助类
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_rp_id() -> str:
|
||||
"""
|
||||
获取 Relying Party ID
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
app_domain = settings.APP_DOMAIN.strip()
|
||||
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
|
||||
if not app_domain.startswith(('http://', 'https://')):
|
||||
app_domain = f'https://{app_domain}'
|
||||
parsed = urlparse(app_domain)
|
||||
host = parsed.hostname
|
||||
if host:
|
||||
return host
|
||||
# 从 APP_DOMAIN 中提取域名
|
||||
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
|
||||
# 移除端口号
|
||||
if ':' in host:
|
||||
host = host.split(':')[0]
|
||||
return host
|
||||
# 只有在未配置 APP_DOMAIN 时,才默认为 localhost
|
||||
return 'localhost'
|
||||
|
||||
@staticmethod
|
||||
def get_rp_name() -> str:
|
||||
"""
|
||||
获取 Relying Party 名称
|
||||
"""
|
||||
return "MoviePilot"
|
||||
|
||||
@staticmethod
|
||||
def get_origin() -> str:
|
||||
"""
|
||||
获取源地址
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
return settings.APP_DOMAIN.rstrip('/')
|
||||
# 如果未配置APP_DOMAIN,使用默认的localhost地址
|
||||
return f'http://localhost:{settings.NGINX_PORT}'
|
||||
|
||||
@staticmethod
|
||||
def standardize_credential_id(credential_id: str) -> str:
|
||||
"""
|
||||
标准化凭证ID(Base64 URL Safe)
|
||||
"""
|
||||
try:
|
||||
# Base64解码并重新编码以标准化格式
|
||||
decoded = base64.urlsafe_b64decode(credential_id + '==')
|
||||
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
|
||||
except (binascii.Error, TypeError, ValueError) as e:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@staticmethod
|
||||
def generate_registration_options(
|
||||
user_id: int,
|
||||
username: str,
|
||||
display_name: Optional[str] = None,
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成注册选项
|
||||
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param display_name: 显示名称
|
||||
:param existing_credentials: 已存在的凭证列表
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 用户信息
|
||||
user_id_bytes = str(user_id).encode('utf-8')
|
||||
|
||||
# 排除已有的凭证
|
||||
exclude_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
|
||||
# 生成注册选项
|
||||
options = generate_registration_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
rp_name=PassKeyHelper.get_rp_name(),
|
||||
user_id=user_id_bytes,
|
||||
user_name=username,
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=uv_requirement,
|
||||
),
|
||||
supported_pub_key_algs=[
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
|
||||
]
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge(用于后续验证)
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成注册选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str:
|
||||
"""
|
||||
在 localhost 环境下获取并验证实际 Origin,否则返回默认值
|
||||
"""
|
||||
if not settings.APP_DOMAIN and rp_id == 'localhost':
|
||||
try:
|
||||
# 解析 clientDataJSON 获取实际的 origin
|
||||
client_data_json = json.loads(
|
||||
base64.urlsafe_b64decode(
|
||||
credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '=='
|
||||
).decode('utf-8')
|
||||
)
|
||||
actual_origin = client_data_json.get('origin', '')
|
||||
hostname = urlparse(actual_origin).hostname
|
||||
|
||||
if hostname in ['localhost', '127.0.0.1']:
|
||||
logger.info(f"本地环境,使用动态 origin: {actual_origin}")
|
||||
return actual_origin
|
||||
except Exception as e:
|
||||
logger.warning(f"无法提取动态 origin: {e}")
|
||||
return default_origin
|
||||
|
||||
@staticmethod
|
||||
def verify_registration_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str, int, Optional[str]]:
|
||||
"""
|
||||
验证注册响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (credential_id, public_key, sign_count, aaguid)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
# 解码challenge
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
|
||||
# 构建RegistrationCredential对象
|
||||
registration_credential = parse_registration_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证注册响应
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
# 提取信息
|
||||
credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=')
|
||||
public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=')
|
||||
sign_count = verification.sign_count
|
||||
# aaguid 可能已经是字符串格式,也可能是bytes
|
||||
if verification.aaguid:
|
||||
if isinstance(verification.aaguid, bytes):
|
||||
aaguid = verification.aaguid.hex()
|
||||
else:
|
||||
aaguid = str(verification.aaguid)
|
||||
else:
|
||||
aaguid = None
|
||||
|
||||
return credential_id, public_key, sign_count, aaguid
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证注册响应失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_authentication_options(
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None,
|
||||
user_verification: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成认证选项
|
||||
|
||||
:param existing_credentials: 已存在的凭证列表(用于限制可用凭证)
|
||||
:param user_verification: 用户验证要求,如果不指定则从配置中读取
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 允许的凭证
|
||||
allow_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
allow_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
|
||||
# 用户验证要求
|
||||
if not user_verification:
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
else:
|
||||
uv_requirement = UserVerificationRequirement(user_verification)
|
||||
|
||||
# 生成认证选项
|
||||
options = generate_authentication_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
user_verification=uv_requirement
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成认证选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def verify_authentication_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
credential_public_key: str,
|
||||
credential_current_sign_count: int,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
验证认证响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param credential_public_key: 凭证公钥
|
||||
:param credential_current_sign_count: 当前签名计数
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (验证成功, 新的签名计数)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
# 解码
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==')
|
||||
|
||||
# 构建AuthenticationCredential对象
|
||||
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证认证响应
|
||||
verification = verify_authentication_response(
|
||||
credential=authentication_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
credential_public_key=public_key_bytes,
|
||||
credential_current_sign_count=credential_current_sign_count,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
return True, verification.new_sign_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证认证响应失败: {e}")
|
||||
return False, credential_current_sign_count
|
||||
216
app/modules/discord/__init__.py
Normal file
216
app/modules/discord/__init__.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import json
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
except Exception as err: # ImportError or other load issues
|
||||
Discord = None
|
||||
logger.error(f"Discord 模块未加载,缺少依赖或初始化错误:{err}")
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
"""
|
||||
if not Discord:
|
||||
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
|
||||
return
|
||||
self.stop()
|
||||
super().init_service(service_name=Discord.__name__.lower(),
|
||||
service_type=Discord)
|
||||
self._channel = MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Discord"
|
||||
|
||||
@staticmethod
|
||||
def get_type() -> ModuleType:
|
||||
"""
|
||||
获取模块类型
|
||||
"""
|
||||
return ModuleType.Notification
|
||||
|
||||
@staticmethod
|
||||
def get_subtype() -> MessageChannel:
|
||||
"""
|
||||
获取模块子类型
|
||||
"""
|
||||
return MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_priority() -> int:
|
||||
"""
|
||||
获取模块优先级,数字越小优先级越高,只有同一接口下优先级才生效
|
||||
"""
|
||||
return 4
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止模块
|
||||
"""
|
||||
for client in self.get_instances().values():
|
||||
client.stop()
|
||||
|
||||
def test(self) -> Optional[Tuple[bool, str]]:
|
||||
"""
|
||||
测试模块连接性
|
||||
"""
|
||||
if not self.get_instances():
|
||||
return None
|
||||
for name, client in self.get_instances().items():
|
||||
state = client.get_state()
|
||||
if not state:
|
||||
return False, f"Discord {name} Bot 未就绪"
|
||||
return True, ""
|
||||
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
|
||||
"""
|
||||
解析消息内容,返回字典,注意以下约定值:
|
||||
userid: 用户ID
|
||||
username: 用户名
|
||||
text: 内容
|
||||
:param source: 消息来源
|
||||
:param body: 请求体
|
||||
:param form: 表单
|
||||
:param args: 参数
|
||||
:return: 渠道、消息体
|
||||
"""
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
try:
|
||||
msg_json: dict = json.loads(body)
|
||||
except Exception as e:
|
||||
logger.debug(f"解析 Discord 消息失败:{str(e)}")
|
||||
return None
|
||||
|
||||
if not msg_json:
|
||||
return None
|
||||
|
||||
msg_type = msg_json.get("type")
|
||||
userid = msg_json.get("userid")
|
||||
username = msg_json.get("username")
|
||||
|
||||
if msg_type == "interaction":
|
||||
callback_data = msg_json.get("callback_data")
|
||||
message_id = msg_json.get("message_id")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if callback_data and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}")
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=f"CALLBACK:{callback_data}",
|
||||
is_callback=True,
|
||||
callback_data=callback_data,
|
||||
message_id=message_id,
|
||||
chat_id=str(chat_id) if chat_id else None
|
||||
)
|
||||
return None
|
||||
|
||||
if msg_type == "message":
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if text and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}")
|
||||
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
|
||||
userid=userid, username=username, text=text,
|
||||
chat_id=str(chat_id) if chat_id else None)
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
:param message: 消息通知对象
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get('discord_userid')
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
:param message: 消息体
|
||||
:param medias: 媒体信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
|
||||
"""
|
||||
发送种子信息选择列表
|
||||
:param message: 消息体
|
||||
:param torrents: 种子信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_torrents_msg(title=message.title, torrents=torrents,
|
||||
userid=message.userid, buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def delete_message(self, channel: MessageChannel, source: str,
|
||||
message_id: str, chat_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 指定的消息源
|
||||
:param message_id: 消息ID(Slack中为时间戳)
|
||||
:param chat_id: 聊天ID(频道ID)
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.delete_msg(message_id=message_id, chat_id=chat_id)
|
||||
if result:
|
||||
success = True
|
||||
return success
|
||||
606
app/modules/discord/discord.py
Normal file
606
app/modules/discord/discord.py
Normal file
@@ -0,0 +1,606 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import NotificationType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
# Discord embed 字段解析白名单
|
||||
# 只有这些消息类型会使用复杂的字段解析逻辑
|
||||
PARSE_FIELD_TYPES = {
|
||||
NotificationType.Download, # 资源下载
|
||||
NotificationType.Organize, # 整理入库
|
||||
NotificationType.Subscribe, # 订阅
|
||||
NotificationType.Manual, # 手动处理
|
||||
}
|
||||
|
||||
|
||||
class Discord:
|
||||
"""
|
||||
Discord Bot 通知与交互实现(基于 discord.py 2.6.4)
|
||||
"""
|
||||
|
||||
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.messages = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client: Optional[discord.Client] = discord.Client(
|
||||
intents=intents,
|
||||
proxy=settings.PROXY_HOST
|
||||
)
|
||||
self._tree: Optional[app_commands.CommandTree] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
self._register_events()
|
||||
self._start()
|
||||
|
||||
@staticmethod
|
||||
def _to_int(val: Optional[Union[str, int]]) -> Optional[int]:
|
||||
try:
|
||||
return int(val) if val is not None and str(val).strip() else None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _register_events(self):
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
self._bot_user_id = self._client.user.id if self._client.user else None
|
||||
self._ready_event.set()
|
||||
logger.info(f"Discord Bot 已登录:{self._client.user}")
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: discord.Message):
|
||||
if message.author.bot:
|
||||
return
|
||||
if not self._should_process_message(message):
|
||||
return
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
payload = {
|
||||
"type": "message",
|
||||
"userid": str(message.author.id),
|
||||
"username": username,
|
||||
"user_tag": str(message.author),
|
||||
"text": cleaned_text,
|
||||
"message_id": str(message.id),
|
||||
"chat_id": str(message.channel.id),
|
||||
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@self._client.event
|
||||
async def on_interaction(interaction: discord.Interaction):
|
||||
if interaction.type == discord.InteractionType.component:
|
||||
data = interaction.data or {}
|
||||
callback_data = data.get("custom_id")
|
||||
if not callback_data:
|
||||
return
|
||||
try:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Discord 交互响应失败:{e}")
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
payload = {
|
||||
"type": "interaction",
|
||||
"userid": str(interaction.user.id) if interaction.user else None,
|
||||
"username": username,
|
||||
"user_tag": str(interaction.user) if interaction.user else None,
|
||||
"callback_data": callback_data,
|
||||
"message_id": str(interaction.message.id) if interaction.message else None,
|
||||
"chat_id": str(interaction.channel.id) if interaction.channel else None
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
def _start(self):
|
||||
if self._thread:
|
||||
return
|
||||
|
||||
def runner():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.create_task(self._client.start(self._token))
|
||||
self._loop.run_forever()
|
||||
except Exception as err:
|
||||
logger.error(f"Discord Bot 启动失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.run_until_complete(self._client.close())
|
||||
except Exception as err:
|
||||
logger.debug(f"Discord Bot 关闭失败:{err}")
|
||||
|
||||
self._thread = threading.Thread(target=runner, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
if not self._client or not self._loop or not self._thread:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
|
||||
except Exception as err:
|
||||
logger.error(f"关闭 Discord Bot 失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
except Exception as err:
|
||||
logger.error(f"停止 Discord 事件循环失败:{err}")
|
||||
self._ready_event.clear()
|
||||
|
||||
def get_state(self) -> bool:
|
||||
return self._ready_event.is_set() and self._client is not None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
if not title and not text:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return False
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype),
|
||||
self._loop)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not medias:
|
||||
return False
|
||||
title = title or "媒体列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_media_embeds(medias, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 媒体列表失败:{err}")
|
||||
return False
|
||||
|
||||
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not torrents:
|
||||
return False
|
||||
title = title or "种子列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_torrent_embeds(torrents, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 种子列表失败:{err}")
|
||||
return False
|
||||
|
||||
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._delete_message(message_id=message_id, chat_id=chat_id),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=15)
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
|
||||
userid: Optional[str], link: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional['NotificationType'] = None) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype)
|
||||
view = self._build_view(buttons=buttons, link=link)
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
return True
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
fallback_buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
view = self._build_view(buttons=buttons if buttons else fallback_buttons)
|
||||
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=None, embed=None, view=view, embeds=embeds)
|
||||
|
||||
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
|
||||
embeds=embeds if len(embeds) > 1 else None,
|
||||
view=view)
|
||||
return True
|
||||
|
||||
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
|
||||
content: Optional[str], embed: Optional[discord.Embed],
|
||||
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=str(chat_id))
|
||||
if not channel:
|
||||
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
kwargs: Dict[str, Any] = {"content": content, "view": view}
|
||||
if embeds:
|
||||
if len(embeds) == 1:
|
||||
kwargs["embed"] = embeds[0]
|
||||
else:
|
||||
kwargs["embeds"] = embeds
|
||||
elif embed:
|
||||
kwargs["embed"] = embed
|
||||
await message.edit(**kwargs)
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"编辑 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=chat_id)
|
||||
if not channel:
|
||||
logger.error("删除 Discord 消息时未找到频道")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
await message.delete()
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_embed(title: str, text: Optional[str], image: Optional[str],
|
||||
link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed:
|
||||
fields: List[Dict[str, str]] = []
|
||||
desc_lines: List[str] = []
|
||||
should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False
|
||||
def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]:
|
||||
spans: List[Tuple[int, int]] = []
|
||||
start = 0
|
||||
while True:
|
||||
l_idx = s.find(left, start)
|
||||
if l_idx == -1:
|
||||
break
|
||||
r_idx = s.find(right, l_idx + 1)
|
||||
if r_idx == -1:
|
||||
break
|
||||
spans.append((l_idx, r_idx))
|
||||
start = r_idx + 1
|
||||
return spans
|
||||
|
||||
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
|
||||
segment = s[m.start():m.end()]
|
||||
for i, ch in enumerate(segment):
|
||||
if ch in (":", ":"):
|
||||
return m.start() + i
|
||||
return None
|
||||
|
||||
if text:
|
||||
# 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符
|
||||
if "\\n" in text or "\\r" in text:
|
||||
text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n")
|
||||
if not should_parse_fields:
|
||||
desc_lines.append(text.strip())
|
||||
else:
|
||||
# 匹配形如 "字段:值" 的片段,字段名不允许包含常见分隔符;
|
||||
# 下一个字段需以顿号/逗号/分号等分隔开,且不能是 URL 协议开头,避免值里出现 URL 的":" 被误拆
|
||||
name_re = r"[A-Za-z0-9\u4e00-\u9fa5_\-&]+"
|
||||
pair_pattern = re.compile(
|
||||
rf"({name_re})[::](.*?)(?=(?:[,,。;;、]+\s*(?!https?://|ftp://|ftps://|magnet:){name_re}[::])|$)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
matches = list(pair_pattern.finditer(line))
|
||||
if matches:
|
||||
book_spans = _collect_spans(line, "《", "》") + _collect_spans(line, "【", "】")
|
||||
if book_spans:
|
||||
has_book_colon = False
|
||||
for m in matches:
|
||||
colon_idx = _find_colon_index(line, m)
|
||||
if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans):
|
||||
has_book_colon = True
|
||||
break
|
||||
if has_book_colon:
|
||||
desc_lines.append(line)
|
||||
continue
|
||||
# 若整行只是 URL/时间等自然包含":"的内容,则不当作字段
|
||||
url_like_names = {"http", "https", "ftp", "ftps", "magnet"}
|
||||
if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches):
|
||||
desc_lines.append(line)
|
||||
continue
|
||||
last_end = 0
|
||||
for m in matches:
|
||||
# 追加匹配前的非空文本到描述
|
||||
prefix = line[last_end:m.start()].strip(" ,,;;。、")
|
||||
# 仅当前缀不全是分隔符/空白时才记录
|
||||
if prefix and prefix.strip(" ,,;;。、"):
|
||||
desc_lines.append(prefix)
|
||||
name = m.group(1).strip()
|
||||
value = m.group(2).strip(" ,,;;。、\t") or "-"
|
||||
if name:
|
||||
fields.append({"name": name, "value": value, "inline": False})
|
||||
last_end = m.end()
|
||||
# 匹配末尾后的文本
|
||||
suffix = line[last_end:].strip(" ,,;;。、")
|
||||
if suffix and suffix.strip(" ,,;;。、"):
|
||||
desc_lines.append(suffix)
|
||||
else:
|
||||
desc_lines.append(line)
|
||||
description = "\n".join(desc_lines).strip()
|
||||
if not description and not fields and text:
|
||||
description = text.strip()
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
url=link or "https://github.com/jxxghp/MoviePilot",
|
||||
description=description if description else None,
|
||||
color=0xE67E22
|
||||
)
|
||||
for field in fields:
|
||||
embed.add_field(name=field["name"], value=field["value"], inline=False)
|
||||
if image:
|
||||
embed.set_image(url=image)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def _build_media_embeds(medias: List[MediaInfo], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, media in enumerate(medias[:10], start=1):
|
||||
overview = media.get_overview_string(80)
|
||||
desc_parts = [
|
||||
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
|
||||
overview
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {media.title_year}",
|
||||
url=media.detail_link or discord.Embed.Empty,
|
||||
description="\n".join([p for p in desc_parts if p]),
|
||||
color=0x5865F2
|
||||
)
|
||||
if media.get_poster_image():
|
||||
embed.set_thumbnail(url=media.get_poster_image())
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, context in enumerate(torrents[:10], start=1):
|
||||
torrent = context.torrent_info
|
||||
meta = MetaInfo(torrent.title, torrent.description)
|
||||
title_text = f"{meta.season_episode} {meta.resource_term} {meta.video_term} {meta.release_group}"
|
||||
title_text = re.sub(r"\s+", " ", title_text).strip()
|
||||
detail = [
|
||||
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}↑",
|
||||
meta.resource_term,
|
||||
meta.video_term
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {title_text or torrent.title}",
|
||||
url=torrent.page_url or discord.Embed.Empty,
|
||||
description="\n".join([d for d in detail if d]),
|
||||
color=0x00A86B
|
||||
)
|
||||
poster = getattr(torrent, "poster", None)
|
||||
if poster:
|
||||
embed.set_thumbnail(url=poster)
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_default_buttons(count: int) -> List[List[dict]]:
|
||||
buttons: List[List[dict]] = []
|
||||
max_rows = 5
|
||||
max_per_row = 5
|
||||
capped = min(count, max_rows * max_per_row)
|
||||
for idx in range(1, capped + 1):
|
||||
row_idx = (idx - 1) // max_per_row
|
||||
if len(buttons) <= row_idx:
|
||||
buttons.append([])
|
||||
buttons[row_idx].append({"text": f"选择 {idx}", "callback_data": str(idx)})
|
||||
if count > capped:
|
||||
logger.warn(f"按钮数量超过 Discord 限制,仅展示前 {capped} 个")
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
|
||||
has_buttons = buttons and any(buttons)
|
||||
if not has_buttons and not link:
|
||||
return None
|
||||
|
||||
view = discord.ui.View(timeout=None)
|
||||
if buttons:
|
||||
for row_index, button_row in enumerate(buttons[:5]):
|
||||
for button in button_row[:5]:
|
||||
if "url" in button:
|
||||
btn = discord.ui.Button(label=button.get("text", "链接"),
|
||||
url=button["url"],
|
||||
style=discord.ButtonStyle.link)
|
||||
else:
|
||||
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
|
||||
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
|
||||
custom_id=custom_id,
|
||||
style=discord.ButtonStyle.primary)
|
||||
view.add_item(btn)
|
||||
elif link:
|
||||
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
# 优先使用明确的聊天 ID
|
||||
if chat_id:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(chat_id))
|
||||
except Exception as err:
|
||||
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 私聊
|
||||
if userid:
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
return dm
|
||||
|
||||
# 配置的广播频道
|
||||
if self._broadcast_channel:
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
channel = await self._client.fetch_channel(self._channel_id)
|
||||
except Exception as err:
|
||||
logger.warn(f"通过配置的频道ID获取 Discord 频道失败:{err}")
|
||||
channel = None
|
||||
self._broadcast_channel = channel
|
||||
if channel:
|
||||
return channel
|
||||
|
||||
# 按 Guild 寻找一个可用文本频道
|
||||
target_guilds = []
|
||||
if self._guild_id:
|
||||
guild = self._client.get_guild(self._guild_id)
|
||||
if guild:
|
||||
target_guilds.append(guild)
|
||||
else:
|
||||
target_guilds = list(self._client.guilds)
|
||||
|
||||
for guild in target_guilds:
|
||||
for channel in guild.text_channels:
|
||||
if guild.me and channel.permissions_for(guild.me).send_messages:
|
||||
self._broadcast_channel = channel
|
||||
return channel
|
||||
return None
|
||||
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
if userid in self._user_dm_cache:
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
user_obj = self._client.get_user(int(userid)) or await self._client.fetch_user(int(userid))
|
||||
if not user_obj:
|
||||
return None
|
||||
dm = user_obj.dm_channel or await user_obj.create_dm()
|
||||
if dm:
|
||||
self._user_dm_cache[userid] = dm
|
||||
return dm
|
||||
except Exception as err:
|
||||
logger.error(f"获取 Discord 私聊失败:{err}")
|
||||
return None
|
||||
|
||||
def _should_process_message(self, message: discord.Message) -> bool:
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return True
|
||||
content = message.content or ""
|
||||
# 仅处理 @Bot 或斜杠命令
|
||||
if self._client.user and self._client.user.mentioned_in(message):
|
||||
return True
|
||||
if content.startswith("/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_bot_mention(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
if self._bot_user_id:
|
||||
mention_pattern = rf"<@!?{self._bot_user_id}>"
|
||||
content = re.sub(mention_pattern, "", content).strip()
|
||||
return content
|
||||
|
||||
async def _post_to_ds(self, payload: Dict[str, Any]) -> None:
|
||||
try:
|
||||
proxy = None
|
||||
if settings.PROXY:
|
||||
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
|
||||
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
|
||||
await client.post(self._ds_url, json=payload)
|
||||
except Exception as err:
|
||||
logger.error(f"转发 Discord 消息失败:{err}")
|
||||
@@ -15,9 +15,9 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
html_text = self._prepare_html_text(html_text)
|
||||
html = etree.HTML(html_text)
|
||||
try:
|
||||
tmps = html.xpath('//a[contains(@href, "user.php?id=")]')
|
||||
tmps = html.xpath('//a[contains(@href, "user.php?id=") or contains(@href, "user?id=")]')
|
||||
if tmps:
|
||||
user_id_match = re.search(r"user.php\?id=(\d+)", tmps[0].attrib['href'])
|
||||
user_id_match = re.search(r"user(?:\.php)?\?id=(\d+)", tmps[0].attrib['href'])
|
||||
if user_id_match and user_id_match.group().strip():
|
||||
self.userid = user_id_match.group(1)
|
||||
self._torrent_seeding_page = f"torrents.php?type=seeding&userid={self.userid}"
|
||||
@@ -42,13 +42,13 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
|
||||
self.ratio = 0.0 if self.download <= 0.0 else round(self.upload / self.download, 3)
|
||||
|
||||
tmps = html.xpath('//a[contains(@href, "bonus.php")]/@data-tooltip')
|
||||
tmps = html.xpath('//a[contains(@href, "bonus")]/@data-tooltip')
|
||||
if tmps:
|
||||
bonus_match = re.search(r"([\d,.]+)", tmps[0])
|
||||
if bonus_match and bonus_match.group(1).strip():
|
||||
self.bonus = StringUtils.str_float(bonus_match.group(1))
|
||||
else:
|
||||
tmps = html.xpath('//a[contains(@href, "bonus.php")]')
|
||||
tmps = html.xpath('//a[contains(@href, "bonus")]')
|
||||
if tmps:
|
||||
bonus_text = tmps[0].xpath("string(.)")
|
||||
bonus_match = re.search(r"([\d,.]+)", bonus_text)
|
||||
@@ -142,7 +142,7 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
|
||||
# 是否存在下页数据
|
||||
next_page = None
|
||||
next_page_text = html.xpath('//a[contains(.//text(), "Next") or contains(.//text(), "下一页")]/@href')
|
||||
next_page_text = html.xpath('//a[contains(.//text(), "Next") or contains(.//text(), "下一页") or contains(@title, "下一页") or contains(@title, "Next")]/@href')
|
||||
if next_page_text:
|
||||
next_page = next_page_text[-1].strip()
|
||||
finally:
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -25,6 +26,9 @@ class MTorrentSpider:
|
||||
_size = 100
|
||||
_searchurl = "https://api.%s/api/torrent/search"
|
||||
_downloadurl = "https://api.%s/api/torrent/genDlToken"
|
||||
_subtitle_list_url = "https://api.%s/api/subtitle/list"
|
||||
_subtitle_genlink_url = "https://api.%s/api/subtitle/genlink"
|
||||
_subtitle_download_url ="https://api.%s/api/subtitle/dlV2?credential=%s"
|
||||
_pageurl = "%sdetail/%s"
|
||||
_timeout = 15
|
||||
|
||||
@@ -114,24 +118,36 @@ class MTorrentSpider:
|
||||
labels_value = self._labels.get(result.get('labels') or "0") or ""
|
||||
if labels_value:
|
||||
labels = labels_value.split()
|
||||
status = result.get('status', {})
|
||||
torrent = {
|
||||
'title': result.get('name'),
|
||||
'description': result.get('smallDescr'),
|
||||
'enclosure': self.__get_download_url(result.get('id')),
|
||||
'pubdate': StringUtils.format_timestamp(result.get('createdDate')),
|
||||
'size': int(result.get('size') or '0'),
|
||||
'seeders': int(result.get('status', {}).get("seeders") or '0'),
|
||||
'peers': int(result.get('status', {}).get("leechers") or '0'),
|
||||
'grabs': int(result.get('status', {}).get("timesCompleted") or '0'),
|
||||
'downloadvolumefactor': self.__get_downloadvolumefactor(result.get('status', {}).get("discount")),
|
||||
'uploadvolumefactor': self.__get_uploadvolumefactor(result.get('status', {}).get("discount")),
|
||||
'seeders': int(status.get("seeders") or '0'),
|
||||
'peers': int(status.get("leechers") or '0'),
|
||||
'grabs': int(status.get("timesCompleted") or '0'),
|
||||
'downloadvolumefactor': self.__get_downloadvolumefactor(status.get("discount")),
|
||||
'uploadvolumefactor': self.__get_uploadvolumefactor(status.get("discount")),
|
||||
'page_url': self._pageurl % (self._url, result.get('id')),
|
||||
'imdbid': self.__find_imdbid(result.get('imdb')),
|
||||
'labels': labels,
|
||||
'category': category
|
||||
}
|
||||
if discount_end_time := (result.get('status') or {}).get('discountEndTime'):
|
||||
if discount_end_time := status.get('discountEndTime'):
|
||||
torrent['freedate'] = StringUtils.format_timestamp(discount_end_time)
|
||||
# 解析全站促销时的规则(当前馒头只有下载促销)
|
||||
if promotion_rule := status.get("promotionRule"):
|
||||
discount = promotion_rule.get("discount", "NORMAL")
|
||||
torrent["downloadvolumefactor"] = self.__get_downloadvolumefactor(discount)
|
||||
if end_time := promotion_rule.get("endTime"):
|
||||
torrent["freedate"] = StringUtils.format_timestamp(end_time)
|
||||
if mall_single_free := status.get("mallSingleFree"):
|
||||
if mall_single_free.get("status") == "ONGOING":
|
||||
torrent["downloadvolumefactor"] = self.__get_downloadvolumefactor("FREE")
|
||||
if end_date := mall_single_free.get("endDate"):
|
||||
torrent["freedate"] = StringUtils.format_timestamp(end_date)
|
||||
torrents.append(torrent)
|
||||
return torrents
|
||||
|
||||
@@ -262,3 +278,110 @@ class MTorrentSpider:
|
||||
# base64编码
|
||||
base64_str = base64.b64encode(json.dumps(params).encode('utf-8')).decode('utf-8')
|
||||
return f"[{base64_str}]{url}"
|
||||
|
||||
def get_subtitle_links(self, page_url: str) -> List[str]:
|
||||
"""
|
||||
获取指定页面的字幕下载链接
|
||||
|
||||
:param page_url: 种子详情页网址
|
||||
:type page_url: str
|
||||
:return: 字幕下载链接
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if not page_url:
|
||||
return []
|
||||
# 从馒头的详情页网址中提取种子id
|
||||
torrent_id = urlparse(page_url).path.rsplit("/", 1)[-1].strip()
|
||||
if not torrent_id:
|
||||
return []
|
||||
return self.get_subtitle_links_by_id(torrent_id)
|
||||
|
||||
def get_subtitle_links_by_id(self, torrent_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定种子的字幕下载链接
|
||||
|
||||
:param torrent_id: 种子ID
|
||||
:type torrent_id: str
|
||||
:return: 字幕下载链接
|
||||
:rtype: List[str]
|
||||
"""
|
||||
results = []
|
||||
try:
|
||||
for subtitle_id in self.__subtitle_ids(torrent_id) or []:
|
||||
if link := self.__subtitle_genlink(subtitle_id):
|
||||
results.append(link)
|
||||
except Exception as e:
|
||||
logger.error(f"{self._name} 获取字幕失败:{e}")
|
||||
return results
|
||||
|
||||
def __subtitle_ids(self, torrent_id: str) -> Optional[List[str]]:
|
||||
"""
|
||||
获取指定种子的字幕列表
|
||||
|
||||
:param torrent_id: 种子ID
|
||||
:type torrent_id: str
|
||||
:return: 字幕ID
|
||||
:rtype: List[str] | None
|
||||
"""
|
||||
url = self._subtitle_list_url % self._domain
|
||||
# 发送请求
|
||||
res = RequestUtils(
|
||||
headers={
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"User-Agent": f"{self._ua}",
|
||||
"x-api-key": self._apikey,
|
||||
},
|
||||
proxies=self._proxy,
|
||||
timeout=self._timeout,
|
||||
).post_res(url, data={"id": torrent_id})
|
||||
if res and res.status_code == 200:
|
||||
result = res.json()
|
||||
if int(result.get("code", -1)) == 0:
|
||||
return [item["id"] for item in result.get("data", []) if "id" in item]
|
||||
else:
|
||||
logger.warn(
|
||||
f"{self._name} 获取字幕列表失败,返回:{result.get("message", "未知")}"
|
||||
)
|
||||
return None
|
||||
elif res is not None:
|
||||
logger.warn(f"{self._name} 获取字幕列表失败,错误码:{res.status_code}")
|
||||
return None
|
||||
else:
|
||||
logger.warn(f"{self._name} 获取字幕列表失败,无法连接 {self._domain}")
|
||||
return None
|
||||
|
||||
def __subtitle_genlink(self, subtitle_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取字幕下载链接
|
||||
|
||||
:param subtitle_id: 字幕ID
|
||||
:type subtitle_id: str
|
||||
:return: 下载链接
|
||||
:rtype: str | None
|
||||
"""
|
||||
url = self._subtitle_genlink_url % self._domain
|
||||
# 发送请求
|
||||
res = RequestUtils(
|
||||
headers={
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"User-Agent": f"{self._ua}",
|
||||
"x-api-key": self._apikey,
|
||||
},
|
||||
proxies=self._proxy,
|
||||
timeout=self._timeout,
|
||||
).post_res(url, data={"id": subtitle_id})
|
||||
if res and res.status_code == 200:
|
||||
result = res.json()
|
||||
if int(result.get("code", -1)) == 0 and isinstance(result.get("data"), str):
|
||||
return self._subtitle_download_url % (self._domain, result["data"])
|
||||
else:
|
||||
logger.warn(
|
||||
f"{self._name} 获取字幕下载链接失败,返回:{result.get("message", "未知")}"
|
||||
)
|
||||
return None
|
||||
elif res is not None:
|
||||
logger.warn(f"{self._name} 获取字幕下载链接失败,错误码:{res.status_code}")
|
||||
return None
|
||||
else:
|
||||
logger.warn(f"{self._name} 获取字幕下载链接失败,无法连接 {self._domain}")
|
||||
return None
|
||||
|
||||
@@ -124,12 +124,12 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
return None, None, None, "下载内容为空"
|
||||
|
||||
# 读取种子的名称
|
||||
torrent, content = __get_torrent_info()
|
||||
torrent_from_file, content = __get_torrent_info()
|
||||
# 检查是否为磁力链接
|
||||
is_magnet = isinstance(content, str) and content.startswith("magnet:") or isinstance(content,
|
||||
bytes) and content.startswith(
|
||||
b"magnet:")
|
||||
if not torrent and not is_magnet:
|
||||
if not torrent_from_file and not is_magnet:
|
||||
return None, None, None, f"添加种子任务失败:无法读取种子文件"
|
||||
|
||||
# 获取下载器
|
||||
@@ -170,8 +170,8 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
try:
|
||||
for torrent in torrents:
|
||||
# 名称与大小相等则认为是同一个种子
|
||||
if torrent.get("name") == torrent.name \
|
||||
and torrent.get("total_size") == torrent.total_size:
|
||||
if torrent.get("name") == getattr(torrent_from_file, 'name', '') \
|
||||
and torrent.get("total_size") == getattr(torrent_from_file, 'total_size', 0):
|
||||
torrent_hash = torrent.get("hash")
|
||||
torrent_tags = [str(tag).strip() for tag in torrent.get("tags").split(',')]
|
||||
logger.warn(f"下载器中已存在该种子任务:{torrent_hash} - {torrent.get('name')}")
|
||||
|
||||
@@ -8,9 +8,13 @@ from lxml import etree
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase
|
||||
from app.modules.indexer.spider.mtorrent import MTorrentSpider
|
||||
from app.schemas import TorrentInfo
|
||||
from app.schemas.file import FileURI
|
||||
from app.schemas.types import ModuleType, OtherModulesType
|
||||
from app.utils.http import RequestUtils
|
||||
@@ -25,7 +29,9 @@ class SubtitleModule(_ModuleBase):
|
||||
|
||||
# 站点详情页字幕下载链接识别XPATH
|
||||
_SITE_SUBTITLE_XPATH = [
|
||||
'//td[@class="rowhead"][text()="字幕"]/following-sibling::td//a[not(@class)]/@href',
|
||||
'//td[@class="rowhead"][text()="字幕"]/following-sibling::td//a/@href',
|
||||
'//div[contains(@class, "font-bold")][text()="字幕"]/following-sibling::div[1]//a[not(@class)]/@href', # 憨憨
|
||||
]
|
||||
|
||||
def init_module(self) -> None:
|
||||
@@ -65,6 +71,58 @@ class SubtitleModule(_ModuleBase):
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
def _get_subtitle_links(self, torrent: TorrentInfo):
|
||||
"""
|
||||
获取字幕链接
|
||||
"""
|
||||
# API请求方式的站点需要特殊处理
|
||||
if torrent.site is not None:
|
||||
site = SiteOper().get(torrent.site)
|
||||
if indexer := SitesHelper().get_indexer(site.domain):
|
||||
if indexer.get("parser") == "mTorrent":
|
||||
return MTorrentSpider(indexer).get_subtitle_links(
|
||||
torrent.page_url
|
||||
)
|
||||
# TODO 其它采用API访问的站点
|
||||
# 普通站点通过解析网站代码的方式获取
|
||||
request = RequestUtils(
|
||||
cookies=torrent.site_cookie,
|
||||
ua=torrent.site_ua,
|
||||
proxies=settings.PROXY if torrent.site_proxy else None,
|
||||
)
|
||||
res = request.get_res(torrent.page_url)
|
||||
if res and res.status_code == 200:
|
||||
if not res.text:
|
||||
logger.warn(f"读取页面代码失败:{torrent.page_url}")
|
||||
return []
|
||||
html = etree.HTML(res.text)
|
||||
try:
|
||||
sublink_list = []
|
||||
for xpath in self._SITE_SUBTITLE_XPATH:
|
||||
sublinks = html.xpath(xpath)
|
||||
if sublinks:
|
||||
for sublink in sublinks:
|
||||
if not sublink:
|
||||
continue
|
||||
if not sublink.startswith("http"):
|
||||
base_url = StringUtils.get_base_url(torrent.page_url)
|
||||
if sublink.startswith("/"):
|
||||
sublink = "%s%s" % (base_url, sublink)
|
||||
else:
|
||||
sublink = "%s/%s" % (base_url, sublink)
|
||||
sublink_list.append(sublink)
|
||||
# 已成功获取了链接,后续xpath可以忽略
|
||||
break
|
||||
return sublink_list
|
||||
finally:
|
||||
if html is not None:
|
||||
del html
|
||||
elif res is not None:
|
||||
logger.warn(f"连接 {torrent.page_url} 失败,状态码:{res.status_code}")
|
||||
else:
|
||||
logger.warn(f"无法打开链接:{torrent.page_url}")
|
||||
return None
|
||||
|
||||
def download_added(self, context: Context, download_dir: Path, torrent_content: Union[str, bytes] = None):
|
||||
"""
|
||||
添加下载任务成功后,从站点下载字幕,保存到下载目录
|
||||
@@ -117,83 +175,60 @@ class SubtitleModule(_ModuleBase):
|
||||
logger.error(f"下载目录不存在,无法保存字幕:{download_dir / folder_name}")
|
||||
return
|
||||
# 读取网站代码
|
||||
request = RequestUtils(cookies=torrent.site_cookie, ua=torrent.site_ua)
|
||||
res = request.get_res(torrent.page_url)
|
||||
if res and res.status_code == 200:
|
||||
if not res.text:
|
||||
logger.warn(f"读取页面代码失败:{torrent.page_url}")
|
||||
return
|
||||
html = etree.HTML(res.text)
|
||||
try:
|
||||
sublink_list = []
|
||||
for xpath in self._SITE_SUBTITLE_XPATH:
|
||||
sublinks = html.xpath(xpath)
|
||||
if sublinks:
|
||||
for sublink in sublinks:
|
||||
if not sublink:
|
||||
continue
|
||||
if not sublink.startswith("http"):
|
||||
base_url = StringUtils.get_base_url(torrent.page_url)
|
||||
if sublink.startswith("/"):
|
||||
sublink = "%s%s" % (base_url, sublink)
|
||||
else:
|
||||
sublink = "%s/%s" % (base_url, sublink)
|
||||
sublink_list.append(sublink)
|
||||
finally:
|
||||
if html is not None:
|
||||
del html
|
||||
# 下载所有字幕文件
|
||||
for sublink in sublink_list:
|
||||
logger.info(f"找到字幕下载链接:{sublink},开始下载...")
|
||||
# 下载
|
||||
ret = request.get_res(sublink)
|
||||
if ret and ret.status_code == 200:
|
||||
# 保存ZIP
|
||||
file_name = TorrentHelper.get_url_filename(ret, sublink)
|
||||
if not file_name:
|
||||
logger.warn(f"链接不是字幕文件:{sublink}")
|
||||
continue
|
||||
if file_name.lower().endswith(".zip"):
|
||||
# ZIP包
|
||||
zip_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
zip_file.write_bytes(ret.content)
|
||||
# 解压路径
|
||||
zip_path = zip_file.with_name(zip_file.stem)
|
||||
# 解压文件
|
||||
shutil.unpack_archive(zip_file, zip_path, format='zip')
|
||||
# 遍历转移文件
|
||||
for sub_file in SystemUtils.list_files(zip_path, settings.RMT_SUBEXT):
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
# 删除临时文件
|
||||
try:
|
||||
shutil.rmtree(zip_path)
|
||||
zip_file.unlink()
|
||||
except Exception as err:
|
||||
logger.error(f"删除临时文件失败:{str(err)}")
|
||||
else:
|
||||
sub_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
sub_file.write_bytes(ret.content)
|
||||
sublink_list = self._get_subtitle_links(torrent)
|
||||
if not sublink_list:
|
||||
logger.warn(f"{torrent.page_url} 页面未找到字幕下载链接")
|
||||
return
|
||||
# 下载所有字幕文件
|
||||
request = RequestUtils(
|
||||
cookies=torrent.site_cookie,
|
||||
ua=torrent.site_ua,
|
||||
proxies=settings.PROXY if torrent.site_proxy else None,
|
||||
)
|
||||
for sublink in sublink_list:
|
||||
logger.info(f"找到字幕下载链接:{sublink},开始下载...")
|
||||
# 下载
|
||||
ret = request.get_res(sublink)
|
||||
if ret and ret.status_code == 200:
|
||||
# 保存ZIP
|
||||
file_name = TorrentHelper.get_url_filename(ret, sublink)
|
||||
if not file_name:
|
||||
logger.warn(f"链接不是字幕文件:{sublink}")
|
||||
continue
|
||||
if file_name.lower().endswith(".zip"):
|
||||
# ZIP包
|
||||
zip_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
zip_file.write_bytes(ret.content)
|
||||
# 解压路径
|
||||
zip_path = zip_file.with_name(zip_file.stem)
|
||||
# 解压文件
|
||||
shutil.unpack_archive(zip_file, zip_path, format='zip')
|
||||
# 遍历转移文件
|
||||
for sub_file in SystemUtils.list_files(zip_path, settings.RMT_SUBEXT):
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
# 删除临时文件
|
||||
try:
|
||||
shutil.rmtree(zip_path)
|
||||
zip_file.unlink()
|
||||
except Exception as err:
|
||||
logger.error(f"删除临时文件失败:{str(err)}")
|
||||
else:
|
||||
logger.error(f"下载字幕文件失败:{sublink}")
|
||||
continue
|
||||
if sublink_list:
|
||||
logger.info(f"{torrent.page_url} 页面字幕下载完成")
|
||||
sub_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
sub_file.write_bytes(ret.content)
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
else:
|
||||
logger.warn(f"{torrent.page_url} 页面未找到字幕下载链接")
|
||||
elif res is not None:
|
||||
logger.warn(f"连接 {torrent.page_url} 失败,状态码:{res.status_code}")
|
||||
else:
|
||||
logger.warn(f"无法打开链接:{torrent.page_url}")
|
||||
logger.error(f"下载字幕文件失败:{sublink}")
|
||||
continue
|
||||
logger.info(f"{torrent.page_url} 页面字幕下载完成")
|
||||
|
||||
@@ -78,7 +78,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
"""
|
||||
测试模块连接性
|
||||
"""
|
||||
ret = RequestUtils(proxies=settings.PROXY).get_res(
|
||||
ret = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(
|
||||
f"https://{settings.TMDB_API_DOMAIN}/3/movie/550?api_key={settings.TMDB_API_KEY}")
|
||||
if ret and ret.status_code == 200:
|
||||
return True, ""
|
||||
|
||||
@@ -125,12 +125,12 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
return None, None, None, "下载内容为空"
|
||||
|
||||
# 读取种子的名称
|
||||
torrent, content = __get_torrent_info()
|
||||
torrent_from_file, content = __get_torrent_info()
|
||||
# 检查是否为磁力链接
|
||||
is_magnet = isinstance(content, str) and content.startswith("magnet:") or isinstance(content,
|
||||
bytes) and content.startswith(
|
||||
b"magnet:")
|
||||
if not torrent and not is_magnet:
|
||||
if not torrent_from_file and not is_magnet:
|
||||
return None, None, None, f"添加种子任务失败:无法读取种子文件"
|
||||
|
||||
# 获取下载器
|
||||
@@ -149,7 +149,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
else:
|
||||
labels = None
|
||||
# 添加任务
|
||||
torrent = server.add_torrent(
|
||||
added_torrent = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
@@ -159,7 +159,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
# TR 始终使用原始种子布局, 返回"Original"
|
||||
torrent_layout = "Original"
|
||||
|
||||
if not torrent:
|
||||
if not added_torrent:
|
||||
# 查询所有下载器的种子
|
||||
torrents, error = server.get_torrents()
|
||||
if error:
|
||||
@@ -168,7 +168,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
try:
|
||||
for torrent in torrents:
|
||||
# 名称与大小相等则认为是同一个种子
|
||||
if torrent.name == torrent.name and torrent.total_size == torrent.total_size:
|
||||
if torrent.name == getattr(torrent_from_file, 'name', '') and torrent.total_size == getattr(torrent_from_file, 'total_size', 0):
|
||||
torrent_hash = torrent.hashString
|
||||
logger.warn(f"下载器中已存在该种子任务:{torrent_hash} - {torrent.name}")
|
||||
# 给种子打上标签
|
||||
@@ -189,7 +189,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
del torrents
|
||||
return None, None, None, f"添加种子任务失败:{content}"
|
||||
else:
|
||||
torrent_hash = torrent.hashString
|
||||
torrent_hash = added_torrent.hashString
|
||||
if is_paused:
|
||||
# 选择文件
|
||||
torrent_files = server.get_files(torrent_hash)
|
||||
|
||||
@@ -221,6 +221,22 @@ class ChannelCapabilityManager:
|
||||
max_button_text_length=25,
|
||||
fallback_enabled=True
|
||||
),
|
||||
MessageChannel.Discord: ChannelCapabilities(
|
||||
channel=MessageChannel.Discord,
|
||||
capabilities={
|
||||
ChannelCapability.INLINE_BUTTONS,
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MESSAGE_DELETION,
|
||||
ChannelCapability.CALLBACK_QUERIES,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS
|
||||
},
|
||||
max_buttons_per_row=5,
|
||||
max_button_rows=5,
|
||||
max_button_text_length=80,
|
||||
fallback_enabled=True
|
||||
),
|
||||
MessageChannel.SynologyChat: ChannelCapabilities(
|
||||
channel=MessageChannel.SynologyChat,
|
||||
capabilities={
|
||||
|
||||
@@ -21,7 +21,7 @@ class Token(BaseModel):
|
||||
# 详细权限
|
||||
permissions: Optional[dict] = Field(default_factory=dict)
|
||||
# 是否显示配置向导
|
||||
widzard: Optional[bool] = None
|
||||
wizard: Optional[bool] = None
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
|
||||
@@ -265,6 +265,7 @@ class MessageChannel(Enum):
|
||||
Wechat = "微信"
|
||||
Telegram = "Telegram"
|
||||
Slack = "Slack"
|
||||
Discord = "Discord"
|
||||
SynologyChat = "SynologyChat"
|
||||
VoceChat = "VoceChat"
|
||||
Web = "Web"
|
||||
|
||||
@@ -94,6 +94,7 @@ COPY --from=prepare_venv --chmod=777 ${VENV_PATH} ${VENV_PATH}
|
||||
|
||||
# playwright 环境
|
||||
RUN playwright install-deps chromium \
|
||||
&& playwright install-deps firefox \
|
||||
&& apt-get autoremove -y \
|
||||
&& apt-get clean \
|
||||
&& rm -rf \
|
||||
|
||||
@@ -231,9 +231,9 @@ chown moviepilot:moviepilot /etc/hosts /tmp
|
||||
|
||||
# 下载浏览器内核
|
||||
if [[ "$HTTPS_PROXY" =~ ^https?:// ]] || [[ "$HTTPS_PROXY" =~ ^https?:// ]] || [[ "$PROXY_HOST" =~ ^https?:// ]]; then
|
||||
HTTPS_PROXY="${HTTPS_PROXY:-${https_proxy:-$PROXY_HOST}}" gosu moviepilot:moviepilot playwright install chromium
|
||||
HTTPS_PROXY="${HTTPS_PROXY:-${https_proxy:-$PROXY_HOST}}" gosu moviepilot:moviepilot playwright install ${PLAYWRIGHT_BROWSER_TYPE:-chromium}
|
||||
else
|
||||
gosu moviepilot:moviepilot playwright install chromium
|
||||
gosu moviepilot:moviepilot playwright install ${PLAYWRIGHT_BROWSER_TYPE:-chromium}
|
||||
fi
|
||||
|
||||
# 证书管理
|
||||
|
||||
213
docs/mcp-api.md
213
docs/mcp-api.md
@@ -1,9 +1,77 @@
|
||||
# MoviePilot 工具API文档
|
||||
# MoviePilot MCP (Model Context Protocol) API 文档
|
||||
|
||||
MoviePilot的智能体工具已通过HTTP API暴露,可以通过RESTful API调用所有工具。
|
||||
MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智能体(如 Claude, GPT 等)直接调用 MoviePilot 的功能进行媒体管理、搜索、订阅和下载。
|
||||
|
||||
## API端点
|
||||
## 1. 基础信息
|
||||
|
||||
* **基础路径**: `/api/v1/mcp`
|
||||
* **协议版本**: `2025-11-25, 2025-06-18, 2024-11-05`
|
||||
* **传输协议**: HTTP (JSON-RPC 2.0)
|
||||
* **认证方式**:
|
||||
* Header: `X-API-KEY: <你的API_KEY>`
|
||||
* Query: `?apikey=<你的API_KEY>`
|
||||
|
||||
## 2. 标准 MCP 协议 (JSON-RPC 2.0)
|
||||
|
||||
### 端点
|
||||
**POST** `/api/v1/mcp`
|
||||
|
||||
### 支持的方法
|
||||
- `initialize`: 初始化会话,协商协议版本和能力。
|
||||
- `notifications/initialized`: 客户端确认初始化完成。
|
||||
- `tools/list`: 获取可用工具列表。
|
||||
- `tools/call`: 调用特定工具。
|
||||
- `ping`: 连接存活检测。
|
||||
|
||||
---
|
||||
|
||||
## 4. 客户端配置示例
|
||||
|
||||
### Claude Desktop (Anthropic)
|
||||
|
||||
在Claude Desktop的配置文件中添加MoviePilot的MCP服务器配置:
|
||||
|
||||
**macOS**: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
**Windows**: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
|
||||
使用请求头方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp",
|
||||
"headers": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
或使用查询参数方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 错误码说明
|
||||
|
||||
| 错误码 | 消息 | 说明 |
|
||||
| :--- | :--- | :--- |
|
||||
| -32700 | Parse error | JSON 格式错误 |
|
||||
| -32600 | Invalid Request | 无效的 JSON-RPC 请求 |
|
||||
| -32601 | Method not found | 方法不存在 |
|
||||
| -32602 | Invalid params | 参数验证失败 |
|
||||
| -32002 | Session not found | 会话不存在或已过期 |
|
||||
| -32003 | Not initialized | 会话未完成初始化流程 |
|
||||
| -32603 | Internal error | 服务器内部错误 |
|
||||
|
||||
## 6. RESTful API
|
||||
所有工具相关的API端点都在 `/api/v1/mcp` 路径下(保持向后兼容)。
|
||||
|
||||
### 1. 列出所有工具
|
||||
@@ -137,142 +205,3 @@ MoviePilot的智能体工具已通过HTTP API暴露,可以通过RESTful API调
|
||||
"required": ["title", "year", "media_type"]
|
||||
}
|
||||
```
|
||||
|
||||
## MCP客户端配置
|
||||
|
||||
MoviePilot的MCP工具可以通过HTTP协议在支持MCP的客户端中使用。以下是常见MCP客户端的配置方法:
|
||||
|
||||
### Claude Desktop (Anthropic)
|
||||
|
||||
在Claude Desktop的配置文件中添加MoviePilot的MCP服务器配置:
|
||||
|
||||
**macOS**: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
**Windows**: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-http",
|
||||
"http://localhost:3001/api/v1/mcp"
|
||||
],
|
||||
"env": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**注意**: 如果MCP HTTP服务器不支持环境变量传递API Key,可以使用查询参数方式:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-http",
|
||||
"http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 其他支持MCP的聊天客户端
|
||||
|
||||
对于其他支持MCP协议的聊天客户端(如其他AI聊天助手、对话机器人等),通常可以通过配置文件或设置界面添加HTTP协议的MCP服务器。配置格式可能因客户端而异,但通常需要以下信息:
|
||||
|
||||
**配置参数**:
|
||||
1. **服务器类型**: HTTP
|
||||
2. **服务器地址**: `http://your-moviepilot-host:3001/api/v1/mcp`
|
||||
3. **认证方式**:
|
||||
- 在HTTP请求头中添加 `X-API-KEY: <your_api_key>`
|
||||
- 或在URL查询参数中添加 `apikey=<your_api_key>`
|
||||
|
||||
**示例配置**(通用格式):
|
||||
|
||||
使用请求头方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp",
|
||||
"headers": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
或使用查询参数方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**支持的端点**:
|
||||
- `GET /tools` - 列出所有工具
|
||||
- `POST /tools/call` - 调用工具
|
||||
- `GET /tools/{tool_name}` - 获取工具详情
|
||||
- `GET /tools/{tool_name}/schema` - 获取工具参数Schema
|
||||
|
||||
配置完成后,您就可以在聊天对话中使用MoviePilot的各种工具,例如:
|
||||
- 添加媒体订阅
|
||||
- 查询下载历史
|
||||
- 搜索媒体资源
|
||||
- 管理媒体服务器
|
||||
- 等等...
|
||||
|
||||
### 获取API Key
|
||||
|
||||
API Key可以在MoviePilot的系统设置中生成和查看。请妥善保管您的API Key,不要泄露给他人。
|
||||
|
||||
## 认证
|
||||
|
||||
所有MCP API端点都需要认证。**仅支持API Key认证方式**:
|
||||
|
||||
- **请求头方式**: 在请求头中添加 `X-API-KEY: <api_key>`
|
||||
- **查询参数方式**: 在URL查询参数中添加 `apikey=<api_key>`
|
||||
|
||||
**获取API Key**: 在MoviePilot系统设置中生成和查看API Key。请妥善保管您的API Key,不要泄露给他人。
|
||||
|
||||
## 错误处理
|
||||
|
||||
API会返回标准的HTTP状态码:
|
||||
|
||||
- `200 OK`: 请求成功
|
||||
- `400 Bad Request`: 请求参数错误
|
||||
- `401 Unauthorized`: 未认证或API Key无效
|
||||
- `404 Not Found`: 工具不存在
|
||||
- `500 Internal Server Error`: 服务器内部错误
|
||||
|
||||
错误响应格式:
|
||||
```json
|
||||
{
|
||||
"detail": "错误描述信息"
|
||||
}
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
|
||||
工具API通过FastAPI端点暴露,使用HTTP协议与客户端通信。所有工具共享相同的实现,确保功能一致性。
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **用户上下文**: API调用会使用当前认证用户的ID作为工具执行的用户上下文
|
||||
2. **会话隔离**: 每个API请求使用独立的会话ID
|
||||
3. **参数验证**: 工具参数会根据JSON Schema进行验证
|
||||
4. **错误日志**: 所有工具调用错误都会记录到MoviePilot日志系统
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ cf_clearance~=0.31.0
|
||||
torrentool~=1.2.0
|
||||
slack-bolt~=1.23.0
|
||||
slack-sdk~=3.35.0
|
||||
discord.py==2.6.4
|
||||
chardet~=5.2.0
|
||||
starlette~=0.46.2
|
||||
PyVirtualDisplay~=3.0
|
||||
@@ -61,6 +62,7 @@ cachetools~=6.1.0
|
||||
fast-bencode~=1.1.7
|
||||
pystray~=0.19.5
|
||||
pyotp~=2.9.0
|
||||
webauthn~=2.7.0
|
||||
Pinyin2Hanzi~=0.1.1
|
||||
pywebpush~=2.0.3
|
||||
aiopathlib~=0.6.0
|
||||
@@ -88,4 +90,4 @@ langchain-google-genai~=2.0.10
|
||||
langchain-deepseek~=0.1.4
|
||||
langchain-experimental~=0.3.4
|
||||
openai~=1.108.2
|
||||
google-generativeai~=0.8.5
|
||||
google-generativeai~=0.8.5
|
||||
|
||||
@@ -1117,4 +1117,19 @@ meta_cases = [{
|
||||
"audio_codec": "",
|
||||
"tmdbid": 19995
|
||||
}
|
||||
}, {
|
||||
"path": "/movies/DouBan_IMDB.TOP250.Movies.Mixed.Collection.20240501.FRDS/为奴十二年.12.Years.a.Slave.2013.BluRay.1080p.x265.10bit.2Audio.MNHD-FRDS/12.Years.a.Slave.2013.BluRay.1080p.x265.10bit.2Audio.MNHD-FRDS.mkv",
|
||||
"target": {
|
||||
"type": "未知",
|
||||
"cn_name": "",
|
||||
"en_name": "12 Years A Slave",
|
||||
"year": "2013",
|
||||
"part": "",
|
||||
"season": "",
|
||||
"episode": "",
|
||||
"restype": "BluRay",
|
||||
"pix": "1080p",
|
||||
"video_codec": "x265 10bit",
|
||||
"audio_codec": "2Audio"
|
||||
}
|
||||
}]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.8.8'
|
||||
FRONTEND_VERSION = 'v2.8.8'
|
||||
APP_VERSION = 'v2.9.2'
|
||||
FRONTEND_VERSION = 'v2.9.2'
|
||||
|
||||
Reference in New Issue
Block a user