mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-14 07:26:50 +00:00
Compare commits
139 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
274fc2d74f | ||
|
|
2f1a448afe | ||
|
|
99cab7c337 | ||
|
|
81f7548579 | ||
|
|
6ebd50bebc | ||
|
|
378ba51f4d | ||
|
|
63a890e85d | ||
|
|
bf4f9921e2 | ||
|
|
167ae65695 | ||
|
|
2affa7c9b8 | ||
|
|
785540e178 | ||
|
|
bcad4c0bc6 | ||
|
|
5af217fbf5 | ||
|
|
128aa2ef23 | ||
|
|
fce1186dd1 | ||
|
|
9a7b11f804 | ||
|
|
b068a06fa8 | ||
|
|
931a42e981 | ||
|
|
e0a20a6697 | ||
|
|
1ef4374899 | ||
|
|
3b7212740b | ||
|
|
4b80b8dc1f | ||
|
|
b7f24827e6 | ||
|
|
1c08a22881 | ||
|
|
8bd848519d | ||
|
|
e19f2aa76d | ||
|
|
4a99e2896f | ||
|
|
de3c83b0aa | ||
|
|
36bdb831be | ||
|
|
1809690915 | ||
|
|
e51b679380 | ||
|
|
10c26de7cb | ||
|
|
ca5ec8af0f | ||
|
|
d1d7b8ce55 | ||
|
|
77f8983307 | ||
|
|
ba415acd37 | ||
|
|
bcf13099ac | ||
|
|
eb2b34d71c | ||
|
|
d0b665f773 | ||
|
|
a1674b1ae5 | ||
|
|
af83681f6a | ||
|
|
bebacf7b20 | ||
|
|
6dc1fcbc3e | ||
|
|
b599ef4509 | ||
|
|
526b6a1119 | ||
|
|
88173db4ce | ||
|
|
e139b1ab22 | ||
|
|
6c1e0058c1 | ||
|
|
c96633eb83 | ||
|
|
91eb35a77b | ||
|
|
d749d59cad | ||
|
|
80396b4d30 | ||
|
|
64b93a009c | ||
|
|
2b32250504 | ||
|
|
9b5f863832 | ||
|
|
fd422d7446 | ||
|
|
5162b2748e | ||
|
|
56c684ec06 | ||
|
|
7e93b33407 | ||
|
|
7662235802 | ||
|
|
e41f9facc7 | ||
|
|
785b8ede11 | ||
|
|
78b198ad70 | ||
|
|
c2c0515991 | ||
|
|
b97fefdb8d | ||
|
|
840da6dd85 | ||
|
|
972d916126 | ||
|
|
e3ed065f5f | ||
|
|
760ebe6113 | ||
|
|
a329d3ad89 | ||
|
|
01f8561582 | ||
|
|
883ea5c996 | ||
|
|
99cf13ed9b | ||
|
|
91c7ef6801 | ||
|
|
84ef5705e7 | ||
|
|
cf2a0cf8c2 | ||
|
|
48c25c40e4 | ||
|
|
996d8ab954 | ||
|
|
fac2546a92 | ||
|
|
728ea6172a | ||
|
|
f59d225029 | ||
|
|
0b178a715f | ||
|
|
e06e5328c2 | ||
|
|
1c14cd0979 | ||
|
|
f9141f5ba2 | ||
|
|
48da5c976c | ||
|
|
fa38c81c08 | ||
|
|
8d5fe5270f | ||
|
|
0dc0d66549 | ||
|
|
f589fcc2d0 | ||
|
|
edd44a0993 | ||
|
|
2aae496742 | ||
|
|
6f72046f86 | ||
|
|
d4a9b446a6 | ||
|
|
95f571e9b9 | ||
|
|
e8aeae5c07 | ||
|
|
ddf6dc0343 | ||
|
|
36d55a9db7 | ||
|
|
7d41379ad5 | ||
|
|
63e928da96 | ||
|
|
5c983b64bc | ||
|
|
b2d36c0e68 | ||
|
|
6123a1620e | ||
|
|
5ae7c10a00 | ||
|
|
b5a6794381 | ||
|
|
6b575f836a | ||
|
|
c83589cac6 | ||
|
|
d64492bda5 | ||
|
|
33d6c75924 | ||
|
|
89f01bad42 | ||
|
|
767496f81b | ||
|
|
147a477365 | ||
|
|
13171f636f | ||
|
|
fea3f0d3e0 | ||
|
|
a3a254c2ea | ||
|
|
bd9d5f7fc0 | ||
|
|
726738ee9e | ||
|
|
725244bb2f | ||
|
|
d2ac2b8990 | ||
|
|
116569223c | ||
|
|
05442a019f | ||
|
|
db67080bf8 | ||
|
|
21fabf7436 | ||
|
|
a8c6516b31 | ||
|
|
f5ca48a56e | ||
|
|
65ceff9824 | ||
|
|
ed73cfdcc7 | ||
|
|
9cb79a7827 | ||
|
|
984f29005a | ||
|
|
805c3719af | ||
|
|
ea646149c0 | ||
|
|
eae1f8ee4d | ||
|
|
8d1de245a6 | ||
|
|
b8ef5d1efc | ||
|
|
e1098b34e8 | ||
|
|
8296f8d2da | ||
|
|
867c83383d | ||
|
|
1354119d6d | ||
|
|
53af7f81bb |
@@ -30,6 +30,8 @@
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
@@ -69,19 +69,31 @@ class MoviePilotAgent:
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
@@ -103,7 +115,8 @@ class MoviePilotAgent:
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
@@ -221,15 +234,20 @@ class MoviePilotAgent:
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
await self.send_agent_message(agent_message)
|
||||
if agent_message:
|
||||
# 发送回复
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
"""MoviePilot工具模块"""
|
||||
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from .factory import MoviePilotToolFactory
|
||||
|
||||
__all__ = [
|
||||
"MoviePilotTool",
|
||||
"SearchMediaTool",
|
||||
"AddSubscribeTool",
|
||||
"SearchTorrentsTool",
|
||||
"AddDownloadTool",
|
||||
"QuerySubscribesTool",
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"QuerySitesTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
"MoviePilotToolFactory"
|
||||
]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""MoviePilot工具基类"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
@@ -39,10 +40,35 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
await self.send_tool_message(f"▶️️{explanation}")
|
||||
return await self.run(**kwargs)
|
||||
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
return result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取工具执行时的友好提示消息
|
||||
|
||||
子类可以重写此方法,根据实际参数生成个性化的提示消息。
|
||||
如果返回 None 或空字符串,将回退使用 explanation 参数。
|
||||
|
||||
Args:
|
||||
**kwargs: 工具的所有参数(包括 explanation)
|
||||
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
@@ -68,6 +94,5 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
),
|
||||
escape_markdown=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,15 +4,42 @@ from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||
from app.agent.tools.impl.search_subscribe import SearchSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_download_tasks import QueryDownloadTasksTool
|
||||
from app.agent.tools.impl.query_library_exists import QueryLibraryExistsTool
|
||||
from app.agent.tools.impl.query_library_latest import QueryLibraryLatestTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.update_site import UpdateSiteTool
|
||||
from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_person import SearchPersonTool
|
||||
from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
|
||||
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
|
||||
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
@@ -29,16 +56,43 @@ class MoviePilotToolFactory:
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
SearchPersonTool,
|
||||
SearchPersonCreditsTool,
|
||||
RecognizeMediaTool,
|
||||
ScrapeMetadataTool,
|
||||
QueryEpisodeScheduleTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
SearchWebTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QueryDownloadsTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryRuleGroupsTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
UpdateSiteTool,
|
||||
QuerySiteUserdataTool,
|
||||
TestSiteTool,
|
||||
UpdateSiteCookieTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool
|
||||
QueryLibraryExistsTool,
|
||||
QueryLibraryLatestTool,
|
||||
QueryDirectorySettingsTool,
|
||||
ListDirectoryTool,
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
SendMessageTool,
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
|
||||
@@ -25,7 +25,7 @@ class AddDownloadInput(BaseModel):
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
@@ -35,6 +35,20 @@ class AddDownloadTool(MoviePilotTool):
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据下载参数生成友好的提示消息"""
|
||||
torrent_title = kwargs.get("torrent_title", "")
|
||||
site_name = kwargs.get("site_name", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
message = f"正在添加下载任务: {torrent_title}"
|
||||
if site_name:
|
||||
message += f" (来源: {site_name})"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -21,17 +21,55 @@ class AddSubscribeInput(BaseModel):
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)")
|
||||
quality: Optional[str] = Field(None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria. Supports advanced filtering options like quality, resolution, and effect filters using regular expressions."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据订阅参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year", "")
|
||||
media_type = kwargs.get("media_type", "")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在添加订阅: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
@@ -43,16 +81,53 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
if start_episode is not None:
|
||||
subscribe_kwargs['start_episode'] = start_episode
|
||||
if total_episode is not None:
|
||||
subscribe_kwargs['total_episode'] = total_episode
|
||||
if quality:
|
||||
subscribe_kwargs['quality'] = quality
|
||||
if resolution:
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
if filter_groups:
|
||||
subscribe_kwargs['filter_groups'] = filter_groups
|
||||
if sites:
|
||||
subscribe_kwargs['sites'] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
return f"成功添加订阅:{title} ({year})"
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
if start_episode is not None:
|
||||
params.append(f"开始集数: {start_episode}")
|
||||
if total_episode is not None:
|
||||
params.append(f"总集数: {total_episode}")
|
||||
if quality:
|
||||
params.append(f"质量过滤: {quality}")
|
||||
if resolution:
|
||||
params.append(f"分辨率过滤: {resolution}")
|
||||
if effect:
|
||||
params.append(f"特效过滤: {effect}")
|
||||
if filter_groups:
|
||||
params.append(f"规则组: {', '.join(filter_groups)}")
|
||||
if sites:
|
||||
params.append(f"站点: {', '.join(map(str, sites))}")
|
||||
if params:
|
||||
result_msg += f"\n配置参数: {', '.join(params)}"
|
||||
return result_msg
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
|
||||
76
app/agent/tools/impl/delete_download.py
Normal file
76
app/agent/tools/impl/delete_download.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""删除下载任务工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
task_identifier: str = Field(..., description="Task identifier: can be task hash (unique identifier) or task title/name")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader. Can delete by task hash (unique identifier) or task title/name. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
task_identifier = kwargs.get("task_identifier", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
message = f"正在删除下载任务: {task_identifier}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, task_identifier: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: task_identifier={task_identifier}, downloader={downloader}, delete_files={delete_files}")
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果task_identifier看起来像hash(通常是40个字符的十六进制字符串)
|
||||
task_hash = None
|
||||
if len(task_identifier) == 40 and all(c in '0123456789abcdefABCDEF' for c in task_identifier):
|
||||
# 直接使用hash
|
||||
task_hash = task_identifier
|
||||
else:
|
||||
# 通过标题查找任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
for dl in downloads:
|
||||
# 检查标题或名称是否匹配
|
||||
if (task_identifier.lower() in (dl.title or "").lower()) or \
|
||||
(task_identifier.lower() in (dl.name or "").lower()):
|
||||
task_hash = dl.hash
|
||||
break
|
||||
|
||||
if not task_hash:
|
||||
return f"未找到匹配的下载任务:{task_identifier},请使用 query_downloads 工具查询可用的下载任务"
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[task_hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{task_identifier} {files_info}"
|
||||
else:
|
||||
return f"删除下载任务失败:{task_identifier},请检查任务是否存在或下载器是否可用"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
63
app/agent/tools/impl/delete_subscribe.py
Normal file
63
app/agent/tools/impl/delete_subscribe.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""删除订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
return f"正在删除订阅 (ID: {subscribe_id})"
|
||||
|
||||
async def run(self, subscribe_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
# 获取订阅信息
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
@@ -14,7 +14,21 @@ class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
@@ -26,29 +40,98 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据推荐参数生成友好的提示消息"""
|
||||
source = kwargs.get("source", "tmdb_trending")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
limit = kwargs.get("limit", 20)
|
||||
|
||||
source_map = {
|
||||
"tmdb_trending": "TMDB流行趋势",
|
||||
"tmdb_movies": "TMDB热门电影",
|
||||
"tmdb_tvs": "TMDB热门电视剧",
|
||||
"douban_hot": "豆瓣热门",
|
||||
"douban_movie_hot": "豆瓣热门电影",
|
||||
"douban_tv_hot": "豆瓣热门电视剧",
|
||||
"douban_movie_showing": "豆瓣正在热映",
|
||||
"douban_movies": "豆瓣最新电影",
|
||||
"douban_tvs": "豆瓣最新电视剧",
|
||||
"douban_movie_top250": "豆瓣电影TOP250",
|
||||
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
|
||||
"douban_tv_weekly_global": "豆瓣全球剧集榜",
|
||||
"douban_tv_animation": "豆瓣热门动漫",
|
||||
"bangumi_calendar": "番组计划"
|
||||
}
|
||||
source_desc = source_map.get(source, source)
|
||||
|
||||
message = f"正在获取推荐: {source_desc}"
|
||||
if media_type != "all":
|
||||
message += f" [{media_type}]"
|
||||
message += f" (限制: {limit}条)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
name_dicts = {
|
||||
"tmdb_trending": "TMDB 热门推荐",
|
||||
"douban_hot": "豆瓣热门推荐",
|
||||
"bangumi_calendar": "番组计划推荐"
|
||||
}
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
results = await recommend_chain.async_tmdb_trending(limit=limit)
|
||||
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
|
||||
# 如果需要限制数量,需要在返回后截取
|
||||
results = await recommend_chain.async_tmdb_trending(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "tmdb_movies":
|
||||
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_movies(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "tmdb_tvs":
|
||||
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_tvs(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(limit=limit)
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(limit=limit)
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
|
||||
elif source == "douban_movie_hot":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
elif source == "douban_tv_hot":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
elif source == "douban_movie_showing":
|
||||
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
|
||||
elif source == "douban_movies":
|
||||
results = await recommend_chain.async_douban_movies(page=1, count=limit)
|
||||
elif source == "douban_tvs":
|
||||
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
|
||||
elif source == "douban_movie_top250":
|
||||
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
|
||||
elif source == "douban_tv_weekly_chinese":
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
|
||||
elif source == "douban_tv_weekly_global":
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
|
||||
elif source == "douban_tv_animation":
|
||||
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(limit=limit)
|
||||
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
|
||||
else:
|
||||
# 不支持的推荐来源
|
||||
supported_sources = [
|
||||
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
|
||||
"douban_hot", "douban_movie_hot", "douban_tv_hot",
|
||||
"douban_movie_showing", "douban_movies", "douban_tvs",
|
||||
"douban_movie_top250", "douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global", "douban_tv_animation",
|
||||
"bangumi_calendar"
|
||||
]
|
||||
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
@@ -57,7 +140,11 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
# r 已经是字典格式(to_dict的结果)
|
||||
# r 应该是字典格式(to_dict的结果),但为了安全起见进行检查
|
||||
if not isinstance(r, dict):
|
||||
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
|
||||
continue
|
||||
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
@@ -67,7 +154,6 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"overview": r.get("overview", "")[:200] + "..." if r.get("overview") and len(r.get("overview", "")) > 200 else r.get("overview"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
|
||||
130
app/agent/tools/impl/list_directory.py
Normal file
130
app/agent/tools/impl/list_directory.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""查询文件系统目录内容工具"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.storage import StorageChain
|
||||
from app.log import logger
|
||||
from app.schemas.file import FileItem
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class ListDirectoryInput(BaseModel):
|
||||
"""查询文件系统目录内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(..., description="Directory path to list contents (e.g., '/home/user/downloads' or 'C:/Downloads')")
|
||||
storage: Optional[str] = Field("local", description="Storage type (default: 'local' for local file system, can be 'smb', 'alist', etc.)")
|
||||
sort_by: Optional[str] = Field("name", description="Sort order: 'name' for alphabetical sorting, 'time' for modification time sorting (default: 'name')")
|
||||
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据目录参数生成友好的提示消息"""
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
|
||||
message = f"正在查询目录: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
sort_by: Optional[str] = "name", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, sort_by={sort_by}")
|
||||
|
||||
try:
|
||||
# 规范化路径
|
||||
if not path:
|
||||
return "错误:路径不能为空"
|
||||
|
||||
# 确保路径格式正确
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not path.startswith("/") and not (len(path) > 1 and path[1] == ":"):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
path = str(Path(path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=path,
|
||||
type="dir"
|
||||
)
|
||||
|
||||
# 查询目录内容
|
||||
storage_chain = StorageChain()
|
||||
file_list = storage_chain.list_files(fileitem, recursion=False)
|
||||
|
||||
if file_list is None:
|
||||
return f"无法访问目录:{path},请检查路径是否正确或存储是否可用"
|
||||
|
||||
if not file_list:
|
||||
return f"目录 {path} 为空"
|
||||
|
||||
# 排序
|
||||
if sort_by == "time":
|
||||
file_list.sort(key=lambda x: x.modify_time or 0, reverse=True)
|
||||
else:
|
||||
# 默认按名称排序(目录优先,然后按名称)
|
||||
file_list.sort(key=lambda x: (
|
||||
0 if x.type == "dir" else 1,
|
||||
StringUtils.natural_sort_key(x.name or "")
|
||||
))
|
||||
|
||||
# 限制返回数量
|
||||
total_count = len(file_list)
|
||||
limited_list = file_list[:20]
|
||||
|
||||
# 转换为字典格式
|
||||
simplified_items = []
|
||||
for item in limited_list:
|
||||
# 格式化文件大小
|
||||
size_str = None
|
||||
if item.size:
|
||||
size_str = StringUtils.str_filesize(item.size)
|
||||
|
||||
# 格式化修改时间
|
||||
modify_time_str = None
|
||||
if item.modify_time:
|
||||
try:
|
||||
modify_time_str = datetime.fromtimestamp(item.modify_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
modify_time_str = str(item.modify_time)
|
||||
|
||||
simplified = {
|
||||
"name": item.name,
|
||||
"type": item.type,
|
||||
"path": item.path,
|
||||
"size": size_str,
|
||||
"modify_time": modify_time_str
|
||||
}
|
||||
# 如果是文件,添加扩展名
|
||||
if item.type == "file" and item.extension:
|
||||
simplified["extension"] = item.extension
|
||||
simplified_items.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_items, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 20 个项目。\n\n{result_json}"
|
||||
else:
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询目录内容失败: {e}", exc_info=True)
|
||||
return f"查询目录内容时发生错误: {str(e)}"
|
||||
|
||||
134
app/agent/tools/impl/query_directory_settings.py
Normal file
134
app/agent/tools/impl/query_directory_settings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""查询系统目录设置工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDirectorySettingsInput(BaseModel):
|
||||
"""查询系统目录设置工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
directory_type: Optional[str] = Field("all",
|
||||
description="Filter directories by type: 'download' for download directories, 'library' for media library directories, 'all' for all directories")
|
||||
storage_type: Optional[str] = Field("all",
|
||||
description="Filter directories by storage type: 'local' for local storage, 'remote' for remote storage, 'all' for all storage types")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter directories by name (partial match, optional)")
|
||||
|
||||
|
||||
class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
name: str = "query_directory_settings"
|
||||
description: str = "Query system directory configuration settings (NOT file listings). Returns configured directory paths, storage types, transfer modes, and other directory-related settings. Use 'list_directory' to list actual files and folders in a directory."
|
||||
args_schema: Type[BaseModel] = QueryDirectorySettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
directory_type = kwargs.get("directory_type", "all")
|
||||
storage_type = kwargs.get("storage_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询目录配置"]
|
||||
|
||||
if directory_type != "all":
|
||||
type_map = {"download": "下载目录", "library": "媒体库目录"}
|
||||
parts.append(f"类型: {type_map.get(directory_type, directory_type)}")
|
||||
|
||||
if storage_type != "all":
|
||||
storage_map = {"local": "本地存储", "remote": "远程存储"}
|
||||
parts.append(f"存储: {storage_map.get(storage_type, storage_type)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, directory_type: Optional[str] = "all",
|
||||
storage_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: directory_type={directory_type}, storage_type={storage_type}, name={name}")
|
||||
|
||||
try:
|
||||
directory_helper = DirectoryHelper()
|
||||
|
||||
# 根据目录类型获取目录列表
|
||||
if directory_type == "download":
|
||||
dirs = directory_helper.get_download_dirs()
|
||||
elif directory_type == "library":
|
||||
dirs = directory_helper.get_library_dirs()
|
||||
else:
|
||||
dirs = directory_helper.get_dirs()
|
||||
|
||||
# 按存储类型过滤
|
||||
filtered_dirs = []
|
||||
for d in dirs:
|
||||
# 按存储类型过滤
|
||||
if storage_type == "local":
|
||||
# 对于下载目录,检查 storage;对于媒体库目录,检查 library_storage
|
||||
if directory_type == "download" and d.storage != "local":
|
||||
continue
|
||||
elif directory_type == "library" and d.library_storage != "local":
|
||||
continue
|
||||
elif directory_type == "all":
|
||||
# 检查是否有本地存储配置
|
||||
if d.download_path and d.storage != "local":
|
||||
continue
|
||||
if d.library_path and d.library_storage != "local":
|
||||
continue
|
||||
elif storage_type == "remote":
|
||||
# 对于下载目录,检查 storage;对于媒体库目录,检查 library_storage
|
||||
if directory_type == "download" and d.storage == "local":
|
||||
continue
|
||||
elif directory_type == "library" and d.library_storage == "local":
|
||||
continue
|
||||
elif directory_type == "all":
|
||||
# 检查是否有远程存储配置
|
||||
if d.download_path and d.storage == "local":
|
||||
continue
|
||||
if d.library_path and d.library_storage == "local":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and d.name and name.lower() not in d.name.lower():
|
||||
continue
|
||||
|
||||
filtered_dirs.append(d)
|
||||
|
||||
if filtered_dirs:
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_dirs = []
|
||||
for d in filtered_dirs:
|
||||
simplified = {
|
||||
"name": d.name,
|
||||
"priority": d.priority,
|
||||
"storage": d.storage,
|
||||
"download_path": d.download_path,
|
||||
"library_path": d.library_path,
|
||||
"library_storage": d.library_storage,
|
||||
"media_type": d.media_type,
|
||||
"media_category": d.media_category,
|
||||
"monitor_type": d.monitor_type,
|
||||
"monitor_mode": d.monitor_mode,
|
||||
"transfer_type": d.transfer_type,
|
||||
"overwrite_mode": d.overwrite_mode,
|
||||
"renaming": d.renaming,
|
||||
"scraping": d.scraping,
|
||||
"notify": d.notify,
|
||||
"download_type_folder": d.download_type_folder,
|
||||
"download_category_folder": d.download_category_folder,
|
||||
"library_type_folder": d.library_type_folder,
|
||||
"library_category_folder": d.library_category_folder
|
||||
}
|
||||
simplified_dirs.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_dirs, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关目录配置"
|
||||
except Exception as e:
|
||||
logger.error(f"查询系统目录设置失败: {e}", exc_info=True)
|
||||
return f"查询系统目录设置时发生错误: {str(e)}"
|
||||
|
||||
197
app/agent/tools/impl/query_download_tasks.py
Normal file
197
app/agent/tools/impl/query_download_tasks.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
hash: Optional[str] = Field(None, description="Query specific download task by hash (optional, if provided will search for this specific task regardless of status)")
|
||||
title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)")
|
||||
|
||||
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
status = kwargs.get("status", "all")
|
||||
hash_value = kwargs.get("hash")
|
||||
title = kwargs.get("title")
|
||||
|
||||
parts = ["正在查询下载任务"]
|
||||
|
||||
if downloader:
|
||||
parts.append(f"下载器: {downloader}")
|
||||
|
||||
if status != "all":
|
||||
status_map = {"downloading": "下载中", "completed": "已完成", "paused": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
if hash_value:
|
||||
parts.append(f"Hash: {hash_value[:8]}...")
|
||||
elif title:
|
||||
parts.append(f"标题: {title}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
hash: Optional[str] = None,
|
||||
title: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
downloads = []
|
||||
for torrent in torrents:
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
downloads.append(torrent)
|
||||
filtered_downloads = downloads
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
# 检查标题或名称是否匹配
|
||||
if (title.lower() in (torrent.title or "").lower()) or \
|
||||
(title.lower() in (torrent.name or "").lower()):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
if not filtered_downloads:
|
||||
return f"未找到标题包含 '{title}' 的下载任务"
|
||||
else:
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
continue
|
||||
# 根据status过滤
|
||||
if status == "completed":
|
||||
# 已完成的任务(state为seeding或completed)
|
||||
if torrent.state not in ["seeding", "completed"]:
|
||||
continue
|
||||
elif status == "paused":
|
||||
# 已暂停的任务
|
||||
if torrent.state != "paused":
|
||||
continue
|
||||
# status == "all" 时不过滤
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
|
||||
# 如果查询的是特定hash或title,添加明确的状态信息
|
||||
if hash:
|
||||
return f"找到hash为 {hash} 的下载任务:\n\n{result_json}"
|
||||
elif title:
|
||||
return f"找到 {total_count} 个标题包含 '{title}' 的下载任务:\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
@@ -1,7 +1,7 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -21,6 +21,10 @@ class QueryDownloadersTool(MoviePilotTool):
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询下载器配置"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDownloadsInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadsInput
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
# 使用 DownloadChain.downloading 方法获取正在下载的任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
if status != "all" and dl.status != status:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
116
app/agent/tools/impl/query_episode_schedule.py
Normal file
116
app/agent/tools/impl/query_episode_schedule.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""查询剧集上映时间工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryEpisodeScheduleInput(BaseModel):
|
||||
"""查询剧集上映时间工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series")
|
||||
season: int = Field(..., description="Season number to query")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
name: str = "query_episode_schedule"
|
||||
description: str = "Query TV series episode air dates and schedule. Returns detailed information for each episode including air date, episode number, title, overview, and other metadata. Filters out episodes without air dates."
|
||||
args_schema: Type[BaseModel] = QueryEpisodeScheduleInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
season = kwargs.get("season")
|
||||
episode_group = kwargs.get("episode_group")
|
||||
|
||||
message = f"正在查询剧集上映时间: TMDB ID {tmdb_id} 第{season}季"
|
||||
if episode_group:
|
||||
message += f" (剧集组: {episode_group})"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, tmdb_id: int, season: int, episode_group: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, season={season}, episode_group={episode_group}")
|
||||
|
||||
try:
|
||||
# 获取媒体信息(用于获取标题和海报)
|
||||
media_chain = MediaChain()
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=MediaType.TV)
|
||||
if not mediainfo:
|
||||
return f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
|
||||
# 获取集列表
|
||||
tmdb_chain = TmdbChain()
|
||||
episodes = await tmdb_chain.async_tmdb_episodes(
|
||||
tmdbid=tmdb_id,
|
||||
season=season,
|
||||
episode_group=episode_group
|
||||
)
|
||||
|
||||
if not episodes:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 第{season}季的集信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 过滤掉没有上映日期的集,并构建每集的详细信息
|
||||
episode_list = []
|
||||
for episode in episodes:
|
||||
air_date = episode.air_date
|
||||
|
||||
# 过滤掉没有上映日期的数据
|
||||
if not air_date:
|
||||
continue
|
||||
|
||||
episode_info = {
|
||||
"episode_number": episode.episode_number,
|
||||
"name": episode.name,
|
||||
"air_date": air_date,
|
||||
"runtime": episode.runtime,
|
||||
"vote_average": episode.vote_average,
|
||||
"still_path": episode.still_path,
|
||||
"episode_type": episode.episode_type,
|
||||
"season_number": episode.season_number
|
||||
}
|
||||
episode_list.append(episode_info)
|
||||
|
||||
if not episode_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 第{season}季的播出时间信息(所有集都没有播出日期)"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 按播出日期排序
|
||||
episode_list.sort(key=lambda x: (x["air_date"] or "", x["episode_number"] or 0))
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season,
|
||||
"episode_group": episode_group,
|
||||
"series_title": mediainfo.title if mediainfo else None,
|
||||
"series_poster": mediainfo.poster_path if mediainfo else None,
|
||||
"total_episodes": len(episodes),
|
||||
"episodes_with_air_date": len(episode_list),
|
||||
"episodes": episode_list
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询剧集上映时间失败: {str(e)}"
|
||||
logger.error(f"查询剧集上映时间失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season
|
||||
}, ensure_ascii=False)
|
||||
97
app/agent/tools/impl/query_library_exists.py
Normal file
97
app/agent/tools/impl/query_library_exists.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
# 创建 MediaInfo 对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
|
||||
# 转换媒体类型
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
# media_type == "all" 时不设置类型,让媒体服务器自动判断
|
||||
|
||||
# 调用媒体服务器接口实时查询
|
||||
media_chain = MediaServerChain()
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
if not existsinfo:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 如果找到了,获取详细信息
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有详细信息,返回基本信息
|
||||
result_dict = {
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
}
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
86
app/agent/tools/impl/query_library_latest.py
Normal file
86
app/agent/tools/impl/query_library_latest.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""查询媒体服务器最近入库影片工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
|
||||
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
|
||||
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
server = kwargs.get("server")
|
||||
count = kwargs.get("count", 20)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
else:
|
||||
parts.append("所有服务器")
|
||||
|
||||
parts.append(f"数量: {count}条")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
|
||||
try:
|
||||
media_chain = MediaServerChain()
|
||||
results = []
|
||||
|
||||
# 如果没有指定服务器,获取所有启用的媒体服务器
|
||||
if not server:
|
||||
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
|
||||
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
|
||||
|
||||
if not enabled_servers:
|
||||
return "未找到启用的媒体服务器"
|
||||
|
||||
# 遍历所有启用的服务器
|
||||
for server_name in enabled_servers:
|
||||
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server_name
|
||||
results.append(item_dict)
|
||||
else:
|
||||
# 查询指定服务器
|
||||
latest_items = media_chain.latest(server=server, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server
|
||||
results.append(item_dict)
|
||||
|
||||
if not results:
|
||||
server_info = f"服务器 {server}" if server else "所有服务器"
|
||||
return f"未找到 {server_info} 的最近入库影片"
|
||||
|
||||
# 限制返回数量,避免结果过多
|
||||
if len(results) > count:
|
||||
results = results[:count]
|
||||
|
||||
return json.dumps(results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
|
||||
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class QueryMediaLibraryInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryMediaLibraryInput
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_server_oper = MediaServerOper()
|
||||
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.to_dict() for m in filtered_medias])
|
||||
return "媒体库中未找到相关媒体"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
152
app/agent/tools/impl/query_popular_subscribes.py
Normal file
152
app/agent/tools/impl/query_popular_subscribes.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""查询热门订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import cn2an
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
stype: str = Field(..., description="Media type: '电影' for films, '电视剧' for television series")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
|
||||
sort_type: Optional[str] = Field(None, description="Sort type (optional, e.g., 'count', 'rating')")
|
||||
|
||||
|
||||
class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
name: str = "query_popular_subscribes"
|
||||
description: str = "Query popular subscriptions based on user shared data. Shows media with the most subscribers, supports filtering by genre, rating, minimum subscribers, and pagination."
|
||||
args_schema: Type[BaseModel] = QueryPopularSubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
stype = kwargs.get("stype", "")
|
||||
page = kwargs.get("page", 1)
|
||||
min_sub = kwargs.get("min_sub")
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = [f"正在查询热门订阅 [{stype}]"]
|
||||
|
||||
if min_sub:
|
||||
parts.append(f"最少订阅: {min_sub}")
|
||||
if min_rating:
|
||||
parts.append(f"最低评分: {min_rating}")
|
||||
if max_rating:
|
||||
parts.append(f"最高评分: {max_rating}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, stype: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: stype={stype}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"genre_id={genre_id}, min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
subscribes = await subscribe_helper.async_get_statistic(
|
||||
stype=stype,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
if not subscribes:
|
||||
return "未找到热门订阅数据(可能订阅统计功能未启用)"
|
||||
|
||||
# 转换为MediaInfo格式并过滤
|
||||
ret_medias = []
|
||||
for sub in subscribes:
|
||||
# 订阅人数
|
||||
subscriber_count = sub.get("count", 0)
|
||||
# 如果设置了最小订阅人数,进行过滤
|
||||
if min_sub and subscriber_count < min_sub:
|
||||
continue
|
||||
|
||||
media = MediaInfo()
|
||||
media.type = MediaType(sub.get("type"))
|
||||
media.tmdb_id = sub.get("tmdbid")
|
||||
# 处理标题
|
||||
title = sub.get("name")
|
||||
season = sub.get("season")
|
||||
if season and int(season) > 1 and media.tmdb_id:
|
||||
# 小写数据转大写
|
||||
season_str = cn2an.an2cn(season, "low")
|
||||
title = f"{title} 第{season_str}季"
|
||||
media.title = title
|
||||
media.year = sub.get("year")
|
||||
media.douban_id = sub.get("doubanid")
|
||||
media.bangumi_id = sub.get("bangumiid")
|
||||
media.tvdb_id = sub.get("tvdbid")
|
||||
media.imdb_id = sub.get("imdbid")
|
||||
media.season = sub.get("season")
|
||||
media.vote_average = sub.get("vote")
|
||||
media.poster_path = sub.get("poster")
|
||||
media.backdrop_path = sub.get("backdrop")
|
||||
media.popularity = subscriber_count
|
||||
ret_medias.append(media)
|
||||
|
||||
if not ret_medias:
|
||||
return "未找到符合条件的热门订阅"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_medias = []
|
||||
for media in ret_medias:
|
||||
media_dict = media.to_dict()
|
||||
simplified = {
|
||||
"type": media_dict.get("type"),
|
||||
"title": media_dict.get("title"),
|
||||
"year": media_dict.get("year"),
|
||||
"tmdb_id": media_dict.get("tmdb_id"),
|
||||
"douban_id": media_dict.get("douban_id"),
|
||||
"bangumi_id": media_dict.get("bangumi_id"),
|
||||
"tvdb_id": media_dict.get("tvdb_id"),
|
||||
"imdb_id": media_dict.get("imdb_id"),
|
||||
"season": media_dict.get("season"),
|
||||
"vote_average": media_dict.get("vote_average"),
|
||||
"poster_path": media_dict.get("poster_path"),
|
||||
"backdrop_path": media_dict.get("backdrop_path"),
|
||||
"popularity": media_dict.get("popularity"), # 订阅人数
|
||||
"subscriber_count": media_dict.get("popularity") # 明确标注为订阅人数
|
||||
}
|
||||
simplified_medias.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_medias, ensure_ascii=False, indent=2)
|
||||
|
||||
pagination_info = f"第 {page} 页,每页 {count} 条,共 {len(simplified_medias)} 条结果"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
|
||||
return f"查询热门订阅时发生错误: {str(e)}"
|
||||
|
||||
65
app/agent/tools/impl/query_rule_groups.py
Normal file
65
app/agent/tools/impl/query_rule_groups.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""查询规则组工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
return "正在查询所有规则组"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
rule_helper = RuleHelper()
|
||||
rule_groups = rule_helper.get_rule_groups()
|
||||
|
||||
if not rule_groups:
|
||||
return json.dumps({
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 精简字段,过滤掉 rule_string 避免结果过大
|
||||
simplified_groups = []
|
||||
for group in rule_groups:
|
||||
simplified = {
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category
|
||||
}
|
||||
simplified_groups.append(simplified)
|
||||
|
||||
result = {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
55
app/agent/tools/impl/query_schedulers.py
Normal file
55
app/agent/tools/impl/query_schedulers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""查询定时服务工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class QuerySchedulersInput(BaseModel):
|
||||
"""查询定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QuerySchedulersTool(MoviePilotTool):
|
||||
name: str = "query_schedulers"
|
||||
description: str = "Query scheduled tasks and list all available scheduler jobs. Shows job status, next run time, and provider information."
|
||||
args_schema: Type[BaseModel] = QuerySchedulersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询定时服务"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
scheduler = Scheduler()
|
||||
schedulers = scheduler.list()
|
||||
if schedulers:
|
||||
# 转换为字典列表以便JSON序列化
|
||||
schedulers_list = []
|
||||
for s in schedulers:
|
||||
schedulers_list.append({
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"provider": s.provider,
|
||||
"status": s.status,
|
||||
"next_run": s.next_run
|
||||
})
|
||||
result_json = json.dumps(schedulers_list, ensure_ascii=False, indent=2)
|
||||
# 限制最多30条结果
|
||||
total_count = len(schedulers_list)
|
||||
if total_count > 30:
|
||||
limited_schedulers = schedulers_list[:30]
|
||||
limited_json = json.dumps(limited_schedulers, ensure_ascii=False, indent=2)
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{limited_json}"
|
||||
return result_json
|
||||
return "未找到定时服务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询定时服务失败: {e}", exc_info=True)
|
||||
return f"查询定时服务时发生错误: {str(e)}"
|
||||
|
||||
136
app/agent/tools/impl/query_site_userdata.py
Normal file
136
app/agent/tools/impl/query_site_userdata.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""查询站点用户数据工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
message += " (最新数据)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db,
|
||||
domain=site.domain,
|
||||
workdate=workdate
|
||||
)
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": []
|
||||
}
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
"username": user_data.username,
|
||||
"userid": user_data.userid,
|
||||
"user_level": user_data.user_level,
|
||||
"join_at": user_data.join_at,
|
||||
"bonus": user_data.bonus,
|
||||
"upload": user_data.upload,
|
||||
"upload_gb": round(upload_gb, 2),
|
||||
"download": user_data.download,
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||
reverse=True
|
||||
)
|
||||
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
@@ -21,9 +21,25 @@ class QuerySitesInput(BaseModel):
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询站点"]
|
||||
|
||||
if status != "all":
|
||||
status_map = {"active": "已启用", "inactive": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
|
||||
113
app/agent/tools/impl/query_subscribe_history.py
Normal file
113
app/agent/tools/impl/query_subscribe_history.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""查询订阅历史工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all", description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_history"
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询订阅历史"]
|
||||
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 根据类型查询
|
||||
if media_type == "all":
|
||||
# 查询所有类型,需要分别查询电影和电视剧
|
||||
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
# 按日期排序
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
# 查询指定类型
|
||||
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
|
||||
|
||||
# 按名称过滤
|
||||
filtered_history = []
|
||||
if name:
|
||||
name_lower = name.lower()
|
||||
for record in all_history:
|
||||
if record.name and name_lower in record.name.lower():
|
||||
filtered_history.append(record)
|
||||
else:
|
||||
filtered_history = all_history
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 限制最多30条
|
||||
total_count = len(filtered_history)
|
||||
limited_history = filtered_history[:30]
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in limited_history:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username
|
||||
}
|
||||
# 添加过滤规则信息(如果有)
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
113
app/agent/tools/impl/query_subscribe_shares.py
Normal file
113
app/agent/tools/impl/query_subscribe_shares.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""查询订阅分享工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribeSharesInput(BaseModel):
|
||||
"""查询订阅分享工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
|
||||
sort_type: Optional[str] = Field(None, description="Sort type (optional, e.g., 'count', 'rating')")
|
||||
|
||||
|
||||
class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_shares"
|
||||
description: str = "Query shared subscriptions from other users. Shows popular subscriptions shared by the community with filtering and pagination support."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeSharesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
name = kwargs.get("name")
|
||||
page = kwargs.get("page", 1)
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = ["正在查询订阅分享"]
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
if min_rating:
|
||||
parts.append(f"最低评分: {min_rating}")
|
||||
if max_rating:
|
||||
parts.append(f"最高评分: {max_rating}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: name={name}, page={page}, count={count}, genre_id={genre_id}, "
|
||||
f"min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
shares = await subscribe_helper.async_get_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
if not shares:
|
||||
return "未找到订阅分享数据(可能订阅分享功能未启用)"
|
||||
|
||||
# 简化字段,只保留关键信息
|
||||
simplified_shares = []
|
||||
for share in shares:
|
||||
simplified = {
|
||||
"id": share.get("id"),
|
||||
"name": share.get("name"),
|
||||
"year": share.get("year"),
|
||||
"type": share.get("type"),
|
||||
"season": share.get("season"),
|
||||
"tmdbid": share.get("tmdbid"),
|
||||
"doubanid": share.get("doubanid"),
|
||||
"bangumiid": share.get("bangumiid"),
|
||||
"poster": share.get("poster"),
|
||||
"vote": share.get("vote"),
|
||||
"share_title": share.get("share_title"),
|
||||
"share_comment": share.get("share_comment"),
|
||||
"share_user": share.get("share_user"),
|
||||
"fork_count": share.get("fork_count", 0)
|
||||
}
|
||||
# 截断过长的描述
|
||||
if simplified.get("description") and len(simplified["description"]) > 200:
|
||||
simplified["description"] = simplified["description"][:200] + "..."
|
||||
simplified_shares.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_shares, ensure_ascii=False, indent=2)
|
||||
|
||||
pagination_info = f"第 {page} 页,每页 {count} 条,共 {len(simplified_shares)} 条结果"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
|
||||
return f"查询订阅分享时发生错误: {str(e)}"
|
||||
|
||||
@@ -16,7 +16,7 @@ class QuerySubscribesInput(BaseModel):
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
@@ -24,6 +24,24 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
|
||||
parts = ["正在查询订阅"]
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "P": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
@@ -37,9 +55,9 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多20条结果
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:20]
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
@@ -54,7 +72,6 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
@@ -64,8 +81,8 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
|
||||
133
app/agent/tools/impl/query_transfer_history.py
Normal file
133
app/agent/tools/impl/query_transfer_history.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""查询整理历史记录工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import jieba
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
"""查询整理历史记录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="Search by title (optional, supports partial match)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter by status: 'success' for successful transfers, 'failed' for failed transfers, 'all' for all records (default: 'all')")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1, each page contains 30 records)")
|
||||
|
||||
|
||||
class QueryTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "query_transfer_history"
|
||||
description: str = "Query file transfer history records. Shows transfer status, source and destination paths, media information, and transfer details. Supports filtering by title and status."
|
||||
args_schema: Type[BaseModel] = QueryTransferHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
title = kwargs.get("title")
|
||||
status = kwargs.get("status", "all")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询整理历史"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if status != "all":
|
||||
status_map = {"success": "成功", "failed": "失败"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, title: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, status={status}, page={page}")
|
||||
|
||||
try:
|
||||
# 处理状态参数
|
||||
status_bool = None
|
||||
if status == "success":
|
||||
status_bool = True
|
||||
elif status == "failed":
|
||||
status_bool = False
|
||||
|
||||
# 处理页码参数
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
|
||||
# 每页记录数
|
||||
count = 50
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用 jieba 分词处理标题
|
||||
words = jieba.cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await TransferHistory.async_list_by_title(
|
||||
db, title=title_search, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count_by_title(
|
||||
db, title=title_search, status=status_bool
|
||||
)
|
||||
else:
|
||||
# 查询所有记录
|
||||
result = await TransferHistory.async_list_by_page(
|
||||
db, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count(db, status=status_bool)
|
||||
|
||||
if not result:
|
||||
return "未找到相关整理历史记录"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in result:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
"src": record.src,
|
||||
"dest": record.dest,
|
||||
"mode": record.mode,
|
||||
"status": "成功" if record.status else "失败",
|
||||
"date": record.date,
|
||||
"downloader": record.downloader,
|
||||
"download_hash": record.download_hash
|
||||
}
|
||||
# 如果失败,添加错误信息
|
||||
if not record.status and record.errmsg:
|
||||
simplified["errmsg"] = record.errmsg
|
||||
# 添加媒体ID信息(如果有)
|
||||
if record.tmdbid:
|
||||
simplified["tmdbid"] = record.tmdbid
|
||||
if record.imdbid:
|
||||
simplified["imdbid"] = record.imdbid
|
||||
if record.doubanid:
|
||||
simplified["doubanid"] = record.doubanid
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total + count - 1) // count if total > 0 else 1
|
||||
|
||||
# 构建分页信息
|
||||
pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询整理历史记录失败: {e}", exc_info=True)
|
||||
return f"查询整理历史记录时发生错误: {str(e)}"
|
||||
128
app/agent/tools/impl/query_workflows.py
Normal file
128
app/agent/tools/impl/query_workflows.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""查询工作流工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryWorkflowsInput(BaseModel):
|
||||
"""查询工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
state: Optional[str] = Field("all", description="Filter workflows by state: 'W' for waiting, 'R' for running, 'P' for paused, 'S' for success, 'F' for failed, 'all' for all workflows (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter workflows by name (partial match, optional)")
|
||||
trigger_type: Optional[str] = Field("all", description="Filter workflows by trigger type: 'timer' for scheduled, 'event' for event-triggered, 'manual' for manual, 'all' for all types (default: 'all')")
|
||||
|
||||
|
||||
class QueryWorkflowsTool(MoviePilotTool):
|
||||
name: str = "query_workflows"
|
||||
description: str = "Query workflow list and status. Shows workflow name, description, trigger type, state, execution count, and other workflow details. Supports filtering by state, name, and trigger type."
|
||||
args_schema: Type[BaseModel] = QueryWorkflowsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
state = kwargs.get("state", "all")
|
||||
name = kwargs.get("name")
|
||||
trigger_type = kwargs.get("trigger_type", "all")
|
||||
|
||||
parts = ["正在查询工作流"]
|
||||
|
||||
if state != "all":
|
||||
state_map = {"W": "等待", "R": "运行中", "P": "暂停", "S": "成功", "F": "失败"}
|
||||
parts.append(f"状态: {state_map.get(state, state)}")
|
||||
|
||||
if trigger_type != "all":
|
||||
trigger_map = {"timer": "定时触发", "event": "事件触发", "manual": "手动触发"}
|
||||
parts.append(f"触发类型: {trigger_map.get(trigger_type, trigger_type)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, state: Optional[str] = "all",
|
||||
name: Optional[str] = None,
|
||||
trigger_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: state={state}, name={name}, trigger_type={trigger_type}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflows = await workflow_oper.async_list()
|
||||
|
||||
# 过滤工作流
|
||||
filtered_workflows = []
|
||||
for wf in workflows:
|
||||
# 按状态过滤
|
||||
if state != "all" and wf.state != state:
|
||||
continue
|
||||
|
||||
# 按触发类型过滤
|
||||
if trigger_type != "all":
|
||||
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
|
||||
continue
|
||||
elif trigger_type == "event" and wf.trigger_type != "event":
|
||||
continue
|
||||
elif trigger_type == "manual" and wf.trigger_type != "manual":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and wf.name and name.lower() not in wf.name.lower():
|
||||
continue
|
||||
|
||||
filtered_workflows.append(wf)
|
||||
|
||||
if not filtered_workflows:
|
||||
return "未找到相关工作流"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_workflows = []
|
||||
for wf in filtered_workflows:
|
||||
# 状态说明
|
||||
state_map = {
|
||||
"W": "等待",
|
||||
"R": "运行中",
|
||||
"P": "暂停",
|
||||
"S": "成功",
|
||||
"F": "失败"
|
||||
}
|
||||
state_desc = state_map.get(wf.state, wf.state)
|
||||
|
||||
# 触发类型说明
|
||||
trigger_type_map = {
|
||||
"timer": "定时触发",
|
||||
"event": "事件触发",
|
||||
"manual": "手动触发"
|
||||
}
|
||||
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
|
||||
|
||||
simplified = {
|
||||
"id": wf.id,
|
||||
"name": wf.name,
|
||||
"description": wf.description,
|
||||
"trigger_type": trigger_type_desc,
|
||||
"state": state_desc,
|
||||
"run_count": wf.run_count,
|
||||
"timer": wf.timer,
|
||||
"event_type": wf.event_type,
|
||||
"add_time": wf.add_time,
|
||||
"last_time": wf.last_time,
|
||||
"current_action": wf.current_action
|
||||
}
|
||||
# 如果有结果,添加结果信息
|
||||
if wf.result:
|
||||
simplified["result"] = wf.result
|
||||
simplified_workflows.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
162
app/agent/tools/impl/recognize_media.py
Normal file
162
app/agent/tools/impl/recognize_media.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""识别媒体信息工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RecognizeMediaInput(BaseModel):
|
||||
"""识别媒体信息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="The title of the torrent/media to recognize (required for torrent recognition)")
|
||||
subtitle: Optional[str] = Field(None, description="The subtitle or description of the torrent (optional, helps improve recognition accuracy)")
|
||||
path: Optional[str] = Field(None, description="The file path to recognize (required for file recognition, mutually exclusive with title)")
|
||||
|
||||
|
||||
class RecognizeMediaTool(MoviePilotTool):
|
||||
name: str = "recognize_media"
|
||||
description: str = "Extract/identify media information from torrent titles or file paths (NOT database search). Supports two modes: 1) Extract from torrent title and optional subtitle, 2) Extract from file path. Returns detailed media information. Use 'search_media' to search TMDB database, or 'scrape_metadata' to generate metadata files for existing files."
|
||||
args_schema: Type[BaseModel] = RecognizeMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据识别参数生成友好的提示消息"""
|
||||
title = kwargs.get("title")
|
||||
subtitle = kwargs.get("subtitle")
|
||||
path = kwargs.get("path")
|
||||
|
||||
if path:
|
||||
message = f"正在识别文件媒体信息: {path}"
|
||||
elif title:
|
||||
message = f"正在识别种子媒体信息: {title}"
|
||||
if subtitle:
|
||||
message += f" ({subtitle})"
|
||||
else:
|
||||
message = "正在识别媒体信息"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: Optional[str] = None, subtitle: Optional[str] = None,
|
||||
path: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, subtitle={subtitle}, path={path}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
context = None
|
||||
|
||||
# 根据提供的参数选择识别方式
|
||||
if path:
|
||||
# 文件路径识别
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "文件路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
context = await media_chain.async_recognize_by_path(path)
|
||||
if context:
|
||||
return self._format_context_result(context, "文件")
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无法识别文件媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif title:
|
||||
# 种子标题识别
|
||||
metainfo = MetaInfo(title, subtitle)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(metainfo)
|
||||
if mediainfo:
|
||||
context = Context(meta_info=metainfo, media_info=mediainfo)
|
||||
return self._format_context_result(context, "种子")
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无法识别种子媒体信息: {title}",
|
||||
"title": title,
|
||||
"subtitle": subtitle
|
||||
}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "必须提供 title(标题)或 path(文件路径)参数之一"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"识别媒体信息失败: {str(e)}"
|
||||
logger.error(f"识别媒体信息失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _format_context_result(self, context: Context, source_type: str) -> str:
|
||||
"""格式化识别结果为JSON字符串"""
|
||||
if not context:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "识别结果为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
context_dict = context.to_dict()
|
||||
media_info = context_dict.get("media_info")
|
||||
meta_info = context_dict.get("meta_info")
|
||||
|
||||
# 构建简化的结果
|
||||
result = {
|
||||
"success": True,
|
||||
"source_type": source_type,
|
||||
"media_info": None,
|
||||
"meta_info": None
|
||||
}
|
||||
|
||||
# 处理媒体信息
|
||||
if media_info:
|
||||
result["media_info"] = {
|
||||
"title": media_info.get("title"),
|
||||
"en_title": media_info.get("en_title"),
|
||||
"year": media_info.get("year"),
|
||||
"type": media_info.get("type"),
|
||||
"season": media_info.get("season"),
|
||||
"tmdb_id": media_info.get("tmdb_id"),
|
||||
"imdb_id": media_info.get("imdb_id"),
|
||||
"douban_id": media_info.get("douban_id"),
|
||||
"bangumi_id": media_info.get("bangumi_id"),
|
||||
"overview": media_info.get("overview"),
|
||||
"vote_average": media_info.get("vote_average"),
|
||||
"poster_path": media_info.get("poster_path"),
|
||||
"backdrop_path": media_info.get("backdrop_path"),
|
||||
"detail_link": media_info.get("detail_link"),
|
||||
"title_year": media_info.get("title_year"),
|
||||
"source": media_info.get("source")
|
||||
}
|
||||
|
||||
# 处理元数据信息
|
||||
if meta_info:
|
||||
result["meta_info"] = {
|
||||
"name": meta_info.get("name"),
|
||||
"title": meta_info.get("title"),
|
||||
"year": meta_info.get("year"),
|
||||
"type": meta_info.get("type"),
|
||||
"begin_season": meta_info.get("begin_season"),
|
||||
"end_season": meta_info.get("end_season"),
|
||||
"begin_episode": meta_info.get("begin_episode"),
|
||||
"end_episode": meta_info.get("end_episode"),
|
||||
"total_episode": meta_info.get("total_episode"),
|
||||
"part": meta_info.get("part"),
|
||||
"season_episode": meta_info.get("season_episode"),
|
||||
"episode_list": meta_info.get("episode_list"),
|
||||
"tmdbid": meta_info.get("tmdbid"),
|
||||
"doubanid": meta_info.get("doubanid")
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
53
app/agent/tools/impl/run_scheduler.py
Normal file
53
app/agent/tools/impl/run_scheduler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""运行定时服务工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
job_id = kwargs.get("job_id", "")
|
||||
return f"正在运行定时服务 (ID: {job_id})"
|
||||
|
||||
async def run(self, job_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: job_id={job_id}")
|
||||
|
||||
try:
|
||||
scheduler = Scheduler()
|
||||
# 检查定时服务是否存在
|
||||
schedulers = scheduler.list()
|
||||
job_exists = False
|
||||
job_name = None
|
||||
for s in schedulers:
|
||||
if s.id == job_id:
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
72
app/agent/tools/impl/run_workflow.py
Normal file
72
app/agent/tools/impl/run_workflow.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""执行工作流工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_identifier: str = Field(..., description="Workflow identifier: can be workflow ID (integer as string) or workflow name")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually. Can run workflow by ID or name. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_identifier = kwargs.get("workflow_identifier", "")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
message = f"正在执行工作流: {workflow_identifier}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_identifier: str,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_identifier={workflow_identifier}, from_begin={from_begin}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
|
||||
# 尝试解析为工作流ID
|
||||
workflow = None
|
||||
if workflow_identifier.isdigit():
|
||||
# 如果是数字,尝试作为工作流ID查询
|
||||
workflow = await workflow_oper.async_get(int(workflow_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称查询
|
||||
if not workflow:
|
||||
workflow = await workflow_oper.async_get_by_name(workflow_identifier)
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_identifier},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
119
app/agent/tools/impl/scrape_metadata.py
Normal file
119
app/agent/tools/impl/scrape_metadata.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""刮削媒体元数据工具"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
|
||||
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据刮削参数生成友好的提示消息"""
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
overwrite = kwargs.get("overwrite", False)
|
||||
|
||||
message = f"正在刮削媒体元数据: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
if overwrite:
|
||||
message += " [覆盖模式]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
scrape_path = Path(path)
|
||||
meta = MetaInfoPath(scrape_path)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: media_chain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
overwrite=overwrite
|
||||
)
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
@@ -24,9 +24,26 @@ class SearchMediaInput(BaseModel):
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
|
||||
description: str = "Search TMDB database for media resources (movies, TV shows, anime, etc.) by title, year, type, and other criteria. Returns detailed media information from TMDB. Use 'recognize_media' to extract info from torrent titles/file paths, or 'scrape_metadata' to generate metadata files."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在搜索媒体: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
@@ -34,17 +51,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 构建搜索标题
|
||||
search_title = title
|
||||
if year:
|
||||
search_title = f"{title} {year}"
|
||||
if media_type:
|
||||
search_title = f"{search_title} {media_type}"
|
||||
if season:
|
||||
search_title = f"{search_title} S{season:02d}"
|
||||
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=search_title)
|
||||
meta, results = await media_chain.async_search(title=title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
@@ -60,9 +68,9 @@ class SearchMediaTool(MoviePilotTool):
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多20条结果
|
||||
# 限制最多30条结果
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:20]
|
||||
limited_results = filtered_results[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
@@ -83,8 +91,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
|
||||
83
app/agent/tools/impl/search_person.py
Normal file
83
app/agent/tools/impl/search_person.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""搜索人物工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonInput(BaseModel):
|
||||
"""搜索人物工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')")
|
||||
|
||||
|
||||
class SearchPersonTool(MoviePilotTool):
|
||||
name: str = "search_person"
|
||||
description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database."
|
||||
args_schema: Type[BaseModel] = SearchPersonInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
name = kwargs.get("name", "")
|
||||
return f"正在搜索人物: {name}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: name={name}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 使用 MediaChain.async_search_persons 方法搜索人物
|
||||
persons = await media_chain.async_search_persons(name=name)
|
||||
|
||||
if persons:
|
||||
# 限制最多30条结果
|
||||
total_count = len(persons)
|
||||
limited_persons = persons[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for person in limited_persons:
|
||||
simplified = {
|
||||
"name": person.name,
|
||||
"id": person.id,
|
||||
"source": person.source,
|
||||
"profile_path": person.profile_path,
|
||||
"original_name": person.original_name,
|
||||
"known_for_department": person.known_for_department,
|
||||
"popularity": person.popularity,
|
||||
"biography": person.biography[:200] + "..." if person.biography and len(person.biography) > 200 else person.biography,
|
||||
"birthday": person.birthday,
|
||||
"deathday": person.deathday,
|
||||
"place_of_birth": person.place_of_birth,
|
||||
"gender": person.gender,
|
||||
"imdb_id": person.imdb_id,
|
||||
"also_known_as": person.also_known_as[:5] if person.also_known_as else [], # 限制别名数量
|
||||
}
|
||||
# 添加豆瓣特有字段
|
||||
if person.source == "douban":
|
||||
simplified["url"] = person.url
|
||||
simplified["avatar"] = person.avatar
|
||||
simplified["latin_name"] = person.latin_name
|
||||
simplified["roles"] = person.roles[:5] if person.roles else [] # 限制角色数量
|
||||
# 添加Bangumi特有字段
|
||||
if person.source == "bangumi":
|
||||
simplified["career"] = person.career
|
||||
simplified["relation"] = person.relation
|
||||
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索人物失败: {str(e)}"
|
||||
logger.error(f"搜索人物失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
85
app/agent/tools/impl/search_person_credits.py
Normal file
85
app/agent/tools/impl/search_person_credits.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""搜索演员参演作品工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonCreditsInput(BaseModel):
|
||||
"""搜索演员参演作品工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)")
|
||||
source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
|
||||
|
||||
class SearchPersonCreditsTool(MoviePilotTool):
|
||||
name: str = "search_person_credits"
|
||||
description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in."
|
||||
args_schema: Type[BaseModel] = SearchPersonCreditsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
person_id = kwargs.get("person_id", "")
|
||||
source = kwargs.get("source", "")
|
||||
return f"正在搜索人物参演作品: {source} ID {person_id}"
|
||||
|
||||
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")
|
||||
|
||||
try:
|
||||
# 根据source选择相应的chain
|
||||
if source.lower() == "tmdb":
|
||||
tmdb_chain = TmdbChain()
|
||||
medias = await tmdb_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "douban":
|
||||
douban_chain = DoubanChain()
|
||||
medias = await douban_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "bangumi":
|
||||
bangumi_chain = BangumiChain()
|
||||
medias = await bangumi_chain.async_person_credits(person_id=person_id)
|
||||
else:
|
||||
return f"不支持的数据源: {source}。支持的数据源: tmdb, douban, bangumi"
|
||||
|
||||
if medias:
|
||||
# 限制最多30条结果
|
||||
total_count = len(medias)
|
||||
limited_medias = medias[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for media in limited_medias:
|
||||
simplified = {
|
||||
"title": media.title,
|
||||
"en_title": media.en_title,
|
||||
"year": media.year,
|
||||
"type": media.type.value if media.type else None,
|
||||
"season": media.season,
|
||||
"tmdb_id": media.tmdb_id,
|
||||
"imdb_id": media.imdb_id,
|
||||
"douban_id": media.douban_id,
|
||||
"overview": media.overview[:200] + "..." if media.overview and len(media.overview) > 200 else media.overview,
|
||||
"vote_average": media.vote_average,
|
||||
"poster_path": media.poster_path,
|
||||
"backdrop_path": media.backdrop_path,
|
||||
"detail_link": media.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到人物 ID {person_id} ({source}) 的参演作品"
|
||||
except Exception as e:
|
||||
error_message = f"搜索演员参演作品失败: {str(e)}"
|
||||
logger.error(f"搜索演员参演作品失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
127
app/agent/tools/impl/search_subscribe.py
Normal file
127
app/agent/tools/impl/search_subscribe.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""搜索订阅缺失剧集工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
name: str = "search_subscribe"
|
||||
description: str = "Search for missing episodes/resources for a specific subscription. This tool will search torrent sites for the missing episodes of the subscription and automatically download matching resources. Use this when a user wants to search for missing episodes of a specific subscription."
|
||||
args_schema: Type[BaseModel] = SearchSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
# 先验证订阅是否存在
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = subscribe_oper.get(subscribe_id)
|
||||
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取订阅信息用于返回
|
||||
subscribe_info = {
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
"lack_episode": subscribe.lack_episode,
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
}
|
||||
|
||||
# 检查订阅状态
|
||||
if subscribe.state == "S":
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
|
||||
"subscribe": subscribe_info
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 如果提供了 filter_groups 参数,先更新订阅的规则组
|
||||
if filter_groups is not None:
|
||||
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
|
||||
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
|
||||
|
||||
# 调用 SubscribeChain 的 search 方法
|
||||
# search 方法是同步的,需要在异步环境中运行
|
||||
subscribe_chain = SubscribeChain()
|
||||
|
||||
# 在线程池中执行同步的搜索操作
|
||||
# 当 sid 有值时,state 参数会被忽略,直接处理该订阅
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscribe_chain.search(
|
||||
sid=subscribe_id,
|
||||
state='R', # 默认状态,当 sid 有值时此参数会被忽略
|
||||
manual=manual
|
||||
)
|
||||
)
|
||||
|
||||
# 重新获取订阅信息以获取更新后的状态
|
||||
updated_subscribe = subscribe_oper.get(subscribe_id)
|
||||
if updated_subscribe:
|
||||
subscribe_info.update({
|
||||
"state": updated_subscribe.state,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"last_update": updated_subscribe.last_update,
|
||||
"filter_groups": updated_subscribe.filter_groups
|
||||
})
|
||||
|
||||
# 如果提供了规则组,会在返回信息中显示
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
|
||||
"subscribe": subscribe_info
|
||||
}
|
||||
|
||||
if filter_groups is not None:
|
||||
result["message"] += f"(已应用规则组: {', '.join(filter_groups)})"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
|
||||
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
@@ -33,6 +33,26 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
filter_pattern = kwargs.get("filter_pattern")
|
||||
|
||||
message = f"正在搜索种子: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
if filter_pattern:
|
||||
message += f" 过滤: {filter_pattern}"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
|
||||
193
app/agent/tools/impl/search_web.py
Normal file
193
app/agent/tools/impl/search_web.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""搜索网络内容工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5, description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
name: str = "search_web"
|
||||
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
|
||||
args_schema: Type[BaseModel] = SearchWebInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 5)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
max_results: 最大返回结果数(默认5,最大10)
|
||||
|
||||
Returns:
|
||||
格式化的搜索结果JSON字符串
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
|
||||
# 使用DuckDuckGo API进行搜索
|
||||
search_results = await self._search_duckduckgo_api(query, max_results)
|
||||
|
||||
if not search_results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
|
||||
# 裁剪结果以避免占用过多上下文
|
||||
formatted_results = self._format_and_truncate_results(search_results, max_results)
|
||||
|
||||
result_json = json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索网络内容失败: {str(e)}"
|
||||
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
|
||||
"""
|
||||
使用DuckDuckGo API进行搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
try:
|
||||
# DuckDuckGo Instant Answer API
|
||||
api_url = "https://api.duckduckgo.com/"
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": "1",
|
||||
"skip_disambig": "1"
|
||||
}
|
||||
|
||||
# 使用代理(如果配置了)
|
||||
http_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
data = await http_utils.get_json(api_url, params=params)
|
||||
|
||||
results = []
|
||||
|
||||
if data:
|
||||
# 处理AbstractText(摘要)
|
||||
if data.get("AbstractText"):
|
||||
results.append({
|
||||
"title": data.get("Heading", query),
|
||||
"snippet": data.get("AbstractText", ""),
|
||||
"url": data.get("AbstractURL", ""),
|
||||
"source": "DuckDuckGo Abstract"
|
||||
})
|
||||
|
||||
# 处理RelatedTopics(相关主题)
|
||||
related_topics = data.get("RelatedTopics", [])
|
||||
for topic in related_topics[:max_results - len(results)]:
|
||||
if isinstance(topic, dict):
|
||||
text = topic.get("Text", "")
|
||||
first_url = topic.get("FirstURL", "")
|
||||
if text and first_url:
|
||||
# 提取标题(通常在" - "之前)
|
||||
title = text.split(" - ")[0] if " - " in text else text[:100]
|
||||
snippet = text
|
||||
|
||||
results.append({
|
||||
"title": title.strip(),
|
||||
"snippet": snippet,
|
||||
"url": first_url,
|
||||
"source": "DuckDuckGo Related"
|
||||
})
|
||||
|
||||
# 处理Results(搜索结果)
|
||||
api_results = data.get("Results", [])
|
||||
for result in api_results[:max_results - len(results)]:
|
||||
if isinstance(result, dict):
|
||||
title = result.get("Text", "")
|
||||
url = result.get("FirstURL", "")
|
||||
if title and url:
|
||||
results.append({
|
||||
"title": title,
|
||||
"snippet": result.get("Text", ""),
|
||||
"url": url,
|
||||
"source": "DuckDuckGo Results"
|
||||
})
|
||||
|
||||
return results[:max_results]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo API搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: list, max_results: int) -> dict:
|
||||
"""
|
||||
格式化并裁剪搜索结果以避免占用过多上下文
|
||||
|
||||
Args:
|
||||
results: 原始搜索结果列表
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
格式化后的结果字典
|
||||
"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 限制结果数量
|
||||
limited_results = results[:max_results]
|
||||
|
||||
for idx, result in enumerate(limited_results, 1):
|
||||
title = result.get("title", "")[:200] # 限制标题长度
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要,避免过长
|
||||
max_snippet_length = 300 # 每个摘要最多300字符
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本,移除多余的空白字符
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
|
||||
# 添加提示信息
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"注意:共找到 {len(results)} 条结果,为节省上下文空间,仅显示前 {max_results} 条结果。"
|
||||
|
||||
return formatted
|
||||
@@ -21,6 +21,20 @@ class SendMessageTool(MoviePilotTool):
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
|
||||
72
app/agent/tools/impl/test_site.py
Normal file
72
app/agent/tools/impl/test_site.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""测试站点连通性工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据测试参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
return f"正在测试站点连通性: {site_identifier}"
|
||||
|
||||
async def run(self, site_identifier: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
# 测试站点连通性
|
||||
status, message = site_chain.test(site.domain)
|
||||
|
||||
if status:
|
||||
return f"站点连通性测试成功:{site.name} ({site.domain})\n{message}"
|
||||
else:
|
||||
return f"站点连通性测试失败:{site.name} ({site.domain})\n{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
134
app/agent/tools/impl/transfer_file.py
Normal file
134
app/agent/tools/impl/transfer_file.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""整理文件或目录工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem, MediaType
|
||||
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
media_type: Optional[str] = Field(None, description="Media type: '电影' for films, '电视剧' for television series (optional, will be auto-detected if not specified)")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
)
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
# 处理媒体类型
|
||||
mtype = None
|
||||
if media_type:
|
||||
try:
|
||||
mtype = MediaType(media_type)
|
||||
except ValueError:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
target_path=target_path_obj,
|
||||
tmdbid=tmdbid,
|
||||
doubanid=doubanid,
|
||||
mtype=mtype,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
)
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
error_text = str(errormsg)
|
||||
return f"整理失败:{error_text}"
|
||||
else:
|
||||
if background:
|
||||
return f"整理任务已提交到后台运行:{file_path}"
|
||||
else:
|
||||
return f"整理成功:{file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
203
app/agent/tools/impl/update_site.py
Normal file
203
app/agent/tools/impl/update_site.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""更新站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update")
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
fields_updated.append("URL")
|
||||
if kwargs.get("pri") is not None:
|
||||
fields_updated.append("优先级")
|
||||
if kwargs.get("cookie"):
|
||||
fields_updated.append("Cookie")
|
||||
if kwargs.get("ua"):
|
||||
fields_updated.append("User-Agent")
|
||||
if kwargs.get("proxy") is not None:
|
||||
fields_updated.append("代理设置")
|
||||
if kwargs.get("is_active") is not None:
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
if ua is not None:
|
||||
site_dict["ua"] = ua
|
||||
if apikey is not None:
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
if filter is not None:
|
||||
site_dict["filter"] = filter
|
||||
if note is not None:
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
if limit_count is not None:
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
}
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
"name": updated_site.name,
|
||||
"domain": updated_site.domain,
|
||||
"url": updated_site.url,
|
||||
"pri": updated_site.pri,
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
88
app/agent/tools/impl/update_site_cookie.py
Normal file
88
app/agent/tools/impl/update_site_cookie.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""更新站点Cookie和UA工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: str, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
)
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
return f"站点【{site.name}】Cookie和UA更新失败\n错误原因:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
239
app/agent/tools/impl/update_subscribe.py
Normal file
239
app/agent/tools/impl/update_subscribe.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""更新订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update")
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
fields_updated.append("总集数")
|
||||
if kwargs.get("lack_episode") is not None:
|
||||
fields_updated.append("缺失集数")
|
||||
if kwargs.get("quality"):
|
||||
fields_updated.append("质量过滤")
|
||||
if kwargs.get("resolution"):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
if year is not None:
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
if resolution is not None:
|
||||
subscribe_dict["resolution"] = resolution
|
||||
if effect is not None:
|
||||
subscribe_dict["effect"] = effect
|
||||
if include is not None:
|
||||
subscribe_dict["include"] = include
|
||||
if exclude is not None:
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
if downloader is not None:
|
||||
subscribe_dict["downloader"] = downloader
|
||||
if save_path is not None:
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
if media_category is not None:
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
}
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
"name": updated_subscribe.name,
|
||||
"year": updated_subscribe.year,
|
||||
"type": updated_subscribe.type,
|
||||
"season": updated_subscribe.season,
|
||||
"state": updated_subscribe.state,
|
||||
"total_episode": updated_subscribe.total_episode,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
187
app/agent/tools/manager.py
Normal file
187
app/agent/tools/manager.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.input_schema = input_schema
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {e}", exc_info=True)
|
||||
self.tools = []
|
||||
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
else:
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
# 获取Pydantic模型的字段信息
|
||||
schema = args_schema.model_json_schema()
|
||||
|
||||
# 构建JSON Schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
# 转换字段类型
|
||||
field_type = field_info.get("type", "string")
|
||||
field_description = field_info.get("description", "")
|
||||
|
||||
# 处理可选字段
|
||||
if field_name not in schema.get("required", []):
|
||||
# 可选字段
|
||||
default_value = field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
# 处理枚举类型
|
||||
if "enum" in field_info:
|
||||
properties[field_name]["enum"] = field_info["enum"]
|
||||
|
||||
# 处理数组类型
|
||||
if field_type == "array" and "items" in field_info:
|
||||
properties[field_name]["items"] = field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
@@ -28,3 +28,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
|
||||
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
|
||||
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
|
||||
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
@@ -6,12 +6,13 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,8 +67,8 @@ def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
bangumiid: Annotated[int | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
# 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
"""
|
||||
@@ -76,9 +77,13 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
@@ -88,6 +93,7 @@ def add(
|
||||
media_info=mediainfo,
|
||||
torrent_info=torrentinfo
|
||||
)
|
||||
|
||||
did = DownloadChain().download_single(context=context, username=current_user.name,
|
||||
downloader=downloader, save_path=save_path, source="Manual")
|
||||
if not did:
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.helper.image import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
161
app/api/endpoints/mcp.py
Normal file
161
app/api/endpoints/mcp.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局工具管理器实例(单例模式,按用户ID缓存)
|
||||
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
|
||||
|
||||
|
||||
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例(按用户ID缓存)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
global _tools_managers
|
||||
# 使用用户ID作为缓存键
|
||||
cache_key = f"{user_id}_{session_id}"
|
||||
if cache_key not in _tools_managers:
|
||||
_tools_managers[cache_key] = MoviePilotToolsManager(
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
return _tools_managers[cache_key]
|
||||
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有可用的工具列表
|
||||
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
tools_list.append(tool_dict)
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
result=result_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {request.tool_name} 失败: {e}", exc_info=True)
|
||||
return schemas.ToolCallResponse(
|
||||
success=False,
|
||||
error=f"调用工具失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}", summary="获取工具详情", response_model=Dict[str, Any])
|
||||
async def get_tool_info(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的详细信息
|
||||
|
||||
Returns:
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具信息失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}/schema", summary="获取工具参数Schema", response_model=Dict[str, Any])
|
||||
async def get_tool_schema(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的参数Schema(JSON Schema格式)
|
||||
|
||||
Returns:
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool.input_schema
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具Schema失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具Schema失败: {str(e)}")
|
||||
@@ -85,25 +85,26 @@ async def search(title: str,
|
||||
return obj.get("source")
|
||||
return obj.source
|
||||
|
||||
result = []
|
||||
media_chain = MediaChain()
|
||||
if type == "media":
|
||||
_, medias = await media_chain.async_search(title=title)
|
||||
if medias:
|
||||
result = [media.to_dict() for media in medias]
|
||||
result = [media.to_dict() for media in medias] if medias else []
|
||||
elif type == "collection":
|
||||
result = await media_chain.async_search_collections(name=title)
|
||||
else:
|
||||
result = await media_chain.async_search_persons(name=title)
|
||||
if result:
|
||||
# 按设置的顺序对结果进行排序
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') or []
|
||||
sort_order = {}
|
||||
for index, source in enumerate(setting_order):
|
||||
sort_order[source] = index
|
||||
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return result[(page - 1) * count:page * count]
|
||||
return []
|
||||
collections = await media_chain.async_search_collections(name=title)
|
||||
result = [collection.to_dict() for collection in collections] if collections else []
|
||||
else: # person
|
||||
persons = await media_chain.async_search_persons(name=title)
|
||||
result = [person.model_dump() for person in persons] if persons else []
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# 排序和分页
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') if settings.SEARCH_SOURCE else []
|
||||
sort_order = {source: index for index, source in enumerate(setting_order)}
|
||||
|
||||
sorted_result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return sorted_result[(page - 1) * count:page * count]
|
||||
|
||||
|
||||
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
|
||||
|
||||
@@ -219,10 +219,10 @@ async def read_userdata(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
if not user_data:
|
||||
user_datas = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
if not user_datas:
|
||||
return schemas.Response(success=False, data=[])
|
||||
return schemas.Response(success=True, data=user_data)
|
||||
return schemas.Response(success=True, data=[data.to_dict() for data in user_datas])
|
||||
|
||||
|
||||
@router.get("/test/{site_id}", summary="连接测试", response_model=schemas.Response)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Annotated
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
@@ -19,7 +16,6 @@ from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -29,12 +25,14 @@ from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
@@ -50,7 +48,7 @@ router = APIRouter()
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: bool = False,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
@@ -69,77 +67,24 @@ async def fetch_image(
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
if not cache_path.suffix:
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 缓存对像,缓存过期时间为全局图片缓存天数
|
||||
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
if use_cache:
|
||||
content = await cache_backend.get(cache_path.as_posix(), region="images")
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
content = await ImageHelper().async_fetch_image(
|
||||
url=url,
|
||||
proxy=proxy,
|
||||
use_cache=use_cache,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format for URL {url}: {e}")
|
||||
return None
|
||||
|
||||
# 获取请求响应头
|
||||
response_headers = response.headers
|
||||
cache_control_header = response_headers.get("Cache-Control", "")
|
||||
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
|
||||
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
await cache_backend.set(cache_path.as_posix(), content, region="images")
|
||||
logger.debug(f"Image cached at {cache_path.as_posix()}")
|
||||
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
if if_none_match == etag:
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(status_code=304, headers=headers)
|
||||
|
||||
# 响应
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
@@ -177,8 +122,7 @@ async def cache_img(
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
proxy = "doubanio.com" not in url
|
||||
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
|
||||
|
||||
@@ -247,13 +191,11 @@ async def set_env_setting(env: dict,
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
for key in success_updates.keys():
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=getattr(settings, key, None),
|
||||
change_type="update"
|
||||
))
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
@@ -338,6 +280,18 @@ async def set_setting(
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
from app.schemas import ExistMediaInfo, FileURI, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
ResourceDownloadEventData
|
||||
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
|
||||
ChainEventType
|
||||
@@ -162,7 +162,7 @@ class DownloadChain(ChainBase):
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、Subscribe、Manual等)
|
||||
:param downloader: 下载器
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
@@ -232,13 +232,14 @@ class DownloadChain(ChainBase):
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
|
||||
storage = 'local'
|
||||
# 下载目录
|
||||
if save_path:
|
||||
# 下载目录使用自定义的
|
||||
download_dir = Path(save_path)
|
||||
else:
|
||||
# 根据媒体信息查询下载目录配置
|
||||
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
|
||||
dir_info = DirectoryHelper().get_dir(_media, include_unsorted=True)
|
||||
storage = dir_info.storage if dir_info else storage
|
||||
# 拼装子目录
|
||||
if dir_info:
|
||||
# 一级目录
|
||||
@@ -259,6 +260,8 @@ class DownloadChain(ChainBase):
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
# 添加下载
|
||||
result: Optional[tuple] = self.download(content=torrent_content,
|
||||
@@ -400,7 +403,7 @@ class DownloadChain(ChainBase):
|
||||
根据缺失数据,自动种子列表中组合择优下载
|
||||
:param contexts: 资源上下文列表
|
||||
:param no_exists: 缺失的剧集信息
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、订阅、手工下载等)
|
||||
:param userid: 用户ID
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
from app.agent import agent_manager
|
||||
@@ -8,7 +10,7 @@ from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import settings
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
@@ -35,6 +37,10 @@ class MessageChain(ChainBase):
|
||||
_cache_file = "__user_messages__"
|
||||
# 每页数据量
|
||||
_page_size: int = 8
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 15
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
@@ -158,19 +164,15 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
# 处理消息
|
||||
if text.startswith('CALLBACK:'):
|
||||
# 处理按钮回调(适配支持回调的渠道)
|
||||
# 处理按钮回调(适配支持回调的渠),优先级最高
|
||||
if ChannelCapabilityManager.supports_callbacks(channel):
|
||||
self._handle_callback(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username,
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
elif text.startswith('/') and not text.lower().startswith('/ai'):
|
||||
# 执行特定命令命令(但不是/ai)
|
||||
self.eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
@@ -180,265 +182,226 @@ class MessageChain(ChainBase):
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
elif text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
elif text.lower().startswith('/ai'):
|
||||
# 用户指定AI智能体消息响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
# 普通消息,全局智能体响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
else:
|
||||
# 非智能体普通消息响应
|
||||
if text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(
|
||||
**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if _current_page == 0:
|
||||
# 第一页
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
try:
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
# 第一页
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
@@ -446,93 +409,144 @@ class MessageChain(ChainBase):
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
|
||||
userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
finally:
|
||||
user_cache.clear()
|
||||
del user_cache
|
||||
@@ -822,8 +836,86 @@ class MessageChain(ChainBase):
|
||||
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_session_id(userid: Union[str, int]) -> str:
|
||||
"""
|
||||
获取或创建会话ID
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
@staticmethod
|
||||
def clear_user_session(userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
返回是否成功清除
|
||||
"""
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def remote_clear_session(self, channel: MessageChannel, userid: Union[str, int], source: Optional[str] = None):
|
||||
"""
|
||||
清除用户会话(远程命令接口)
|
||||
"""
|
||||
# 获取并清除会话信息
|
||||
session_id = None
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="智能体会话已清除,下次将创建新的会话",
|
||||
userid=userid
|
||||
))
|
||||
else:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="您当前没有活跃的智能体会话",
|
||||
userid=userid
|
||||
))
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
@@ -839,19 +931,11 @@ class MessageChain(ChainBase):
|
||||
))
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助未配置,请在系统设置中配置"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if text.lower().startswith("/ai"):
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
@@ -862,45 +946,22 @@ class MessageChain(ChainBase):
|
||||
))
|
||||
return
|
||||
|
||||
# 发送处理中消息
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot助手已收到您的请求,请稍候..."
|
||||
))
|
||||
# 生成或复用会话ID
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
|
||||
# 生成会话ID
|
||||
session_id = f"user_{userid}_{hash(user_message) % 10000}"
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached, FileCache
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -103,40 +99,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
# 生成缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
|
||||
cache_backend = FileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
# 本地存在缓存图片,则直接跳过
|
||||
if cache_backend.get(cache_path.as_posix(), region="images"):
|
||||
logger.debug(f"Cache hit: Image already exists at {cache_path}")
|
||||
return
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if not referer else None
|
||||
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
|
||||
if not response:
|
||||
logger.debug(f"Empty response for URL: {url}")
|
||||
return
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
Image.open(io.BytesIO(response.content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
return
|
||||
|
||||
# 保存缓存
|
||||
cache_backend.set(cache_path.as_posix(), response.content, region="images")
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
ImageHelper().fetch_image(url=url)
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
|
||||
@@ -949,7 +949,7 @@ class SubscribeChain(ChainBase):
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Union, Dict, Optional
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.chain.site import SiteChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.system import SystemChain
|
||||
@@ -140,6 +141,12 @@ class Command(metaclass=Singleton):
|
||||
"description": "当前版本",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
},
|
||||
"/clear_session": {
|
||||
"func": MessageChain().remote_clear_session,
|
||||
"description": "清除会话",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
}
|
||||
}
|
||||
# 插件命令集合
|
||||
|
||||
@@ -1024,13 +1024,11 @@ def fresh(fresh: bool = True):
|
||||
with fresh():
|
||||
result = some_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
|
||||
token = _fresh.set(fresh or is_fresh())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset fresh mode. {id(token):#x}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_fresh(fresh: bool = True):
|
||||
@@ -1041,13 +1039,11 @@ async def async_fresh(fresh: bool = True):
|
||||
async with async_fresh():
|
||||
result = await some_async_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
|
||||
token = _fresh.set(fresh or is_fresh())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
|
||||
|
||||
def is_fresh() -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
@@ -6,6 +7,7 @@ import re
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
@@ -409,6 +411,8 @@ class ConfigModel(BaseModel):
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# 合局AI智能体
|
||||
AI_AGENT_GLOBAL: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
@@ -852,6 +856,8 @@ class GlobalVar(object):
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
# 当前事件循环
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -916,6 +922,19 @@ class GlobalVar(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def loop(self) -> AbstractEventLoop:
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
"""
|
||||
设置循环
|
||||
"""
|
||||
self.CURRENT_EVENT_LOOP = loop
|
||||
|
||||
|
||||
# 全局标识
|
||||
global_vars = GlobalVar()
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ChainEventData
|
||||
@@ -90,8 +91,6 @@ class EventManager(metaclass=Singleton):
|
||||
self.__lock = threading.Lock()
|
||||
# 退出事件
|
||||
self.__event = threading.Event()
|
||||
# 当前事件循环
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -454,7 +453,7 @@ class EventManager(metaclass=Singleton):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
self.loop
|
||||
global_vars.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
|
||||
@@ -6,11 +6,11 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -20,7 +20,7 @@ from watchfiles import watch
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
@@ -28,16 +28,16 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
"""
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -250,20 +250,12 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
return self._plugins
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
|
||||
return
|
||||
logger.info("配置变更,重新加载插件文件修改监测...")
|
||||
def on_config_changed(self):
|
||||
self.reload_monitor()
|
||||
|
||||
def get_reload_name(self) -> str:
|
||||
return "插件文件修改监测"
|
||||
|
||||
def reload_monitor(self):
|
||||
"""
|
||||
重新加载插件文件修改监测
|
||||
|
||||
@@ -261,7 +261,7 @@ def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx,或请求头中获取 X-API-KEY=xxx
|
||||
:return: 返回校验通过的 API Key
|
||||
"""
|
||||
return __verify_key(apikey, settings.API_TOKEN, "apikey")
|
||||
|
||||
@@ -65,6 +65,14 @@ class MediaServerItem(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
|
||||
if not mtype and not year:
|
||||
return db.query(cls).filter(cls.title == title).first()
|
||||
elif not year:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype).first()
|
||||
elif not mtype:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == str(year)).first()
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)).first()
|
||||
@@ -85,7 +93,16 @@ class MediaServerItem(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
if not mtype and not year:
|
||||
result = await db.execute(select(cls).filter(cls.title == title))
|
||||
elif not year:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype))
|
||||
elif not mtype:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.year == str(year)))
|
||||
else:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)))
|
||||
return result.scalars().first()
|
||||
|
||||
@@ -29,6 +29,12 @@ class SiteOper(DbOper):
|
||||
"""
|
||||
return Site.get(self._db, sid)
|
||||
|
||||
async def async_get(self, sid: int) -> Site:
|
||||
"""
|
||||
异步查询单个站点
|
||||
"""
|
||||
return await Site.async_get(self._db, sid)
|
||||
|
||||
def list(self) -> List[Site]:
|
||||
"""
|
||||
获取站点列表
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
@@ -9,7 +9,7 @@ from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?}}", re.DOTALL)
|
||||
|
||||
|
||||
class DirectoryHelper:
|
||||
@@ -51,7 +51,7 @@ class DirectoryHelper:
|
||||
"""
|
||||
return [d for d in self.get_library_dirs() if d.library_storage == "local"]
|
||||
|
||||
def get_dir(self, media: MediaInfo, include_unsorted: Optional[bool] = False,
|
||||
def get_dir(self, media: Optional[MediaInfo], include_unsorted: Optional[bool] = False,
|
||||
storage: Optional[str] = None, src_path: Path = None,
|
||||
target_storage: Optional[str] = None, dest_path: Path = None
|
||||
) -> Optional[schemas.TransferDirectoryConf]:
|
||||
@@ -64,11 +64,8 @@ class DirectoryHelper:
|
||||
:param src_path: 源目录,有值时直接匹配
|
||||
:param dest_path: 目标目录,有值时直接匹配
|
||||
"""
|
||||
# 处理类型
|
||||
if not media:
|
||||
return None
|
||||
# 电影/电视剧
|
||||
media_type = media.type.value
|
||||
media_type = media.type.value if media else None
|
||||
dirs = self.get_dirs()
|
||||
|
||||
# 如果存在源目录,并源目录为任一下载目录的子目录时,则进行源目录匹配,否则,允许源目录按同盘优先的逻辑匹配
|
||||
@@ -93,7 +90,7 @@ class DirectoryHelper:
|
||||
if dest_path and dest_path != Path(d.library_path):
|
||||
continue
|
||||
# 目录类型为全部的,符合条件
|
||||
if not d.media_type:
|
||||
if not media_type or not d.media_type:
|
||||
matched_dirs.append(d)
|
||||
continue
|
||||
# 目录类型相等,目录类别为全部,符合条件
|
||||
@@ -109,11 +106,27 @@ class DirectoryHelper:
|
||||
# 优先源目录同盘
|
||||
for matched_dir in matched_dirs:
|
||||
matched_path = Path(matched_dir.download_path)
|
||||
if SystemUtils.is_same_disk(matched_path, src_path):
|
||||
if self._is_same_source((src_path, storage or "local"), (matched_path, matched_dir.library_storage)):
|
||||
return matched_dir
|
||||
return matched_dirs[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_same_source(src: Tuple[Path, str], tar: Tuple[Path, str]) -> bool:
|
||||
"""
|
||||
判断源目录和目标目录是否在同一存储盘
|
||||
|
||||
:param src: 源目录路径和存储类型
|
||||
:param tar: 目标目录路径和存储类型
|
||||
:return: 是否在同一存储盘
|
||||
"""
|
||||
src_path, src_storage = src
|
||||
tar_path, tar_storage = tar
|
||||
if "local" == tar_storage == src_storage:
|
||||
return SystemUtils.is_same_disk(src_path, tar_path)
|
||||
# 网络存储,直接比较类型
|
||||
return src_storage == tar_storage
|
||||
|
||||
@staticmethod
|
||||
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
|
||||
"""
|
||||
|
||||
@@ -14,10 +14,8 @@ from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 定义一个全局线程池执行器
|
||||
@@ -69,25 +67,23 @@ def enable_doh(enable: bool):
|
||||
socket.getaddrinfo = _orig_getaddrinfo
|
||||
|
||||
|
||||
class DohHelper(metaclass=Singleton):
|
||||
class DohHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
DoH帮助类,用于处理DNS over HTTPS解析。
|
||||
"""
|
||||
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
|
||||
|
||||
def __init__(self):
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ["DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"]:
|
||||
return
|
||||
def on_config_changed(self):
|
||||
with _doh_lock:
|
||||
# DOH配置有变动的情况下,清空缓存
|
||||
_doh_cache.clear()
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
def get_reload_name(self):
|
||||
return 'DoH'
|
||||
|
||||
def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, FileCache, AsyncFileCache
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.ip import IpUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -161,3 +168,121 @@ class WallpaperHelper(metaclass=Singleton):
|
||||
return wallpaper_list
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class ImageHelper(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
_base_path = settings.CACHE_PATH
|
||||
_ttl = settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600
|
||||
self.file_cache = FileCache(base=_base_path, ttl=_ttl)
|
||||
self.async_file_cache = AsyncFileCache(base=_base_path, ttl=_ttl)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_cache_path(url: str) -> str:
|
||||
"""缓存路径"""
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path(sanitized_path)
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
return cache_path.as_posix()
|
||||
|
||||
@staticmethod
|
||||
def _validate_image(content: bytes) -> bool:
|
||||
"""验证图片"""
|
||||
if not content:
|
||||
return False
|
||||
try:
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_request_params(url: str, proxy: Optional[bool], cookies: Optional[str | dict]) -> dict:
|
||||
"""获取参数"""
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
if proxy is None:
|
||||
proxies = settings.PROXY if not (referer or IpUtils.is_internal(url)) else None
|
||||
else:
|
||||
proxies = settings.PROXY if proxy else None
|
||||
return {
|
||||
"ua": settings.NORMAL_USER_AGENT,
|
||||
"proxies": proxies,
|
||||
"referer": referer,
|
||||
"cookies": cookies,
|
||||
"accept_type": "image/avif,image/webp,image/apng,*/*",
|
||||
}
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(同步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = self.file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = RequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
self.file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
|
||||
async def async_fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(异步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = await self.async_file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = await AsyncRequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
await self.async_file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
44
app/helper/llm.py
Normal file
44
app/helper/llm.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
return self._get_google_models(api_key)
|
||||
else:
|
||||
return self._get_openai_compatible_models(provider, api_key, base_url)
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表"""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=api_key)
|
||||
models = genai.list_models()
|
||||
return [m.name for m in models if 'generateContent' in m.supported_generation_methods]
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
if provider == "deepseek":
|
||||
base_url = base_url or "https://api.deepseek.com"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
models = client.models.list()
|
||||
return [model.id for model in models.data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
||||
raise e
|
||||
@@ -7,10 +7,8 @@ import redis
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 类型缓存集合,针对非容器简单类型
|
||||
@@ -74,16 +72,17 @@ def deserialize(value: bytes) -> Any:
|
||||
raise ValueError("Unknown serialization format")
|
||||
|
||||
|
||||
class RedisHelper(metaclass=Singleton):
|
||||
class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
- 支持内存限制和淘汰策略设置
|
||||
- 提供键名生成和区域管理功能
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -114,25 +113,17 @@ class RedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis...")
|
||||
def on_config_changed(self):
|
||||
self.close()
|
||||
self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis"
|
||||
|
||||
def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
@@ -310,10 +301,10 @@ class RedisHelper(metaclass=Singleton):
|
||||
logger.debug("Redis connection closed")
|
||||
|
||||
|
||||
class AsyncRedisHelper(metaclass=Singleton):
|
||||
class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
异步Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理异步Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
@@ -321,6 +312,7 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
- 提供键名生成和区域管理功能
|
||||
- 所有操作都是异步的
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -351,25 +343,17 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis async connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
async def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis (async)...")
|
||||
async def on_config_changed(self):
|
||||
await self.close()
|
||||
await self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis (async)"
|
||||
|
||||
async def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
|
||||
@@ -8,35 +8,32 @@ from typing import Tuple
|
||||
import docker
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class SystemHelper:
|
||||
class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
系统工具类,提供系统相关的操作和判断
|
||||
"""
|
||||
CONFIG_WATCH = {
|
||||
"DEBUG",
|
||||
"LOG_LEVEL",
|
||||
"LOG_MAX_FILE_SIZE",
|
||||
"LOG_BACKUP_COUNT",
|
||||
"LOG_FILE_FORMAT",
|
||||
"LOG_CONSOLE_FORMAT",
|
||||
}
|
||||
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新日志设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEBUG', 'LOG_LEVEL', 'LOG_MAX_FILE_SIZE', 'LOG_BACKUP_COUNT',
|
||||
'LOG_FILE_FORMAT', 'LOG_CONSOLE_FORMAT']:
|
||||
return
|
||||
logger.info("配置变更,更新日志设置...")
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "日志设置"
|
||||
|
||||
@staticmethod
|
||||
def can_restart() -> bool:
|
||||
"""
|
||||
|
||||
@@ -6,8 +6,7 @@ from urllib.parse import unquote
|
||||
|
||||
from torrentool.api import Torrent
|
||||
|
||||
from app.core.cache import FileCache
|
||||
from app.core.cache import TTLCache
|
||||
from app.core.cache import TTLCache, FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, TorrentInfo, MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas import Notification, NotificationConf, MediaServerConf, DownloaderConf
|
||||
from app.schemas.types import ModuleType, DownloaderType, MediaServerType, MessageChannel, StorageSchema, \
|
||||
OtherModulesType
|
||||
OtherModulesType, SystemConfigKey
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
|
||||
|
||||
class _ModuleBase(metaclass=ABCMeta):
|
||||
class _ModuleBase(ConfigReloadMixin, metaclass=ABCMeta):
|
||||
"""
|
||||
模块基类,实现对应方法,在有需要时会被自动调用,返回None代表不启用该模块,将继续执行下一模块
|
||||
输入参数与输出参数一致的,或没有输出的,可以被多个模块重复实现
|
||||
"""
|
||||
|
||||
def on_config_changed(self):
|
||||
self.init_module()
|
||||
|
||||
def get_reload_name(self):
|
||||
return self.get_name()
|
||||
|
||||
@abstractmethod
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -177,6 +185,7 @@ class _MessageBase(ServiceBase[TService, NotificationConf]):
|
||||
"""
|
||||
消息基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Notifications.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -224,6 +233,7 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
"""
|
||||
下载器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Downloaders.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -281,12 +291,37 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
重置默认配置名称
|
||||
"""
|
||||
self._default_config_name = None
|
||||
|
||||
def normalize_path(self, path: Path, downloader: Optional[str]) -> str:
|
||||
"""
|
||||
根据下载器配置和路径映射,规范化下载路径
|
||||
|
||||
:param path: 存储路径
|
||||
:param downloader: 下载器名称
|
||||
:return: 规范化后发送给下载器的路径
|
||||
"""
|
||||
dir = path.as_posix()
|
||||
conf = self.get_config(downloader)
|
||||
if conf and conf.path_mapping:
|
||||
for (storage_path, download_path) in conf.path_mapping:
|
||||
storage_path = Path(storage_path.strip()).as_posix()
|
||||
download_path = Path(download_path.strip()).as_posix()
|
||||
if dir.startswith(storage_path):
|
||||
dir = dir.replace(storage_path, download_path, 1)
|
||||
break
|
||||
# 去掉存储协议前缀 if any, 下载器无法识别
|
||||
for s in StorageSchema:
|
||||
prefix = f"{s.value}:"
|
||||
if dir.startswith(prefix):
|
||||
return dir[len(prefix):]
|
||||
return dir
|
||||
|
||||
|
||||
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
|
||||
"""
|
||||
媒体服务器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.MediaServers.value}
|
||||
|
||||
def get_configs(self) -> Dict[str, MediaServerConf]:
|
||||
"""
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.emby.emby import Emby
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
|
||||
@@ -18,20 +18,6 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
|
||||
super().init_service(service_name=Emby.__name__.lower(),
|
||||
service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Emby模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Emby"
|
||||
|
||||
@@ -640,7 +640,7 @@ class Emby:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
year=item.get("ProductionYear"),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
|
||||
@@ -37,6 +37,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
transtype = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
}
|
||||
|
||||
# 文件块大小,默认10MB
|
||||
@@ -635,7 +636,39 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
return False
|
||||
|
||||
def link(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
|
||||
pass
|
||||
"""
|
||||
硬链接文件
|
||||
Samba服务器需要开启 unix extensions 支持
|
||||
"""
|
||||
try:
|
||||
self._check_connection()
|
||||
src_path = self._normalize_path(fileitem.path)
|
||||
dst_path = self._normalize_path(target_file)
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not smbclient.path.exists(src_path):
|
||||
raise FileNotFoundError(f"源文件不存在: {src_path}")
|
||||
|
||||
# 确保目标路径的父目录存在
|
||||
dst_parent = "\\".join(dst_path.rsplit("\\", 1)[:-1])
|
||||
if dst_parent and not smbclient.path.exists(dst_parent):
|
||||
logger.info(f"【SMB】创建目标目录: {dst_parent}")
|
||||
smbclient.makedirs(dst_parent, exist_ok=True)
|
||||
|
||||
# 尝试创建硬链接
|
||||
smbclient.link(src_path, dst_path)
|
||||
logger.info(f"【SMB】硬链接创建成功: {src_path} -> {dst_path}")
|
||||
return True
|
||||
|
||||
except SMBResponseException as e:
|
||||
# SMB协议错误,可能不支持硬链接
|
||||
logger.error(f"【SMB】创建硬链接失败(当前Samba服务器可能不支持硬链接): {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】创建硬链接失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def softlink(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from hashlib import sha256
|
||||
|
||||
import oss2
|
||||
import requests
|
||||
import httpx
|
||||
from oss2 import SizedFileAdapter, determine_part_size
|
||||
from oss2.models import PartInfo
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from app import schemas
|
||||
from app.core.config import settings, global_vars
|
||||
@@ -19,8 +20,10 @@ from app.modules.filemanager.storages import transfer_process
|
||||
from app.schemas.types import StorageSchema
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.limit import QpsRateLimiter
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
lock = Lock()
|
||||
|
||||
|
||||
class NoCheckInException(Exception):
|
||||
@@ -36,10 +39,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
schema = StorageSchema.U115
|
||||
|
||||
# 支持的整理方式
|
||||
transtype = {
|
||||
"move": "移动",
|
||||
"copy": "复制"
|
||||
}
|
||||
transtype = {"move": "移动", "copy": "复制"}
|
||||
# 基础url
|
||||
base_url = "https://proapi.115.com"
|
||||
|
||||
@@ -52,18 +52,28 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._auth_state = {}
|
||||
self.session = requests.Session()
|
||||
self.session = httpx.Client(follow_redirects=True, timeout=20.0)
|
||||
self._init_session()
|
||||
self.qps_limiter: Dict[str, QpsRateLimiter] = {
|
||||
"/open/ufile/files": QpsRateLimiter(4),
|
||||
"/open/folder/get_info": QpsRateLimiter(3),
|
||||
"/open/ufile/move": QpsRateLimiter(2),
|
||||
"/open/ufile/copy": QpsRateLimiter(2),
|
||||
"/open/ufile/update": QpsRateLimiter(2),
|
||||
"/open/ufile/delete": QpsRateLimiter(2),
|
||||
}
|
||||
|
||||
def _init_session(self):
|
||||
"""
|
||||
初始化带速率限制的会话
|
||||
"""
|
||||
self.session.headers.update({
|
||||
"User-Agent": "W115Storage/2.0",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
})
|
||||
self.session.headers.update(
|
||||
{
|
||||
"User-Agent": "W115Storage/2.0",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
)
|
||||
|
||||
def _check_session(self):
|
||||
"""
|
||||
@@ -87,10 +97,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
if expires_in and refresh_time + expires_in < int(time.time()):
|
||||
tokens = self.__refresh_access_token(refresh_token)
|
||||
if tokens:
|
||||
self.set_config({
|
||||
"refresh_time": int(time.time()),
|
||||
**tokens
|
||||
})
|
||||
self.set_config({"refresh_time": int(time.time()), **tokens})
|
||||
else:
|
||||
return None
|
||||
access_token = tokens.get("access_token")
|
||||
@@ -105,7 +112,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 生成PKCE参数
|
||||
code_verifier = secrets.token_urlsafe(96)[:128]
|
||||
code_challenge = base64.b64encode(
|
||||
hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
sha256(code_verifier.encode("utf-8")).digest()
|
||||
).decode("utf-8")
|
||||
# 请求设备码
|
||||
resp = self.session.post(
|
||||
@@ -113,8 +120,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
data={
|
||||
"client_id": settings.U115_APP_ID,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "sha256"
|
||||
}
|
||||
"code_challenge_method": "sha256",
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
return {}, "网络错误"
|
||||
@@ -126,13 +133,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"code_verifier": code_verifier,
|
||||
"uid": result["data"]["uid"],
|
||||
"time": result["data"]["time"],
|
||||
"sign": result["data"]["sign"]
|
||||
"sign": result["data"]["sign"],
|
||||
}
|
||||
|
||||
# 生成二维码内容
|
||||
return {
|
||||
"codeContent": result['data']['qrcode']
|
||||
}, ""
|
||||
return {"codeContent": result["data"]["qrcode"]}, ""
|
||||
|
||||
def check_login(self) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
@@ -146,8 +151,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
params={
|
||||
"uid": self._auth_state["uid"],
|
||||
"time": self._auth_state["time"],
|
||||
"sign": self._auth_state["sign"]
|
||||
}
|
||||
"sign": self._auth_state["sign"],
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
return {}, "网络错误"
|
||||
@@ -156,11 +161,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return {}, result.get("message")
|
||||
if result["data"]["status"] == 2:
|
||||
tokens = self.__get_access_token()
|
||||
self.set_config({
|
||||
"refresh_time": int(time.time()),
|
||||
**tokens
|
||||
})
|
||||
return {"status": result["data"]["status"], "tip": result["data"]["msg"]}, ""
|
||||
self.set_config({"refresh_time": int(time.time()), **tokens})
|
||||
return {
|
||||
"status": result["data"]["status"],
|
||||
"tip": result["data"]["msg"],
|
||||
}, ""
|
||||
except Exception as e:
|
||||
return {}, str(e)
|
||||
|
||||
@@ -174,8 +179,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"https://passportapi.115.com/open/deviceCodeToToken",
|
||||
data={
|
||||
"uid": self._auth_state["uid"],
|
||||
"code_verifier": self._auth_state["code_verifier"]
|
||||
}
|
||||
"code_verifier": self._auth_state["code_verifier"],
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
raise Exception("获取 access_token 失败")
|
||||
@@ -190,21 +195,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
resp = self.session.post(
|
||||
"https://passportapi.115.com/open/refreshToken",
|
||||
data={
|
||||
"refresh_token": refresh_token
|
||||
}
|
||||
data={"refresh_token": refresh_token},
|
||||
)
|
||||
if resp is None:
|
||||
logger.error(f"【115】刷新 access_token 失败:refresh_token={refresh_token}")
|
||||
logger.error(
|
||||
f"【115】刷新 access_token 失败:refresh_token={refresh_token}"
|
||||
)
|
||||
return None
|
||||
result = resp.json()
|
||||
if result.get("code") != 0:
|
||||
logger.warn(f"【115】刷新 access_token 失败:{result.get('code')} - {result.get('message')}!")
|
||||
logger.warn(
|
||||
f"【115】刷新 access_token 失败:{result.get('code')} - {result.get('message')}!"
|
||||
)
|
||||
return None
|
||||
return result.get("data")
|
||||
|
||||
def _request_api(self, method: str, endpoint: str,
|
||||
result_key: Optional[str] = None, **kwargs) -> Optional[Union[dict, list]]:
|
||||
def _request_api(
|
||||
self, method: str, endpoint: str, result_key: Optional[str] = None, **kwargs
|
||||
) -> Optional[Union[dict, list]]:
|
||||
"""
|
||||
带错误处理和速率限制的API请求
|
||||
"""
|
||||
@@ -216,12 +224,13 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 重试次数
|
||||
retry_times = kwargs.pop("retry_limit", 5)
|
||||
|
||||
# qps 速率限制
|
||||
if endpoint in self.qps_limiter:
|
||||
self.qps_limiter[endpoint].acquire()
|
||||
|
||||
try:
|
||||
resp = self.session.request(
|
||||
method, f"{self.base_url}{endpoint}",
|
||||
**kwargs
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"【115】{method} 请求 {endpoint} 网络错误: {str(e)}")
|
||||
return None
|
||||
|
||||
@@ -241,7 +250,21 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
|
||||
# 处理请求错误
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if retry_times <= 0:
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 错误 {e},重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
sleep_duration = 2 ** (5 - retry_times + 1)
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 错误 {e},等待 {sleep_duration} 秒后重试..."
|
||||
)
|
||||
time.sleep(sleep_duration)
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
|
||||
# 返回数据
|
||||
ret_data = resp.json()
|
||||
@@ -251,10 +274,14 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
|
||||
if "已达到当前访问上限" in error_msg:
|
||||
if retry_times <= 0:
|
||||
logger.error(f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!")
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
logger.info(f"【115】{method} 请求 {endpoint} 达到访问上限,等待 {self.retry_delay} 秒后重试...")
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,等待 {self.retry_delay} 秒后重试..."
|
||||
)
|
||||
time.sleep(self.retry_delay)
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
return None
|
||||
@@ -269,26 +296,15 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
计算文件SHA1(符合115规范)
|
||||
size: 前多少字节
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filepath, 'rb') as f:
|
||||
sha1 = hashes.Hash(hashes.SHA1())
|
||||
with open(filepath, "rb") as f:
|
||||
if size:
|
||||
chunk = f.read(size)
|
||||
sha1.update(chunk)
|
||||
else:
|
||||
while chunk := f.read(8192):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def _delay_get_item(self, path: Path) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
自动延迟重试 get_item 模块
|
||||
"""
|
||||
for i in range(1, 4):
|
||||
time.sleep(2 ** i)
|
||||
fileitem = self.get_item(path)
|
||||
if fileitem:
|
||||
return fileitem
|
||||
return None
|
||||
return sha1.finalize().hex()
|
||||
|
||||
def init_storage(self):
|
||||
pass
|
||||
@@ -304,7 +320,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return [item]
|
||||
return []
|
||||
if fileitem.path == "/":
|
||||
cid = '0'
|
||||
cid = "0"
|
||||
else:
|
||||
cid = fileitem.fileid
|
||||
if not cid:
|
||||
@@ -322,29 +338,37 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"GET",
|
||||
"/open/ufile/files",
|
||||
"data",
|
||||
params={"cid": int(cid), "limit": 1000, "offset": offset, "cur": True, "show_dir": 1}
|
||||
params={
|
||||
"cid": int(cid),
|
||||
"limit": 1000,
|
||||
"offset": offset,
|
||||
"cur": True,
|
||||
"show_dir": 1,
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
raise FileNotFoundError(f"【115】{fileitem.path} 检索出错!")
|
||||
if not resp:
|
||||
break
|
||||
for item in resp:
|
||||
# 更新缓存
|
||||
path = f"{fileitem.path}{item['fn']}"
|
||||
file_path = path + ("/" if item["fc"] == "0" else "")
|
||||
items.append(schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
fileid=str(item["fid"]),
|
||||
parent_fileid=cid,
|
||||
name=item["fn"],
|
||||
basename=Path(item["fn"]).stem,
|
||||
extension=item["ico"] if item["fc"] == "1" else None,
|
||||
type="dir" if item["fc"] == "0" else "file",
|
||||
path=file_path,
|
||||
size=item["fs"] if item["fc"] == "1" else None,
|
||||
modify_time=item["upt"],
|
||||
pickcode=item["pc"]
|
||||
))
|
||||
parent_path = Path(fileitem.path) # noqa
|
||||
item_name = item["fn"]
|
||||
full_path = parent_path / item_name
|
||||
items.append(
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
fileid=str(item["fid"]),
|
||||
parent_fileid=cid,
|
||||
name=item["fn"],
|
||||
basename=Path(item["fn"]).stem,
|
||||
extension=item["ico"] if item["fc"] == "1" else None,
|
||||
type="dir" if item["fc"] == "0" else "file",
|
||||
path=full_path.as_posix() + ("/" if item["fc"] == "0" else ""),
|
||||
size=item["fs"] if item["fc"] == "1" else None,
|
||||
modify_time=item["upt"],
|
||||
pickcode=item["pc"],
|
||||
)
|
||||
)
|
||||
|
||||
if len(resp) < 1000:
|
||||
break
|
||||
@@ -352,7 +376,9 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
return items
|
||||
|
||||
def create_folder(self, parent_item: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
|
||||
def create_folder(
|
||||
self, parent_item: schemas.FileItem, name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
创建目录
|
||||
"""
|
||||
@@ -360,10 +386,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/open/folder/add",
|
||||
data={
|
||||
"pid": int(parent_item.fileid or "0"),
|
||||
"file_name": name
|
||||
}
|
||||
data={"pid": int(parent_item.fileid or "0"), "file_name": name},
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
@@ -376,15 +399,19 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
fileid=str(resp["data"]["file_id"]),
|
||||
path=str(new_path) + "/",
|
||||
path=new_path.as_posix() + "/",
|
||||
name=name,
|
||||
basename=name,
|
||||
type="dir",
|
||||
modify_time=int(time.time())
|
||||
modify_time=int(time.time()),
|
||||
)
|
||||
|
||||
def upload(self, target_dir: schemas.FileItem, local_path: Path,
|
||||
new_name: Optional[str] = None) -> Optional[schemas.FileItem]:
|
||||
def upload(
|
||||
self,
|
||||
target_dir: schemas.FileItem,
|
||||
local_path: Path,
|
||||
new_name: Optional[str] = None,
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
实现带秒传、断点续传和二次认证的文件上传
|
||||
"""
|
||||
@@ -409,13 +436,9 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"file_size": file_size,
|
||||
"target": target_param,
|
||||
"fileid": file_sha1,
|
||||
"preid": file_preid
|
||||
"preid": file_preid,
|
||||
}
|
||||
init_resp = self._request_api(
|
||||
"POST",
|
||||
"/open/upload/init",
|
||||
data=init_data
|
||||
)
|
||||
init_resp = self._request_api("POST", "/open/upload/init", data=init_data)
|
||||
if not init_resp:
|
||||
return None
|
||||
if not init_resp.get("state"):
|
||||
@@ -444,19 +467,15 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 取2392148-2392298之间的内容(包含2392148、2392298)的sha1
|
||||
f.seek(start)
|
||||
chunk = f.read(end - start + 1)
|
||||
sign_val = hashlib.sha1(chunk).hexdigest().upper()
|
||||
sha1 = hashes.Hash(hashes.SHA1())
|
||||
sha1.update(chunk)
|
||||
sign_val = sha1.finalize().hex().upper()
|
||||
# 重新初始化请求
|
||||
# sign_key,sign_val(根据sign_check计算的值大写的sha1值)
|
||||
init_data.update({
|
||||
"pick_code": pick_code,
|
||||
"sign_key": sign_key,
|
||||
"sign_val": sign_val
|
||||
})
|
||||
init_resp = self._request_api(
|
||||
"POST",
|
||||
"/open/upload/init",
|
||||
data=init_data
|
||||
init_data.update(
|
||||
{"pick_code": pick_code, "sign_key": sign_key, "sign_val": sign_val}
|
||||
)
|
||||
init_resp = self._request_api("POST", "/open/upload/init", data=init_data)
|
||||
if not init_resp:
|
||||
return None
|
||||
if not init_resp.get("state"):
|
||||
@@ -485,32 +504,30 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"GET",
|
||||
"/open/folder/get_info",
|
||||
"data",
|
||||
params={
|
||||
"file_id": int(file_id)
|
||||
}
|
||||
params={"file_id": int(file_id)},
|
||||
)
|
||||
if info_resp:
|
||||
return schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
fileid=str(info_resp["file_id"]),
|
||||
path=str(target_path) + ("/" if info_resp["file_category"] == "0" else ""),
|
||||
path=target_path.as_posix()
|
||||
+ ("/" if info_resp["file_category"] == "0" else ""),
|
||||
type="file" if info_resp["file_category"] == "1" else "dir",
|
||||
name=info_resp["file_name"],
|
||||
basename=Path(info_resp["file_name"]).stem,
|
||||
extension=Path(info_resp["file_name"]).suffix[1:] if info_resp[
|
||||
"file_category"] == "1" else None,
|
||||
extension=Path(info_resp["file_name"]).suffix[1:]
|
||||
if info_resp["file_category"] == "1"
|
||||
else None,
|
||||
pickcode=info_resp["pick_code"],
|
||||
size=StringUtils.num_filesize(info_resp['size']) if info_resp["file_category"] == "1" else None,
|
||||
modify_time=info_resp["utime"]
|
||||
size=StringUtils.num_filesize(info_resp["size"])
|
||||
if info_resp["file_category"] == "1"
|
||||
else None,
|
||||
modify_time=info_resp["utime"],
|
||||
)
|
||||
return self._delay_get_item(target_path)
|
||||
return self.get_item(target_path)
|
||||
|
||||
# Step 4: 获取上传凭证
|
||||
token_resp = self._request_api(
|
||||
"GET",
|
||||
"/open/upload/get_token",
|
||||
"data"
|
||||
)
|
||||
token_resp = self._request_api("GET", "/open/upload/get_token", "data")
|
||||
if not token_resp:
|
||||
logger.warn("【115】获取上传凭证失败")
|
||||
return None
|
||||
@@ -530,8 +547,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"file_size": file_size,
|
||||
"target": target_param,
|
||||
"fileid": file_sha1,
|
||||
"pick_code": pick_code
|
||||
}
|
||||
"pick_code": pick_code,
|
||||
},
|
||||
)
|
||||
if resume_resp:
|
||||
logger.debug(f"【115】上传 Step 5 断点续传结果: {resume_resp}")
|
||||
@@ -542,25 +559,25 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
auth = oss2.StsAuth(
|
||||
access_key_id=AccessKeyId,
|
||||
access_key_secret=AccessKeySecret,
|
||||
security_token=SecurityToken
|
||||
security_token=SecurityToken,
|
||||
)
|
||||
bucket = oss2.Bucket(auth, endpoint, bucket_name) # noqa
|
||||
# determine_part_size方法用于确定分片大小,设置分片大小为 10M
|
||||
part_size = determine_part_size(file_size, preferred_size=10 * 1024 * 1024)
|
||||
|
||||
# 初始化进度条
|
||||
logger.info(f"【115】开始上传: {local_path} -> {target_path},分片大小:{StringUtils.str_filesize(part_size)}")
|
||||
logger.info(
|
||||
f"【115】开始上传: {local_path} -> {target_path},分片大小:{StringUtils.str_filesize(part_size)}"
|
||||
)
|
||||
progress_callback = transfer_process(local_path.as_posix())
|
||||
|
||||
# 初始化分片
|
||||
upload_id = bucket.init_multipart_upload(object_name,
|
||||
params={
|
||||
"encoding-type": "url",
|
||||
"sequential": ""
|
||||
}).upload_id
|
||||
upload_id = bucket.init_multipart_upload(
|
||||
object_name, params={"encoding-type": "url", "sequential": ""}
|
||||
).upload_id
|
||||
parts = []
|
||||
# 逐个上传分片
|
||||
with open(local_path, 'rb') as fileobj:
|
||||
with open(local_path, "rb") as fileobj:
|
||||
part_number = 1
|
||||
offset = 0
|
||||
while offset < file_size:
|
||||
@@ -569,9 +586,15 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return None
|
||||
num_to_upload = min(part_size, file_size - offset)
|
||||
# 调用SizedFileAdapter(fileobj, size)方法会生成一个新的文件对象,重新计算起始追加位置。
|
||||
logger.info(f"【115】开始上传 {target_name} 分片 {part_number}: {offset} -> {offset + num_to_upload}")
|
||||
result = bucket.upload_part(object_name, upload_id, part_number,
|
||||
data=SizedFileAdapter(fileobj, num_to_upload))
|
||||
logger.info(
|
||||
f"【115】开始上传 {target_name} 分片 {part_number}: {offset} -> {offset + num_to_upload}"
|
||||
)
|
||||
result = bucket.upload_part(
|
||||
object_name,
|
||||
upload_id,
|
||||
part_number,
|
||||
data=SizedFileAdapter(fileobj, num_to_upload),
|
||||
)
|
||||
parts.append(PartInfo(part_number, result.etag))
|
||||
logger.info(f"【115】{target_name} 分片 {part_number} 上传完成")
|
||||
offset += num_to_upload
|
||||
@@ -585,15 +608,18 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
# 请求头
|
||||
headers = {
|
||||
'X-oss-callback': encode_callback(callback["callback"]),
|
||||
'x-oss-callback-var': encode_callback(callback["callback_var"]),
|
||||
'x-oss-forbid-overwrite': 'false'
|
||||
"X-oss-callback": encode_callback(callback["callback"]),
|
||||
"x-oss-callback-var": encode_callback(callback["callback_var"]),
|
||||
"x-oss-forbid-overwrite": "false",
|
||||
}
|
||||
try:
|
||||
result = bucket.complete_multipart_upload(object_name, upload_id, parts,
|
||||
headers=headers)
|
||||
result = bucket.complete_multipart_upload(
|
||||
object_name, upload_id, parts, headers=headers
|
||||
)
|
||||
if result.status == 200:
|
||||
logger.debug(f"【115】上传 Step 6 回调结果:{result.resp.response.json()}")
|
||||
logger.debug(
|
||||
f"【115】上传 Step 6 回调结果:{result.resp.response.json()}"
|
||||
)
|
||||
logger.info(f"【115】{target_name} 上传成功")
|
||||
else:
|
||||
logger.warn(f"【115】{target_name} 上传失败,错误码: {result.status}")
|
||||
@@ -602,10 +628,12 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
if e.code == "FileAlreadyExists":
|
||||
logger.warn(f"【115】{target_name} 已存在")
|
||||
else:
|
||||
logger.error(f"【115】{target_name} 上传失败: {e.status}, 错误码: {e.code}, 详情: {e.message}")
|
||||
logger.error(
|
||||
f"【115】{target_name} 上传失败: {e.status}, 错误码: {e.code}, 详情: {e.message}"
|
||||
)
|
||||
return None
|
||||
# 返回结果
|
||||
return self._delay_get_item(target_path)
|
||||
return self.get_item(target_path)
|
||||
|
||||
def download(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
|
||||
"""
|
||||
@@ -617,12 +645,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
return None
|
||||
|
||||
download_info = self._request_api(
|
||||
"POST",
|
||||
"/open/ufile/downurl",
|
||||
"data",
|
||||
data={
|
||||
"pick_code": detail.pickcode
|
||||
}
|
||||
"POST", "/open/ufile/downurl", "data", data={"pick_code": detail.pickcode}
|
||||
)
|
||||
if not download_info:
|
||||
logger.error(f"【115】获取下载链接失败: {fileitem.name}")
|
||||
@@ -643,28 +666,26 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
progress_callback = transfer_process(Path(fileitem.path).as_posix())
|
||||
|
||||
try:
|
||||
with self.session.get(download_url, stream=True) as r:
|
||||
with self.session.stream("GET", download_url) as r:
|
||||
r.raise_for_status()
|
||||
downloaded_size = 0
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=self.chunk_size):
|
||||
for chunk in r.iter_bytes(chunk_size=self.chunk_size):
|
||||
if global_vars.is_transfer_stopped(fileitem.path):
|
||||
logger.info(f"【115】{fileitem.path} 下载已取消!")
|
||||
r.close()
|
||||
return None
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
# 更新进度
|
||||
if file_size:
|
||||
progress = (downloaded_size * 100) / file_size
|
||||
progress_callback(progress)
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if file_size:
|
||||
progress = (downloaded_size * 100) / file_size
|
||||
progress_callback(progress)
|
||||
|
||||
# 完成下载
|
||||
progress_callback(100)
|
||||
logger.info(f"【115】下载完成: {fileitem.name}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"【115】下载网络错误: {fileitem.name} - {str(e)}")
|
||||
# 删除可能部分下载的文件
|
||||
if local_path.exists():
|
||||
@@ -688,14 +709,10 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
try:
|
||||
self._request_api(
|
||||
"POST",
|
||||
"/open/ufile/delete",
|
||||
data={
|
||||
"file_ids": int(fileitem.fileid)
|
||||
}
|
||||
"POST", "/open/ufile/delete", data={"file_ids": int(fileitem.fileid)}
|
||||
)
|
||||
return True
|
||||
except requests.exceptions.HTTPError:
|
||||
except httpx.HTTPError:
|
||||
return False
|
||||
|
||||
def rename(self, fileitem: schemas.FileItem, name: str) -> bool:
|
||||
@@ -705,10 +722,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/open/ufile/update",
|
||||
data={
|
||||
"file_id": int(fileitem.fileid),
|
||||
"file_name": name
|
||||
}
|
||||
data={"file_id": int(fileitem.fileid), "file_name": name},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
@@ -725,10 +739,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
"POST",
|
||||
"/open/folder/get_info",
|
||||
"data",
|
||||
data={
|
||||
"path": path.as_posix()
|
||||
},
|
||||
no_error_log=True
|
||||
data={"path": path.as_posix()},
|
||||
no_error_log=True,
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
@@ -739,10 +751,12 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
type="file" if resp["file_category"] == "1" else "dir",
|
||||
name=resp["file_name"],
|
||||
basename=Path(resp["file_name"]).stem,
|
||||
extension=Path(resp["file_name"]).suffix[1:] if resp["file_category"] == "1" else None,
|
||||
extension=Path(resp["file_name"]).suffix[1:]
|
||||
if resp["file_category"] == "1"
|
||||
else None,
|
||||
pickcode=resp["pick_code"],
|
||||
size=resp['size_byte'] if resp["file_category"] == "1" else None,
|
||||
modify_time=resp["utime"]
|
||||
size=resp["size_byte"] if resp["file_category"] == "1" else None,
|
||||
modify_time=resp["utime"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"【115】获取文件信息失败: {str(e)}")
|
||||
@@ -753,7 +767,9 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
获取指定路径的文件夹,如不存在则创建
|
||||
"""
|
||||
|
||||
def __find_dir(_fileitem: schemas.FileItem, _name: str) -> Optional[schemas.FileItem]:
|
||||
def __find_dir(
|
||||
_fileitem: schemas.FileItem, _name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
查找下级目录中匹配名称的目录
|
||||
"""
|
||||
@@ -808,13 +824,13 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
data={
|
||||
"file_id": int(fileitem.fileid),
|
||||
"pid": int(dest_fileitem.fileid),
|
||||
}
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
if resp["state"]:
|
||||
new_path = Path(path) / fileitem.name
|
||||
new_item = self._delay_get_item(new_path)
|
||||
new_item = self.get_item(new_path)
|
||||
if not new_item:
|
||||
return False
|
||||
if self.rename(new_item, new_name):
|
||||
@@ -840,13 +856,13 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
data={
|
||||
"file_ids": int(fileitem.fileid),
|
||||
"to_cid": int(dest_fileitem.fileid),
|
||||
}
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
if resp["state"]:
|
||||
new_path = Path(path) / fileitem.name
|
||||
new_file = self._delay_get_item(new_path)
|
||||
new_file = self.get_item(new_path)
|
||||
if not new_file:
|
||||
return False
|
||||
if self.rename(new_file, new_name):
|
||||
@@ -864,17 +880,12 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
获取带有企业级配额信息的存储使用情况
|
||||
"""
|
||||
try:
|
||||
resp = self._request_api(
|
||||
"GET",
|
||||
"/open/user/info",
|
||||
"data"
|
||||
)
|
||||
resp = self._request_api("GET", "/open/user/info", "data")
|
||||
if not resp:
|
||||
return None
|
||||
space = resp["rt_space_info"]
|
||||
return schemas.StorageUsage(
|
||||
total=space["all_total"]["size"],
|
||||
available=space["all_remain"]["size"]
|
||||
total=space["all_total"]["size"], available=space["all_remain"]["size"]
|
||||
)
|
||||
except NoCheckInException:
|
||||
return None
|
||||
|
||||
@@ -418,6 +418,9 @@ class TransHandler:
|
||||
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
|
||||
elif fileitem.storage == target_storage:
|
||||
# 同一网盘
|
||||
if not source_oper.is_support_transtype(transfer_type):
|
||||
return None, f"存储 {fileitem.storage} 不支持 {transfer_type} 整理方式"
|
||||
|
||||
if transfer_type == "copy":
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
@@ -438,6 +441,11 @@ class TransHandler:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 创建硬链接失败"
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class TYemaSiteUserInfo(SiteParserBase):
|
||||
user_info = detail.get("data", {})
|
||||
self.userid = user_info.get("id")
|
||||
self.username = user_info.get("name")
|
||||
self.user_level = user_info.get("level")
|
||||
self.user_level = str(user_info.get("level")) if user_info.get("level") is not None else None
|
||||
self.join_at = StringUtils.unify_datetime_str(user_info.get("registerTime"))
|
||||
|
||||
self.upload = user_info.get('uploadSize')
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.jellyfin.jellyfin import Jellyfin
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
|
||||
@@ -19,20 +19,6 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
|
||||
super().init_service(service_name=Jellyfin.__name__.lower(),
|
||||
service_type=lambda conf: Jellyfin(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Jellyfin模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Jellyfin"
|
||||
|
||||
@@ -732,7 +732,7 @@ class Jellyfin:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
year=item.get("ProductionYear"),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Optional, Tuple, Union, Any, List, Generator
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MediaServerBase
|
||||
from app.modules.plex.plex import Plex
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class PlexModule(_ModuleBase, _MediaServerBase[Plex]):
|
||||
@@ -19,20 +19,6 @@ class PlexModule(_ModuleBase, _MediaServerBase[Plex]):
|
||||
super().init_service(service_name=Plex.__name__.lower(),
|
||||
service_type=lambda conf: Plex(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Plex模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Plex"
|
||||
|
||||
@@ -509,7 +509,7 @@ class Plex:
|
||||
item_type=item.type,
|
||||
title=item.title,
|
||||
original_title=item.originalTitle,
|
||||
year=str(item.year),
|
||||
year=item.year,
|
||||
tmdbid=ids.get("tmdb_id"),
|
||||
imdbid=ids.get("imdb_id"),
|
||||
tvdbid=ids.get("tvdb_id"),
|
||||
|
||||
@@ -7,13 +7,12 @@ from torrentool.torrent import Torrent
|
||||
from app import schemas
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _DownloaderBase
|
||||
from app.modules.qbittorrent.qbittorrent import Qbittorrent
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType, SystemConfigKey, EventType
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
@@ -26,20 +25,6 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
super().init_service(service_name=Qbittorrent.__name__.lower(),
|
||||
service_type=Qbittorrent)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Downloaders.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Qbittorrent模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Qbittorrent"
|
||||
@@ -165,7 +150,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
# 添加任务
|
||||
state = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=str(download_dir),
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
tag=tags,
|
||||
cookie=cookie,
|
||||
|
||||
@@ -3,12 +3,11 @@ import re
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.slack.slack import Slack
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
@@ -21,20 +20,6 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
service_type=Slack)
|
||||
self._channel = MessageChannel.Slack
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Slack模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Slack"
|
||||
|
||||
@@ -5,11 +5,13 @@ from typing import Tuple, Union
|
||||
|
||||
from lxml import etree
|
||||
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase
|
||||
from app.schemas.file import FileURI
|
||||
from app.schemas.types import ModuleType, OtherModulesType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
@@ -87,15 +89,33 @@ class SubtitleModule(_ModuleBase):
|
||||
# 获取种子信息
|
||||
folder_name, _ = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
# 文件保存目录,如果是单文件种子,则folder_name是空,此时文件保存目录就是下载目录
|
||||
download_dir = download_dir / folder_name
|
||||
storageChain = StorageChain()
|
||||
# 等待目录存在
|
||||
working_dir_item = None
|
||||
# split download_dir into storage and path
|
||||
fileURI = FileURI.from_uri(download_dir.as_posix())
|
||||
storage = fileURI.storage
|
||||
download_dir = Path(fileURI.path)
|
||||
for _ in range(30):
|
||||
if download_dir.exists():
|
||||
found = storageChain.get_file_item(storage, download_dir / folder_name)
|
||||
if found:
|
||||
working_dir_item = found
|
||||
break
|
||||
time.sleep(1)
|
||||
# 目录仍然不存在,且有文件夹名,则创建目录
|
||||
if not download_dir.exists() and folder_name:
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not working_dir_item and folder_name:
|
||||
parent_dir_item = storageChain.get_file_item(storage, download_dir)
|
||||
if parent_dir_item:
|
||||
working_dir_item = storageChain.create_folder(
|
||||
parent_dir_item,
|
||||
folder_name
|
||||
)
|
||||
else:
|
||||
logger.error(f"下载根目录不存在,无法创建字幕文件夹:{download_dir}")
|
||||
return
|
||||
if not working_dir_item:
|
||||
logger.error(f"下载目录不存在,无法保存字幕:{download_dir / folder_name}")
|
||||
return
|
||||
# 读取网站代码
|
||||
request = RequestUtils(cookies=torrent.site_cookie, ua=torrent.site_ua)
|
||||
res = request.get_res(torrent.page_url)
|
||||
@@ -144,12 +164,12 @@ class SubtitleModule(_ModuleBase):
|
||||
shutil.unpack_archive(zip_file, zip_path, format='zip')
|
||||
# 遍历转移文件
|
||||
for sub_file in SystemUtils.list_files(zip_path, settings.RMT_SUBEXT):
|
||||
target_sub_file = download_dir / sub_file.name
|
||||
if target_sub_file.exists():
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
SystemUtils.copy(sub_file, target_sub_file)
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
# 删除临时文件
|
||||
try:
|
||||
shutil.rmtree(zip_path)
|
||||
@@ -160,9 +180,12 @@ class SubtitleModule(_ModuleBase):
|
||||
sub_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
sub_file.write_bytes(ret.content)
|
||||
target_sub_file = download_dir / sub_file.name
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file}")
|
||||
SystemUtils.copy(sub_file, target_sub_file)
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
else:
|
||||
logger.error(f"下载字幕文件失败:{sublink}")
|
||||
continue
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.synologychat.synologychat import SynologyChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
@@ -19,20 +18,6 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
service_type=SynologyChat)
|
||||
self._channel = MessageChannel.SynologyChat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载SynologyChat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Synology Chat"
|
||||
|
||||
@@ -41,7 +41,7 @@ class SynologyChat:
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
发送SynologyChat消息
|
||||
:param title: 消息标题
|
||||
:param text: 消息内容
|
||||
:param image: 消息图片地址
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
from typing import Dict, Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.telegram.telegram import Telegram
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, ConfigChangeEventData, \
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, \
|
||||
NotificationConf
|
||||
from app.schemas.types import ModuleType, ChainEventType, SystemConfigKey, EventType
|
||||
from app.schemas.types import ModuleType, ChainEventType
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
|
||||
@@ -26,20 +24,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
service_type=Telegram)
|
||||
self._channel = MessageChannel.Telegram
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Telegram模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Telegram"
|
||||
@@ -283,8 +267,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
escape_markdown=kwargs.get("escape_markdown"))
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Optional, List, Dict, Callable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import telebot
|
||||
from telebot import apihelper
|
||||
from telebot.types import InputFile, InlineKeyboardMarkup, InlineKeyboardButton
|
||||
from telebot.types import InputMediaPhoto
|
||||
from telebot import TeleBot, apihelper
|
||||
from telebot.types import BotCommand, InlineKeyboardMarkup, InlineKeyboardButton, InputMediaPhoto
|
||||
from telegramify_markdown import standardize, telegramify
|
||||
from telegramify_markdown.type import ContentTypes, SentType
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.utils.common import retry
|
||||
from app.utils.http import RequestUtils
|
||||
@@ -26,13 +26,11 @@ class RetryException(Exception):
|
||||
|
||||
class Telegram:
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_event = Event()
|
||||
_bot: telebot.TeleBot = None
|
||||
_bot: TeleBot = None
|
||||
_callback_handlers: Dict[str, Callable] = {} # 存储回调处理器
|
||||
_user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
_escape_chars = r'_*[]()~`>#+-=|{}.!' # Telegram MarkdownV2
|
||||
_markdown_escape_pattern = re.compile(f'([{re.escape(_escape_chars)}])') # Telegram MarkdownV2 规则转义特殊字符正则pattern
|
||||
|
||||
def __init__(self, TELEGRAM_TOKEN: Optional[str] = None, TELEGRAM_CHAT_ID: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
初始化参数
|
||||
@@ -53,7 +51,7 @@ class Telegram:
|
||||
else:
|
||||
apihelper.proxy = settings.PROXY
|
||||
# bot
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
_bot = TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
# 记录句柄
|
||||
self._bot = _bot
|
||||
# 获取并存储bot用户名用于@检测
|
||||
@@ -216,8 +214,7 @@ class Telegram:
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
escape_markdown: bool = True) -> Optional[bool]:
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
:param title: 消息标题
|
||||
@@ -228,7 +225,6 @@ class Telegram:
|
||||
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
|
||||
:param original_message_id: 原消息ID,如果提供则编辑原消息
|
||||
:param original_chat_id: 原消息的聊天ID,编辑消息时需要
|
||||
:param escape_markdown: 是否对内容进行Markdown转义
|
||||
|
||||
"""
|
||||
if not self._telegram_token or not self._telegram_chat_id:
|
||||
@@ -239,22 +235,14 @@ class Telegram:
|
||||
return False
|
||||
|
||||
try:
|
||||
if title:
|
||||
# 标题总是转义(因为通常标题不包含Markdown格式)
|
||||
title = self.escape_markdown(title)
|
||||
if text:
|
||||
if escape_markdown:
|
||||
# 完全转义模式:转义所有特殊字符
|
||||
text = self.escape_markdown(text)
|
||||
else:
|
||||
# 智能转义模式:保留Markdown格式,只转义普通文本中的特殊字符
|
||||
text = self.escape_markdown_smart(text)
|
||||
if title:
|
||||
caption = f"*{title}*\n{text}"
|
||||
else:
|
||||
caption = text
|
||||
if title and text:
|
||||
caption = f"**{title}**\n{text}"
|
||||
elif title:
|
||||
caption = f"**{title}**"
|
||||
elif text:
|
||||
caption = text
|
||||
else:
|
||||
caption = f"*{title}*"
|
||||
caption = ""
|
||||
|
||||
if link:
|
||||
caption = f"{caption}\n[查看详情]({link})"
|
||||
@@ -512,7 +500,7 @@ class Telegram:
|
||||
|
||||
if image:
|
||||
# 如果有图片,使用edit_message_media
|
||||
media = InputMediaPhoto(media=image, caption=text, parse_mode="MarkdownV2")
|
||||
media = InputMediaPhoto(media=image, caption=standardize(text), parse_mode="MarkdownV2")
|
||||
self._bot.edit_message_media(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
@@ -524,7 +512,7 @@ class Telegram:
|
||||
self._bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
text=standardize(text),
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup
|
||||
)
|
||||
@@ -533,49 +521,114 @@ class Telegram:
|
||||
logger.error(f"编辑消息失败:{str(e)}")
|
||||
return False
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_request(self, userid: Optional[str] = None, image="", caption="",
|
||||
reply_markup: Optional[InlineKeyboardMarkup] = None) -> bool:
|
||||
"""
|
||||
向Telegram发送报文
|
||||
:param reply_markup: 内联键盘
|
||||
"""
|
||||
if image:
|
||||
res = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(image)
|
||||
if res is None:
|
||||
raise Exception("获取图片失败")
|
||||
if res.content:
|
||||
# 使用随机标识构建图片文件的完整路径,并写入图片内容到文件
|
||||
image_file = Path(settings.TEMP_PATH) / "telegram" / str(uuid.uuid4())
|
||||
if not image_file.parent.exists():
|
||||
image_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
image_file.write_bytes(res.content)
|
||||
photo = InputFile(image_file)
|
||||
# 发送图片到Telegram
|
||||
ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id,
|
||||
photo=photo,
|
||||
caption=caption,
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送图片消息失败")
|
||||
return True
|
||||
# 按4096分段循环发送消息
|
||||
ret = None
|
||||
if len(caption) > 4095:
|
||||
for i in range(0, len(caption), 4095):
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption[i:i + 4095],
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup if i == 0 else None)
|
||||
else:
|
||||
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
|
||||
text=caption,
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup)
|
||||
if ret is None:
|
||||
raise RetryException("发送文本消息失败")
|
||||
return True if ret else False
|
||||
kwargs = {
|
||||
'chat_id': userid or self._telegram_chat_id,
|
||||
'parse_mode': "MarkdownV2",
|
||||
'reply_markup': reply_markup
|
||||
}
|
||||
|
||||
# 处理图片
|
||||
image = self.__process_image(image)
|
||||
|
||||
try:
|
||||
# 图片消息的标题长度限制为1024,文本消息为4096
|
||||
caption_limit = 1024 if image else 4096
|
||||
if len(caption) < caption_limit:
|
||||
ret = self.__send_short_message(image, caption, **kwargs)
|
||||
else:
|
||||
sent_idx = set()
|
||||
ret = self.__send_long_message(image, caption, sent_idx, **kwargs)
|
||||
|
||||
return ret is not None
|
||||
except Exception as e:
|
||||
logger.error(f"发送Telegram消息失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __process_image(image_url: Optional[str]) -> Optional[bytes]:
|
||||
"""
|
||||
处理图片URL,获取图片内容
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
image = ImageHelper().fetch_image(image_url)
|
||||
if not image:
|
||||
logger.warn(f"图片获取失败: {image_url},仅发送文本消息")
|
||||
return image
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs):
|
||||
"""
|
||||
发送短消息
|
||||
"""
|
||||
try:
|
||||
if image:
|
||||
return self._bot.send_photo(
|
||||
photo=image,
|
||||
caption=standardize(caption),
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
return self._bot.send_message(
|
||||
text=standardize(caption),
|
||||
**kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise RetryException(f"发送{'图片' if image else '文本'}消息失败")
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_long_message(self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs):
|
||||
"""
|
||||
发送长消息
|
||||
"""
|
||||
try:
|
||||
reply_markup = kwargs.pop("reply_markup", None)
|
||||
|
||||
boxs: SentType = ThreadHelper().submit(lambda x: asyncio.run(telegramify(x)), caption).result()
|
||||
|
||||
ret = None
|
||||
for i, item in enumerate(boxs):
|
||||
if i in sent_idx:
|
||||
# 跳过已发送消息
|
||||
continue
|
||||
|
||||
current_reply_markup = reply_markup if i == 0 else None
|
||||
|
||||
if item.content_type == ContentTypes.TEXT and (i != 0 or not image):
|
||||
ret = self._bot.send_message(**kwargs,
|
||||
text=item.content,
|
||||
reply_markup=current_reply_markup
|
||||
)
|
||||
|
||||
elif item.content_type == ContentTypes.PHOTO or (image and i == 0):
|
||||
ret = self._bot.send_photo(**kwargs,
|
||||
photo=(getattr(item, "file_name", ""),
|
||||
getattr(item, "file_data", image)),
|
||||
caption=getattr(item, "caption", item.content),
|
||||
reply_markup=current_reply_markup
|
||||
)
|
||||
|
||||
elif item.content_type == ContentTypes.FILE:
|
||||
ret = self._bot.send_document(**kwargs,
|
||||
document=(item.file_name, item.file_data),
|
||||
caption=item.caption,
|
||||
reply_markup=current_reply_markup
|
||||
)
|
||||
|
||||
sent_idx.add(i)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
"""
|
||||
@@ -588,7 +641,7 @@ class Telegram:
|
||||
self._bot.delete_my_commands()
|
||||
self._bot.set_my_commands(
|
||||
commands=[
|
||||
telebot.types.BotCommand(cmd[1:], str(desc.get("description"))) for cmd, desc in
|
||||
BotCommand(cmd[1:], str(desc.get("description"))) for cmd, desc in
|
||||
commands.items()
|
||||
]
|
||||
)
|
||||
@@ -610,84 +663,3 @@ class Telegram:
|
||||
self._bot.stop_polling()
|
||||
self._polling_thread.join()
|
||||
logger.info("Telegram消息接收服务已停止")
|
||||
|
||||
def escape_markdown(self, text: str) -> str:
|
||||
# 按 Telegram MarkdownV2 规则转义特殊字符
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
return self._markdown_escape_pattern.sub(r'\\\1', text)
|
||||
|
||||
def escape_markdown_smart(self, text: str) -> str:
|
||||
"""
|
||||
智能转义Markdown文本:只转义不在Markdown标记内的特殊字符
|
||||
这样可以保留已有的Markdown格式(如*粗体*、_斜体_、[链接](url)等),
|
||||
同时转义普通文本中的特殊字符以避免API错误
|
||||
|
||||
注意:Telegram MarkdownV2不支持以下语法,这些字符会被转义:
|
||||
- 标题语法(#、##、###)会被转义为 \#、\##、\###
|
||||
- 列表语法(-、*、+)会被转义为 \-、\*、\+
|
||||
- 引用语法(>)会被转义为 \>
|
||||
|
||||
建议使用加粗文本模拟标题:*标题文本*
|
||||
|
||||
:param text: 要转义的文本
|
||||
:return: 转义后的文本
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return str(text) if text is not None else ""
|
||||
|
||||
# 如果没有特殊字符,直接返回
|
||||
if not any(char in self._escape_chars for char in text):
|
||||
return text
|
||||
|
||||
# 标记受保护的区域(Markdown标记内的内容不转义)
|
||||
protected = [False] * len(text)
|
||||
|
||||
# 按优先级匹配Markdown标记(从最复杂到最简单)
|
||||
# 1. 链接:[text](url) - 必须最先匹配
|
||||
link_pattern = r'\[([^\]]*)\]\(([^)]*)\)'
|
||||
for match in re.finditer(link_pattern, text):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 2. 粗体:*text*(单个*,不是**)
|
||||
bold_pattern = r'(?<!\*)\*(?!\*)([^*]+?)(?<!\*)\*(?!\*)'
|
||||
for match in re.finditer(bold_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 3. 斜体:_text_(单个_,不是__)
|
||||
italic_pattern = r'(?<!_)_(?!_)([^_]+?)(?<!_)_(?!_)'
|
||||
for match in re.finditer(italic_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 4. 代码:`text`
|
||||
code_pattern = r'`([^`]+)`'
|
||||
for match in re.finditer(code_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 5. 删除线:~text~
|
||||
strikethrough_pattern = r'~([^~]+)~'
|
||||
for match in re.finditer(strikethrough_pattern, text):
|
||||
if not any(protected[match.start():match.end()]):
|
||||
for i in range(match.start(), match.end()):
|
||||
protected[i] = True
|
||||
|
||||
# 构建结果:只转义未保护区域的特殊字符
|
||||
result = []
|
||||
for i, char in enumerate(text):
|
||||
if protected[i]:
|
||||
# 受保护区域(Markdown标记内),不转义
|
||||
result.append(char)
|
||||
elif char in self._escape_chars:
|
||||
# 未保护区域,转义特殊字符
|
||||
result.append('\\' + char)
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
return ''.join(result)
|
||||
@@ -14,7 +14,6 @@ from app.modules.themoviedb.category import CategoryHelper
|
||||
from app.modules.themoviedb.scraper import TmdbScraper
|
||||
from app.modules.themoviedb.tmdb_cache import TmdbCache
|
||||
from app.modules.themoviedb.tmdbapi import TmdbApi
|
||||
from app.schemas import MediaPerson
|
||||
from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
@@ -23,6 +22,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
"""
|
||||
TMDB媒体信息匹配
|
||||
"""
|
||||
CONFIG_WATCH = {"PROXY_HOST", "TMDB_API_DOMAIN", "TMDB_API_KEY", "TMDB_LOCALE"}
|
||||
|
||||
# 元数据缓存
|
||||
cache: TmdbCache = None
|
||||
@@ -39,6 +39,12 @@ class TheMovieDbModule(_ModuleBase):
|
||||
self.category = CategoryHelper()
|
||||
self.scraper = TmdbScraper()
|
||||
|
||||
def on_config_changed(self):
|
||||
# 停止模块
|
||||
self.stop()
|
||||
# 初始化模块
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TheMovieDb"
|
||||
@@ -635,7 +641,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return medias
|
||||
return []
|
||||
|
||||
def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
def search_persons(self, name: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
搜索人物信息
|
||||
"""
|
||||
@@ -645,10 +651,10 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return []
|
||||
results = self.tmdb.search_persons(name)
|
||||
if results:
|
||||
return [MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return [schemas.MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return []
|
||||
|
||||
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
async def async_search_persons(self, name: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
异步搜索人物信息
|
||||
"""
|
||||
@@ -658,7 +664,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return []
|
||||
results = await self.tmdb.async_search_persons(name)
|
||||
if results:
|
||||
return [MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return [schemas.MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return []
|
||||
|
||||
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
|
||||
|
||||
@@ -643,17 +643,23 @@ class TmdbApi:
|
||||
reverse=True
|
||||
)
|
||||
for tv in tvs:
|
||||
# 年份
|
||||
# 使用年份、名称匹配
|
||||
tv_year = tv.get('first_air_date')[0:4] if tv.get('first_air_date') else None
|
||||
if (self.__compare_names(name, tv.get('name'))
|
||||
or self.__compare_names(name, tv.get('original_name'))) \
|
||||
and (tv_year == str(season_year)):
|
||||
return tv
|
||||
# 匹配别名、译名
|
||||
# 获取别名、译名重新匹配
|
||||
if not tv.get("names"):
|
||||
tv = self.get_info(mtype=MediaType.TV, tmdbid=tv.get("id"))
|
||||
if not tv or not self.__compare_names(name, tv.get("names")):
|
||||
if not tv or not (
|
||||
self.__compare_names(name, tv.get("name"))
|
||||
or self.__compare_names(name, tv.get("original_name"))
|
||||
or self.__compare_names(name, tv.get("names"))):
|
||||
continue
|
||||
if tv_year == str(season_year):
|
||||
return tv
|
||||
# 季年份匹配
|
||||
if __season_match(tv_info=tv, _season_year=season_year):
|
||||
return tv
|
||||
return {}
|
||||
@@ -744,11 +750,11 @@ class TmdbApi:
|
||||
if validation_result is not None:
|
||||
return validation_result
|
||||
|
||||
logger.info("正在从TheDbMovie网站查询:%s ..." % name)
|
||||
logger.info("正在从TheMovieDb网站查询:%s ..." % name)
|
||||
tmdb_url = self._build_tmdb_search_url(name)
|
||||
res = RequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(url=tmdb_url)
|
||||
if res is None:
|
||||
logger.error("无法连接TheDbMovie")
|
||||
logger.error("无法连接TheMovieDb")
|
||||
return None
|
||||
|
||||
# 响应验证
|
||||
|
||||
@@ -7,13 +7,12 @@ from transmission_rpc import File
|
||||
from app import schemas
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _DownloaderBase
|
||||
from app.modules.transmission.transmission import Transmission
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType, SystemConfigKey, EventType
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
@@ -26,20 +25,6 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
super().init_service(service_name=Transmission.__name__.lower(),
|
||||
service_type=Transmission)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Downloaders.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Transmission模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Transmission"
|
||||
@@ -166,7 +151,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
# 添加任务
|
||||
torrent = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=str(download_dir),
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
labels=labels,
|
||||
cookie=cookie
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.trimemedia.trimemedia import TrimeMedia
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import ChainEventType, MediaServerType, MediaType, ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas.types import ChainEventType, MediaServerType, MediaType, ModuleType
|
||||
|
||||
|
||||
class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
@@ -23,20 +23,6 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
),
|
||||
)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载飞牛影视模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "飞牛影视"
|
||||
|
||||
@@ -449,7 +449,7 @@ class TrimeMedia:
|
||||
item_type=item_type,
|
||||
title=item.title,
|
||||
original_title=item.original_title,
|
||||
year=str(year),
|
||||
year=year,
|
||||
tmdbid=item.tmdb_id,
|
||||
imdbid=item.imdb_id,
|
||||
user_state=user_state,
|
||||
|
||||
@@ -2,12 +2,11 @@ import json
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.vocechat.vocechat import VoceChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
@@ -20,20 +19,6 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
service_type=VoceChat)
|
||||
self._channel = MessageChannel.VoceChat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载VoceChat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "VoceChat"
|
||||
|
||||
@@ -4,11 +4,10 @@ from typing import Union, Tuple
|
||||
from pywebpush import webpush, WebPushException
|
||||
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, MessageChannel, SystemConfigKey, EventType
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import ModuleType, MessageChannel
|
||||
|
||||
|
||||
class WebPushModule(_ModuleBase, _MessageBase):
|
||||
@@ -20,20 +19,6 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
super().init_service(service_name=self.get_name().lower())
|
||||
self._channel = MessageChannel.WebPush
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载WebPush模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "WebPush"
|
||||
|
||||
@@ -3,13 +3,13 @@ import xml.dom.minidom
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from app.modules.wechat.wechat import WeChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, ChainEventType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData
|
||||
from app.schemas.types import ModuleType, ChainEventType
|
||||
from app.utils.dom import DomUtils
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
@@ -24,20 +24,6 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
service_type=WeChat)
|
||||
self._channel = MessageChannel.Wechat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Wechat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "微信"
|
||||
|
||||
@@ -169,8 +169,8 @@ class WeChat:
|
||||
:param link: 跳转链接
|
||||
:return: 发送状态,错误信息
|
||||
"""
|
||||
if not title:
|
||||
logger.error("消息标题不能为空")
|
||||
if not title and not text:
|
||||
logger.error("消息标题和内容不能都为空")
|
||||
return False
|
||||
if text:
|
||||
formatted_text = text.replace("\n\n", "\n")
|
||||
|
||||
@@ -17,13 +17,12 @@ from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.cache import TTLCache, FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas import FileItem
|
||||
from app.schemas.types import SystemConfigKey, EventType
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import SingletonClass
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -60,10 +59,11 @@ class FileMonitorHandler(FileSystemEventHandler):
|
||||
logger.error(f"on_moved 异常: {e}")
|
||||
|
||||
|
||||
class Monitor(metaclass=SingletonClass):
|
||||
class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"""
|
||||
目录监控处理链,单例模式
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Directories.value}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -84,20 +84,12 @@ class Monitor(metaclass=SingletonClass):
|
||||
# 启动目录监控和文件整理
|
||||
self.init()
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Directories.value]:
|
||||
return
|
||||
logger.info("配置变更事件触发,重新初始化目录监控...")
|
||||
def on_config_changed(self):
|
||||
self.init()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "目录监控"
|
||||
|
||||
def save_snapshot(self, storage: str, snapshot: Dict, file_count: int = 0,
|
||||
last_snapshot_time: Optional[float] = None):
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user