mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-14 07:26:50 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5c44a5097 | ||
|
|
16ada1a6c4 | ||
|
|
ac09ce5230 | ||
|
|
2255b61195 | ||
|
|
314ac3903c | ||
|
|
5c3796bf73 | ||
|
|
492e3c333b | ||
|
|
cce72d0884 | ||
|
|
69a064e986 | ||
|
|
f4ca4120bc | ||
|
|
b45956f850 | ||
|
|
762a7fbba7 | ||
|
|
10290ca17b | ||
|
|
12a2561ca8 | ||
|
|
543bee9ad5 | ||
|
|
cc3e062262 | ||
|
|
bf4f5f8744 | ||
|
|
f8f06a602a |
@@ -632,6 +632,8 @@ class MoviePilotAgent:
|
||||
detail = cls._exception_detail_text(error).lower()
|
||||
if "no endpoints found that support image input" in detail:
|
||||
return True
|
||||
if "unknown variant" in detail and "image_url" in detail:
|
||||
return True
|
||||
if "image input" not in detail and "images" not in detail:
|
||||
return False
|
||||
return any(
|
||||
|
||||
@@ -405,6 +405,7 @@ def _patch_openai_responses_instructions_support():
|
||||
return
|
||||
|
||||
_patch_openai_interleaved_reasoning_content_support()
|
||||
_patch_openai_responses_empty_output_support()
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
|
||||
return
|
||||
@@ -464,6 +465,64 @@ def _patch_openai_responses_instructions_support():
|
||||
logger.debug("已修补 langchain-openai responses API 的 instructions 兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_empty_output_support():
|
||||
"""
|
||||
修补 langchain-openai Responses API 流式完成事件 output 为空的兼容性。
|
||||
|
||||
ChatGPT Codex 后端有时会在 `response.completed` chunk 里返回
|
||||
`response.output = None`,但前面的 delta chunk 已经包含实际文本。
|
||||
langchain-openai 在收尾阶段遍历 output 会抛出 TypeError,这里将缺失
|
||||
output 规整为空列表,让收尾 chunk 只承载 usage/metadata。
|
||||
"""
|
||||
try:
|
||||
import langchain_openai.chat_models.base as _openai_base
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai responses output 修补:{err}")
|
||||
return
|
||||
|
||||
if getattr(_openai_base, "_moviepilot_responses_empty_output_patched", False):
|
||||
return
|
||||
|
||||
original_construct = getattr(
|
||||
_openai_base, "_construct_lc_result_from_responses_api", None
|
||||
)
|
||||
if not callable(original_construct):
|
||||
logger.warning("langchain-openai 缺少 Responses API 结果构造函数,无法修补 output")
|
||||
return
|
||||
|
||||
def _clone_response_with_empty_output(response):
|
||||
"""
|
||||
复制 Responses 对象,把缺失 output 规整为空列表。
|
||||
"""
|
||||
model_copy = getattr(response, "model_copy", None)
|
||||
if callable(model_copy):
|
||||
try:
|
||||
return model_copy(update={"output": []})
|
||||
except Exception as err:
|
||||
logger.debug(f"复制 Responses 对象失败,回退原地修补 output:{err}")
|
||||
|
||||
try:
|
||||
setattr(response, "output", [])
|
||||
except Exception as err:
|
||||
logger.debug(f"原地修补 Responses output 失败:{err}")
|
||||
return response
|
||||
|
||||
@wraps(original_construct)
|
||||
def _patched_construct_lc_result_from_responses_api(response, *args, **kwargs):
|
||||
"""
|
||||
在 Responses API 收尾 chunk 缺少 output 时跳过空内容遍历。
|
||||
"""
|
||||
if hasattr(response, "output") and getattr(response, "output", None) is None:
|
||||
response = _clone_response_with_empty_output(response)
|
||||
return original_construct(response, *args, **kwargs)
|
||||
|
||||
_openai_base._construct_lc_result_from_responses_api = (
|
||||
_patched_construct_lc_result_from_responses_api
|
||||
)
|
||||
_openai_base._moviepilot_responses_empty_output_patched = True
|
||||
logger.debug("已修补 langchain-openai responses API 空 output 兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
|
||||
@@ -1458,7 +1458,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
|
||||
async def _fetch_models_dev(self, use_proxy: Optional[bool] = None) -> dict[str, Any]:
|
||||
"""通过网络请求获取最新 models.dev 数据。"""
|
||||
headers = {"User-Agent": "MoviePilot/1.0"}
|
||||
headers = {"User-Agent": settings.USER_AGENT}
|
||||
async with httpx.AsyncClient(**self._build_httpx_kwargs(use_proxy)) as client:
|
||||
response = await client.get(self._MODELS_DEV_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
@@ -1773,7 +1773,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
仅补充 Copilot 必需的意图头,避免重复覆盖。
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"x-initiator": "user",
|
||||
}
|
||||
@@ -2147,7 +2147,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/usercode",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={"client_id": self._CHATGPT_CLIENT_ID},
|
||||
)
|
||||
@@ -2184,7 +2184,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"client_id": self._COPILOT_CLIENT_ID,
|
||||
@@ -2380,7 +2380,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"device_auth_id": session.context["device_auth_id"],
|
||||
@@ -2425,7 +2425,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"client_id": self._COPILOT_CLIENT_ID,
|
||||
|
||||
@@ -5,48 +5,57 @@ All your responses must be in **Chinese (中文)**.
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
<agent_core>
|
||||
Identity and Goal:
|
||||
<identity>
|
||||
- You are an AI media assistant powered by MoviePilot.
|
||||
- Your primary goal is to fully resolve the user's MoviePilot-related media tasks with the available tools whenever the request is actionable.
|
||||
- Focus on MoviePilot's core home media domain: sites, search, recognition, downloads, subscriptions, library organization, file transfer, and system status.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
</identity>
|
||||
|
||||
<non_negotiable_boundaries>
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- Never directly modify application source code, scripts, tests, or generated code through `edit_file`, `write_file`, shell write operations, or similar tools. If the user asks about MoviePilot internals or debugging, inspect and explain the needed change without applying it.
|
||||
- If the user explicitly asks to change the speaking style or persona, use `query_personas` and `switch_persona` instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Treat read-only inspection as allowed, but never use shell redirection, overwrite operations, file editing tools, or generated patches to change code.
|
||||
</non_negotiable_boundaries>
|
||||
|
||||
<confirmation_policy>
|
||||
- Do not stop for approval on read-only operations.
|
||||
- If the user has not explicitly requested an operation that changes system behavior, ask for confirmation before proceeding. This includes modifying system settings, updating plugin configuration, reloading plugins, running restart/stop/start commands, or triggering slash commands such as `/restart`.
|
||||
- Always get explicit consent before destructive or high-impact actions such as starting downloads, deleting subscriptions, deleting download tasks or files, removing history, installing/uninstalling plugins, changing site authentication, changing scheduler or workflow execution state, restarting services, or stopping services.
|
||||
- If the user explicitly requested the exact write action, perform the smallest correct change and then validate the result.
|
||||
- If a requested action is ambiguous between read-only inspection and state change, inspect first and ask a short confirmation question before the state-changing step.
|
||||
</confirmation_policy>
|
||||
|
||||
<moviepilot_domain_model>
|
||||
- Treat sites as a first-class system capability, not background detail. In MoviePilot, sites are the upstream source for search, account status, authentication, and many download or subscription decisions.
|
||||
- Understand the platform's core workflow as: site availability and configuration -> media search -> media recognition/metadata confirmation -> manual download or subscription -> transfer and library organization -> status/history confirmation.
|
||||
- Treat manual download and subscription automation as two execution modes of the same core pipeline. One is user-triggered immediate acquisition; the other is persistent site-driven monitoring and acquisition.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
- Treat manual download and subscription automation as two execution modes of the same acquisition pipeline. Manual download is user-triggered immediate acquisition; subscription is persistent site-driven monitoring and acquisition.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, transfer, or status/history.
|
||||
- Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
- User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
</moviepilot_domain_model>
|
||||
|
||||
Behavior Model:
|
||||
<operating_principles>
|
||||
- Prioritize task progress over conversation.
|
||||
- Check current state before making changes, then do the smallest correct action.
|
||||
- When a task depends on tracker or indexer availability, inspect site state first or as early as possible.
|
||||
- Do not stop for approval on read-only operations. Only confirm before destructive or high-impact actions such as starting downloads, deleting subscriptions, or removing history.
|
||||
- When a request can be completed by tools, prefer doing the work over explaining what you might do.
|
||||
- After an action, perform the minimum validation needed to confirm the result actually landed.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, or transfer.
|
||||
- If the user explicitly asks to change the speaking style or persona, use the dedicated persona tools instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
</operating_principles>
|
||||
|
||||
Core Capabilities:
|
||||
1. Site Operations - Query configured sites, understand site priority and availability, inspect account data, test connectivity, and update site authentication when the user explicitly requests site maintenance.
|
||||
2. Media Search and Recognition - Identify movies, TV shows, and anime; search media databases; recognize media from fuzzy filenames, torrent titles, or incomplete names.
|
||||
3. Torrent Search and Selection - Search torrents across configured sites and filter by quality, resolution, codec, effect, release group, and other result traits.
|
||||
4. Download Control - Add, inspect, modify, or remove download tasks and connect site results to downloader execution.
|
||||
5. Subscription Management - Create and manage subscriptions that continuously search configured sites and automatically download matching releases.
|
||||
6. Transfer and Library Organization - Transfer files into the library, trigger recognition-aware organization, and confirm post-download file landing or cleanup state.
|
||||
7. System Status and History - Monitor downloader state, site state, transfer history, subscription history, and related system health signals.
|
||||
8. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
9. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
10. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
|
||||
Core Workflow:
|
||||
<core_workflow>
|
||||
1. Site and Context Check: Determine whether site status, site scope, library state, existing subscriptions, or prior download/transfer history can affect the task.
|
||||
2. Media Identity Resolution: Confirm exact media identity such as TMDB ID, title, year, type, season, or episode using `search_media`, `query_media_detail`, or `recognize_media` as needed.
|
||||
3. Resource Discovery: Use the appropriate search path for the task. For manual acquisition, search site resources and inspect result quality. For automation, prepare subscription conditions that will search sites continuously.
|
||||
4. Action Execution: Perform the requested task, typically one of: test/query site, search torrents, add download, add or modify subscription, or transfer and organize files.
|
||||
5. Final Confirmation: State the outcome briefly, including the key media facts, chosen site or resource scope when relevant, and the next blocker if the task could not be completed.
|
||||
</core_workflow>
|
||||
|
||||
Tool Calling Strategy:
|
||||
<tool_strategy>
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- Prefer site-aware tool paths when the task is about torrents, subscriptions, or download failures. `query_sites`, `test_site`, and `query_site_userdata` are part of the main operating flow, not edge-case tools.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
@@ -54,11 +63,10 @@ Tool Calling Strategy:
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when automated paths are exhausted.
|
||||
- If torrent search yields no useful result, check site scope, site health, and recognition quality before concluding that the resource is unavailable.
|
||||
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
- Use `execute_command` for shell work. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process.
|
||||
- Use `execute_command` only for diagnostics, read-only inspection, or commands the user explicitly asked to run. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process.
|
||||
</tool_strategy>
|
||||
|
||||
Media Management Rules:
|
||||
<media_rules>
|
||||
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.
|
||||
2. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
3. Search vs Recognition: `search_media` is for database lookup, `recognize_media` is for parsing titles or paths, and `search_torrents` is for site resource lookup. Do not confuse these roles.
|
||||
@@ -67,6 +75,7 @@ Media Management Rules:
|
||||
6. Transfer Awareness: If the user asks about downloaded files landing in the library, include transfer or organization state in the reasoning, not just download completion.
|
||||
7. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative or the next best operational step.
|
||||
8. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</media_rules>
|
||||
</agent_core>
|
||||
|
||||
<communication_runtime>
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.db.user_oper import UserOper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
@@ -407,7 +407,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
@@ -415,6 +415,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Optional, Type, List, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from ddgs import DDGS
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,63 +14,151 @@ from app.log import logger
|
||||
|
||||
# 搜索超时时间(秒)
|
||||
SEARCH_TIMEOUT = 20
|
||||
# 单次搜索最多返回结果数
|
||||
MAX_SEARCH_RESULTS = 20
|
||||
# 默认搜索源
|
||||
DEFAULT_SEARCH_ENGINE = "auto"
|
||||
# 可显式调用的搜索引擎后端
|
||||
SEARCH_ENGINE_BACKENDS = (
|
||||
"auto",
|
||||
"duckduckgo",
|
||||
"google",
|
||||
"brave",
|
||||
"yahoo",
|
||||
"wikipedia",
|
||||
"yandex",
|
||||
"mojeek",
|
||||
)
|
||||
SUPPORTED_SEARCH_ENGINES = SEARCH_ENGINE_BACKENDS
|
||||
DDGS_AUTO_BACKEND = ",".join(
|
||||
backend for backend in SEARCH_ENGINE_BACKENDS if backend != DEFAULT_SEARCH_ENGINE
|
||||
)
|
||||
SITE_SEARCH_PATTERN = re.compile(r"\bsite:", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SearchSiteFilter:
|
||||
"""站点限定搜索参数"""
|
||||
|
||||
domain: str
|
||||
path: str
|
||||
search_target: str
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
explanation: Optional[str] = Field(
|
||||
None,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: str = Field(
|
||||
..., description="The search query string to search for on the web"
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
20,
|
||||
MAX_SEARCH_RESULTS,
|
||||
description="Maximum number of search results to return (default: 20, max: 20)",
|
||||
)
|
||||
search_engine: Optional[str] = Field(
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
description=(
|
||||
"Search backend to use. Supported values: auto, duckduckgo, google, "
|
||||
"brave, yahoo, wikipedia, yandex, mojeek. "
|
||||
"Use auto unless the user asks for a specific search engine."
|
||||
),
|
||||
)
|
||||
site_url: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional website/domain/URL to limit the search to, for example "
|
||||
"'https://docs.python.org/3/' or 'github.com'."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
"""
|
||||
网络搜索工具,支持 DDGS 搜索引擎和指定站点限定搜索。
|
||||
"""
|
||||
|
||||
name: str = "search_web"
|
||||
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
|
||||
description: str = (
|
||||
"Search the web for information when you need current information, facts, "
|
||||
"or references. Supports DDGS-backed search engine selection, automatic "
|
||||
"fallback, and site_url-limited searches for a specified website "
|
||||
"or URL. Uses the configured system proxy by default. Returns search "
|
||||
"results with titles, snippets, and URLs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SearchWebInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 20)
|
||||
return f"搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
max_results = kwargs.get("max_results", MAX_SEARCH_RESULTS)
|
||||
search_engine = self._normalize_search_engine(kwargs.get("search_engine"))
|
||||
site_url = kwargs.get("site_url")
|
||||
message = f"搜索网络内容: {query} (最多返回 {max_results} 条结果"
|
||||
if search_engine != DEFAULT_SEARCH_ENGINE:
|
||||
message += f",搜索源: {search_engine}"
|
||||
if site_url:
|
||||
message += f",限定站点: {site_url}"
|
||||
return f"{message})"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = MAX_SEARCH_RESULTS,
|
||||
search_engine: Optional[str] = DEFAULT_SEARCH_ENGINE,
|
||||
site_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
执行网络搜索。
|
||||
|
||||
:param query: 搜索关键词
|
||||
:param max_results: 最大返回结果数
|
||||
:param search_engine: 指定搜索源,默认自动选择
|
||||
:param site_url: 指定站点或网址,传入时只返回该范围内的搜索结果
|
||||
:return: JSON格式的搜索结果或错误信息
|
||||
"""
|
||||
search_engine = self._normalize_search_engine(search_engine)
|
||||
if search_engine not in SUPPORTED_SEARCH_ENGINES:
|
||||
supported = ", ".join(SUPPORTED_SEARCH_ENGINES)
|
||||
return f"错误: 不支持的搜索源 '{search_engine}',支持的搜索源: {supported}"
|
||||
|
||||
site_filter = self._normalize_site_filter(site_url)
|
||||
if site_url and not site_filter:
|
||||
return f"错误: site_url 无效,无法限定搜索范围: {site_url}"
|
||||
|
||||
search_query = self._build_search_query(query=query, site_filter=site_filter)
|
||||
if not search_query:
|
||||
return "错误: query 不能为空"
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
|
||||
f"执行工具: {self.name}, 参数: query={query}, "
|
||||
f"max_results={max_results}, search_engine={search_engine}, site_url={site_url}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 20), 20)
|
||||
results = []
|
||||
max_results = min(
|
||||
max(1, max_results or MAX_SEARCH_RESULTS),
|
||||
MAX_SEARCH_RESULTS,
|
||||
)
|
||||
results: List[Dict] = []
|
||||
|
||||
# 1. 优先使用 Exa (如果配置了 API Key)
|
||||
if settings.EXA_API_KEY:
|
||||
logger.info("使用 Exa 进行搜索...")
|
||||
results = await self._search_exa(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Exa,使用 Tavily (如果配置了 API Key)
|
||||
if not results and settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 3. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
for engine in self._get_search_plan(search_engine):
|
||||
results = await self._search_with_backend(
|
||||
engine=engine,
|
||||
query=search_query,
|
||||
max_results=max_results,
|
||||
site_filter=site_filter,
|
||||
)
|
||||
if results:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
return f"未找到与 '{search_query}' 相关的搜索结果"
|
||||
|
||||
# 格式化并裁剪结果
|
||||
formatted_results = self._format_and_truncate_results(results, max_results)
|
||||
@@ -82,81 +170,214 @@ class SearchWebTool(MoviePilotTool):
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
# 从设置中随机选择一个 API Key(如果有多个)
|
||||
tavity_api_key = random.choice(settings.TAVILY_API_KEY)
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": tavity_api_key,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("content", ""),
|
||||
"url": result.get("url", ""),
|
||||
"source": "Tavily",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
def _normalize_search_engine(search_engine: Optional[str]) -> str:
|
||||
"""规范化搜索源参数"""
|
||||
engine = (search_engine or DEFAULT_SEARCH_ENGINE).strip().lower()
|
||||
aliases = {
|
||||
"ddgs": DEFAULT_SEARCH_ENGINE,
|
||||
"ddg": "duckduckgo",
|
||||
"duck": "duckduckgo",
|
||||
"search": DEFAULT_SEARCH_ENGINE,
|
||||
"search_engine": DEFAULT_SEARCH_ENGINE,
|
||||
}
|
||||
return aliases.get(engine, engine)
|
||||
|
||||
@staticmethod
|
||||
async def _search_exa(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Exa API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={
|
||||
"x-api-key": settings.EXA_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"numResults": max_results,
|
||||
"type": "auto",
|
||||
"contents": {"highlights": {"maxCharacters": 2000}},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
def _get_search_plan(search_engine: str) -> List[str]:
|
||||
"""根据搜索源配置生成兜底搜索顺序"""
|
||||
if search_engine != DEFAULT_SEARCH_ENGINE:
|
||||
return [search_engine]
|
||||
return [DEFAULT_SEARCH_ENGINE]
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
highlights = result.get("highlights", [])
|
||||
snippet = (
|
||||
highlights[0] if highlights else result.get("text", "")[:500]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": snippet,
|
||||
"url": result.get("url", ""),
|
||||
"source": "Exa",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Exa 搜索失败: {e}")
|
||||
return []
|
||||
async def _search_with_backend(
|
||||
self,
|
||||
engine: str,
|
||||
query: str,
|
||||
max_results: int,
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用指定后端执行搜索。
|
||||
|
||||
:param engine: 搜索后端名称
|
||||
:param query: 已加工的搜索关键词
|
||||
:param max_results: 最大结果数
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 搜索结果列表
|
||||
"""
|
||||
logger.info(f"使用 DDGS 搜索后端 {self._get_ddgs_backend(engine)} 进行搜索...")
|
||||
return await self._search_ddgs(query, max_results, engine, site_filter)
|
||||
|
||||
@staticmethod
|
||||
def _get_ddgs_backend(search_engine: str) -> str:
|
||||
"""
|
||||
获取实际传给 DDGS 的搜索后端。
|
||||
|
||||
:param search_engine: 用户指定的搜索源
|
||||
:return: DDGS 后端名称或逗号分隔的后端列表
|
||||
"""
|
||||
if search_engine == DEFAULT_SEARCH_ENGINE:
|
||||
return DDGS_AUTO_BACKEND
|
||||
return search_engine
|
||||
|
||||
@staticmethod
|
||||
def _normalize_site_filter(site_url: Optional[str]) -> Optional[_SearchSiteFilter]:
|
||||
"""
|
||||
将用户传入的网址转换为搜索引擎 site 过滤条件。
|
||||
|
||||
:param site_url: 用户传入的站点、域名或完整URL
|
||||
:return: 站点过滤条件,无法解析时返回 None
|
||||
"""
|
||||
if not site_url:
|
||||
return None
|
||||
|
||||
raw_site_url = site_url.strip()
|
||||
if not raw_site_url:
|
||||
return None
|
||||
|
||||
parse_target = raw_site_url
|
||||
if not re.match(r"^https?://", raw_site_url, re.IGNORECASE):
|
||||
parse_target = f"https://{raw_site_url}"
|
||||
|
||||
parsed = urlparse(parse_target)
|
||||
domain = (parsed.hostname or "").lower()
|
||||
if not domain:
|
||||
return None
|
||||
|
||||
path = re.sub(r"/+", "/", parsed.path or "").rstrip("/")
|
||||
search_target = f"{domain}{path}" if path else domain
|
||||
return _SearchSiteFilter(domain=domain, path=path, search_target=search_target)
|
||||
|
||||
@staticmethod
|
||||
def _build_search_query(
|
||||
query: str,
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> str:
|
||||
"""
|
||||
生成实际发送给搜索后端的搜索关键词。
|
||||
|
||||
:param query: 原始搜索关键词
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 加入 site 过滤后的关键词
|
||||
"""
|
||||
search_query = (query or "").strip()
|
||||
if not site_filter or SITE_SEARCH_PATTERN.search(search_query):
|
||||
return search_query
|
||||
if not search_query:
|
||||
return f"site:{site_filter.search_target}"
|
||||
return f"{search_query} site:{site_filter.search_target}"
|
||||
|
||||
@staticmethod
|
||||
def _filter_results_by_site(
|
||||
results: List[Dict],
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
根据指定站点过滤搜索结果。
|
||||
|
||||
:param results: 原始搜索结果
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 站点范围内的搜索结果
|
||||
"""
|
||||
if not site_filter:
|
||||
return results
|
||||
return [
|
||||
result
|
||||
for result in results
|
||||
if SearchWebTool._result_matches_site(result.get("url", ""), site_filter)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _result_matches_site(url: str, site_filter: _SearchSiteFilter) -> bool:
|
||||
"""
|
||||
判断搜索结果 URL 是否属于指定站点。
|
||||
|
||||
:param url: 搜索结果 URL
|
||||
:param site_filter: 站点限定条件
|
||||
:return: URL 属于指定站点时返回 True
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
parse_target = url
|
||||
if not re.match(r"^https?://", url, re.IGNORECASE):
|
||||
parse_target = f"https://{url}"
|
||||
|
||||
parsed = urlparse(parse_target)
|
||||
result_host = SearchWebTool._normalize_host(parsed.hostname or "")
|
||||
target_host = SearchWebTool._normalize_host(site_filter.domain)
|
||||
if not result_host or not target_host:
|
||||
return False
|
||||
if result_host != target_host and not result_host.endswith(f".{target_host}"):
|
||||
return False
|
||||
if not site_filter.path:
|
||||
return True
|
||||
|
||||
result_path = re.sub(r"/+", "/", parsed.path or "").rstrip("/")
|
||||
return result_path == site_filter.path or result_path.startswith(
|
||||
f"{site_filter.path}/"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_host(host: str) -> str:
|
||||
"""
|
||||
标准化域名以便比较。
|
||||
|
||||
:param host: 原始域名
|
||||
:return: 去掉常见 www 前缀后的域名
|
||||
"""
|
||||
normalized_host = (host or "").lower()
|
||||
if normalized_host.startswith("www."):
|
||||
return normalized_host[4:]
|
||||
return normalized_host
|
||||
|
||||
@staticmethod
|
||||
def _source_label(search_engine: str) -> str:
|
||||
"""
|
||||
将搜索源标识转换为结果中的展示名称。
|
||||
|
||||
:param search_engine: 搜索源标识
|
||||
:return: 展示名称
|
||||
"""
|
||||
labels = {
|
||||
"auto": "DDGS",
|
||||
"duckduckgo": "DuckDuckGo",
|
||||
"google": "Google",
|
||||
"brave": "Brave",
|
||||
"yahoo": "Yahoo",
|
||||
"wikipedia": "Wikipedia",
|
||||
"yandex": "Yandex",
|
||||
"mojeek": "Mojeek",
|
||||
}
|
||||
return labels.get(
|
||||
search_engine or DEFAULT_SEARCH_ENGINE,
|
||||
search_engine or "SearchEngine",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_result_url(result: Dict) -> str:
|
||||
"""
|
||||
从不同搜索引擎结果结构中提取 URL。
|
||||
|
||||
:param result: 搜索引擎返回的单条结果
|
||||
:return: URL 字符串
|
||||
"""
|
||||
return result.get("href") or result.get("url") or ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_result_snippet(result: Dict) -> str:
|
||||
"""
|
||||
从不同搜索引擎结果结构中提取摘要。
|
||||
|
||||
:param result: 搜索引擎返回的单条结果
|
||||
:return: 摘要字符串
|
||||
"""
|
||||
return (
|
||||
result.get("body")
|
||||
or result.get("snippet")
|
||||
or result.get("content")
|
||||
or ""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
@@ -167,11 +388,26 @@ class SearchWebTool(MoviePilotTool):
|
||||
return proxy_setting.get("http") or proxy_setting.get("https")
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
async def _search_ddgs(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int,
|
||||
search_engine: str = DEFAULT_SEARCH_ENGINE,
|
||||
site_filter: Optional[_SearchSiteFilter] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用 DDGS 搜索引擎后端进行搜索。
|
||||
|
||||
:param query: 搜索关键词
|
||||
:param max_results: 最大结果数
|
||||
:param search_engine: DDGS搜索后端
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
|
||||
def sync_search():
|
||||
"""在线程中执行同步搜索"""
|
||||
results = []
|
||||
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
@@ -180,26 +416,36 @@ class SearchWebTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(query, max_results=max_results)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
ddgs_results = ddgs.text(
|
||||
query,
|
||||
max_results=max_results,
|
||||
backend=self._get_ddgs_backend(search_engine),
|
||||
)
|
||||
if ddgs_results:
|
||||
for result in ddgs_results:
|
||||
source = (
|
||||
DEFAULT_SEARCH_ENGINE
|
||||
if search_engine == DEFAULT_SEARCH_ENGINE
|
||||
else search_engine
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
"source": "DuckDuckGo",
|
||||
"snippet": self._extract_result_snippet(result),
|
||||
"url": self._extract_result_url(result),
|
||||
"source": self._source_label(source),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
logger.warning(f"搜索引擎搜索进程失败: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, sync_search)
|
||||
results = await loop.run_in_executor(None, sync_search)
|
||||
return self._filter_results_by_site(results, site_filter)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e}")
|
||||
logger.warning(f"搜索引擎搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -38,7 +38,15 @@ def play_item(
|
||||
if item:
|
||||
play_url = media_chain.get_play_url(server=name, item_id=itemid)
|
||||
if play_url:
|
||||
return schemas.Response(success=True, data={"url": play_url})
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
"url": play_url,
|
||||
"item_id": item.item_id or itemid,
|
||||
"server_id": item.server_id,
|
||||
"server_type": item.server,
|
||||
},
|
||||
)
|
||||
return schemas.Response(success=False, message="未找到播放地址")
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
)
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.types import SystemConfigKey, EventType
|
||||
from app.utils.string import StringUtils
|
||||
@@ -161,6 +162,61 @@ async def update_sites_priority(
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
def _update_site_cookie(
|
||||
site_id: int,
|
||||
username: str,
|
||||
password: str,
|
||||
code: Optional[str],
|
||||
db: Session,
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
执行站点 Cookie 与 UA 更新。
|
||||
|
||||
:param site_id: 站点编号
|
||||
:param username: 站点登录用户名
|
||||
:param password: 站点登录密码
|
||||
:param code: 二步验证码或密钥
|
||||
:param db: 数据库会话
|
||||
:return: 更新结果
|
||||
"""
|
||||
site_info = Site.get(db, site_id)
|
||||
if not site_info:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在!",
|
||||
)
|
||||
logger.info(f"开始更新站点【{site_info.name}】Cookie&UA")
|
||||
state, message = SiteChain().update_cookie(
|
||||
site_info=site_info, username=username, password=password, two_step_code=code
|
||||
)
|
||||
if state:
|
||||
logger.info(f"站点【{site_info.name}】Cookie&UA更新成功")
|
||||
else:
|
||||
logger.error(f"站点【{site_info.name}】Cookie&UA更新失败:{message}")
|
||||
return schemas.Response(success=state, message=message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cookie/{site_id}", summary="更新站点Cookie&UA", response_model=schemas.Response
|
||||
)
|
||||
def update_cookie_by_body(
|
||||
site_id: int,
|
||||
site_cookie_update: schemas.SiteCookieUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
使用请求体中的用户密码更新站点Cookie
|
||||
"""
|
||||
return _update_site_cookie(
|
||||
site_id=site_id,
|
||||
username=site_cookie_update.username,
|
||||
password=site_cookie_update.password,
|
||||
code=site_cookie_update.code,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/cookie/{site_id}", summary="更新站点Cookie&UA", response_model=schemas.Response
|
||||
)
|
||||
@@ -175,18 +231,13 @@ def update_cookie(
|
||||
"""
|
||||
使用用户密码更新站点Cookie
|
||||
"""
|
||||
# 查询站点
|
||||
site_info = Site.get(db, site_id)
|
||||
if not site_info:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在!",
|
||||
)
|
||||
# 更新Cookie
|
||||
state, message = SiteChain().update_cookie(
|
||||
site_info=site_info, username=username, password=password, two_step_code=code
|
||||
return _update_site_cookie(
|
||||
site_id=site_id,
|
||||
username=username,
|
||||
password=password,
|
||||
code=code,
|
||||
db=db,
|
||||
)
|
||||
return schemas.Response(success=state, message=message)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -360,13 +360,11 @@ async def fetch_image(
|
||||
|
||||
fetch_url = SecurityUtils.strip_url_signature(url)
|
||||
# 验证URL安全性
|
||||
if not await SecurityUtils.is_safe_url_async(
|
||||
if not await SecurityUtils.is_safe_image_url_async(
|
||||
url,
|
||||
allowed_domains,
|
||||
block_private=True,
|
||||
allowed_private_ranges=settings.IMAGE_PROXY_ALLOWED_PRIVATE_RANGES,
|
||||
) and not (fetch_url := SecurityUtils.verify_signed_url(url)):
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
):
|
||||
return None
|
||||
|
||||
content = await ImageHelper().async_fetch_image(
|
||||
|
||||
@@ -492,7 +492,8 @@ class DownloadChain(ChainBase):
|
||||
season=not_exist.season,
|
||||
episodes=need,
|
||||
total_episode=not_exist.total_episode,
|
||||
start_episode=not_exist.start_episode
|
||||
start_episode=not_exist.start_episode,
|
||||
require_complete_coverage=not_exist.require_complete_coverage
|
||||
)
|
||||
else:
|
||||
no_exists[_mid].pop(_sea)
|
||||
@@ -511,6 +512,34 @@ class DownloadChain(ChainBase):
|
||||
return 9999
|
||||
return no_exist[season].total_episode
|
||||
|
||||
def __get_no_exist_media(_mid: Union[int, str], season: int) -> Optional[NotExistMediaInfo]:
|
||||
"""
|
||||
获取指定媒体和季的缺失信息。
|
||||
"""
|
||||
if not no_exists or not no_exists.get(_mid):
|
||||
return None
|
||||
return no_exists.get(_mid).get(season)
|
||||
|
||||
def __get_required_episodes(_mid: Union[int, str], season: int) -> Set[int]:
|
||||
"""
|
||||
获取整季候选必须覆盖的目标集范围。
|
||||
"""
|
||||
tv = __get_no_exist_media(_mid, season)
|
||||
if not tv:
|
||||
return set()
|
||||
if not tv.total_episode:
|
||||
return set()
|
||||
start = tv.start_episode or 1
|
||||
return set(range(start, tv.total_episode + 1))
|
||||
|
||||
def __requires_complete_coverage(_tv: Optional[NotExistMediaInfo]) -> bool:
|
||||
"""
|
||||
判断当前缺失范围是否要求候选资源完整覆盖目标范围。
|
||||
"""
|
||||
if not _tv:
|
||||
return False
|
||||
return bool(_tv.require_complete_coverage)
|
||||
|
||||
def __apply_allowed_episodes(_need_episodes, _context: Context) -> Set[int]:
|
||||
"""
|
||||
根据候选携带的允许集裁剪 need_episodes,返回真正可下载的剧集集合。
|
||||
@@ -616,13 +645,23 @@ class DownloadChain(ChainBase):
|
||||
logger.info(f"{meta.org_string} 解析种子文件集数为 {torrent_episodes}")
|
||||
if not torrent_episodes:
|
||||
continue
|
||||
torrent_episodes_set = set(torrent_episodes)
|
||||
# 更新集数范围
|
||||
begin_ep = min(torrent_episodes)
|
||||
end_ep = max(torrent_episodes)
|
||||
meta.set_episodes(begin=begin_ep, end=end_ep)
|
||||
# 需要总集数
|
||||
# 需要目标集范围;完整覆盖场景必须覆盖范围内每一集,不能只按数量判断。
|
||||
need_tv_info = __get_no_exist_media(need_mid, torrent_season[0])
|
||||
required_episodes = __get_required_episodes(need_mid, torrent_season[0]) \
|
||||
if __requires_complete_coverage(need_tv_info) else set()
|
||||
need_total = __get_season_episodes(need_mid, torrent_season[0])
|
||||
if len(torrent_episodes) < need_total:
|
||||
if required_episodes and not required_episodes.issubset(torrent_episodes_set):
|
||||
missing_episodes = sorted(required_episodes.difference(torrent_episodes_set))
|
||||
logger.info(
|
||||
f"{meta.org_string} 解析文件集数未覆盖目标范围,"
|
||||
f"缺少 {StringUtils.format_ep(missing_episodes)},先放弃这个种子")
|
||||
continue
|
||||
if not required_episodes and need_total and len(torrent_episodes) < need_total:
|
||||
logger.info(
|
||||
f"{meta.org_string} 解析文件集数发现不是完整合集,先放弃这个种子")
|
||||
continue
|
||||
@@ -713,8 +752,15 @@ class DownloadChain(ChainBase):
|
||||
effective_need = __apply_allowed_episodes(need_episodes, context)
|
||||
if not effective_need:
|
||||
continue
|
||||
# 为需要集的子集则下载
|
||||
if torrent_episodes.issubset(effective_need):
|
||||
if __requires_complete_coverage(tv):
|
||||
# 完整覆盖任务要求候选集数覆盖目标范围,允许资源包含范围外的额外集。
|
||||
required_episodes = __get_required_episodes(need_mid, need_season)
|
||||
match_episodes = required_episodes.issubset(torrent_episodes) \
|
||||
if required_episodes else False
|
||||
else:
|
||||
# 普通缺集下载保持原语义:候选自身必须是所需集的子集。
|
||||
match_episodes = torrent_episodes.issubset(effective_need)
|
||||
if match_episodes:
|
||||
# 下载
|
||||
logger.info(f"开始下载 {meta.title} ...")
|
||||
download_id = self.download_single(context, save_path=save_path,
|
||||
@@ -752,6 +798,8 @@ class DownloadChain(ChainBase):
|
||||
need_season = sea
|
||||
# 当前需要集
|
||||
need_episodes = tv.episodes
|
||||
if __requires_complete_coverage(tv):
|
||||
continue
|
||||
# 没有集的不处理
|
||||
if not need_episodes:
|
||||
continue
|
||||
|
||||
@@ -37,7 +37,8 @@ from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas import (MediaRecognizeConvertEventData, SubscribeEpisodesRefreshEventData,
|
||||
SubscribeCompletionCheckEventData)
|
||||
from app.schemas.types import MediaType, SystemConfigKey, MessageChannel, NotificationType, EventType, ChainEventType, \
|
||||
ContentType
|
||||
|
||||
@@ -640,6 +641,10 @@ class SubscribeChain(ChainBase):
|
||||
logger.error(f"媒体信息中没有季集信息,标题:{title},tmdbid:{tmdbid},doubanid:{doubanid}")
|
||||
return None, "媒体信息中没有季集信息"
|
||||
total_episode = len(mediainfo.seasons.get(season) or [])
|
||||
# 允许外部覆盖按 TMDB 算出的总集数(如待定集数)
|
||||
total_episode = self.__apply_episodes_refresh(
|
||||
total_episode, season=season, mediainfo=mediainfo,
|
||||
tmdbid=mediainfo.tmdb_id, doubanid=mediainfo.douban_id, scene="create")
|
||||
if not total_episode:
|
||||
logger.error(f'未获取到总集数,标题:{title},tmdbid:{tmdbid}, doubanid:{doubanid}')
|
||||
return None, f"未获取到第 {season} 季的总集数"
|
||||
@@ -821,6 +826,10 @@ class SubscribeChain(ChainBase):
|
||||
logger.error(f"媒体信息中没有季集信息,标题:{title},tmdbid:{tmdbid},doubanid:{doubanid}")
|
||||
return None, "媒体信息中没有季集信息"
|
||||
total_episode = len(mediainfo.seasons.get(season) or [])
|
||||
# 允许外部覆盖按 TMDB 算出的总集数(如待定集数)
|
||||
total_episode = await self.__async_apply_episodes_refresh(
|
||||
total_episode, season=season, mediainfo=mediainfo,
|
||||
tmdbid=mediainfo.tmdb_id, doubanid=mediainfo.douban_id, scene="create")
|
||||
if not total_episode:
|
||||
logger.error(f'未获取到总集数,标题:{title},tmdbid:{tmdbid}, doubanid:{doubanid}')
|
||||
return None, f"未获取到第 {season} 季的总集数"
|
||||
@@ -1687,8 +1696,12 @@ class SubscribeChain(ChainBase):
|
||||
current_priority = None
|
||||
if not subscribe.manual_total_episode and len(episodes):
|
||||
total_episode = len(episodes)
|
||||
# 总集数增长按 delta 同步抬升 lack
|
||||
lack_episode = (subscribe.lack_episode or 0) + (total_episode - (subscribe.total_episode or 0))
|
||||
# 允许外部覆盖按 TMDB 算出的总集数(如待定集数)
|
||||
total_episode = self.__apply_episodes_refresh(
|
||||
total_episode, season=subscribe.season, mediainfo=mediainfo,
|
||||
tmdbid=subscribe.tmdbid, doubanid=subscribe.doubanid,
|
||||
subscribe_id=subscribe.id, scene="refresh")
|
||||
lack_episode = max((subscribe.lack_episode or 0) + (total_episode - (subscribe.total_episode or 0)), 0)
|
||||
if subscribe.best_version and subscribe.type == MediaType.TV.value:
|
||||
# 为新增集补齐 episode_priority 初始项(priority=0)
|
||||
old_total_episode = subscribe.total_episode or 0
|
||||
@@ -1990,6 +2003,15 @@ class SubscribeChain(ChainBase):
|
||||
# 如果订阅状态为待定(P),说明订阅信息尚未完全更新,无法完成订阅
|
||||
if subscribe.state == "P":
|
||||
return
|
||||
# 发送订阅完成判定事件,在写入 DB 前,允许外部据完结策略否决本次自动完成
|
||||
completion_event = eventmanager.send_event(
|
||||
ChainEventType.SubscribeCompletionCheck,
|
||||
SubscribeCompletionCheckEventData(subscribe=subscribe, mediainfo=mediainfo, meta=meta))
|
||||
if completion_event and completion_event.event_data:
|
||||
completion_data: SubscribeCompletionCheckEventData = completion_event.event_data
|
||||
if completion_data.cancel:
|
||||
logger.info(f'{mediainfo.title_year} 完成被 [{completion_data.source}] 否决:{completion_data.reason}')
|
||||
return
|
||||
# 完成订阅
|
||||
msgstr = "订阅" if not subscribe.best_version else "洗版"
|
||||
logger.info(f'{mediainfo.title_year} 完成{msgstr}')
|
||||
@@ -3079,7 +3101,9 @@ class SubscribeChain(ChainBase):
|
||||
season=subscribe.season,
|
||||
episodes=pending_episodes,
|
||||
total_episode=subscribe.total_episode,
|
||||
start_episode=subscribe.start_episode or 1)
|
||||
start_episode=subscribe.start_episode or 1,
|
||||
# 完整覆盖约束会影响整季文件探测、显式集列表匹配和多集拆包降级。
|
||||
require_complete_coverage=self.__is_full_best_version_enabled(subscribe))
|
||||
}
|
||||
}
|
||||
else:
|
||||
@@ -3122,7 +3146,52 @@ class SubscribeChain(ChainBase):
|
||||
return False, no_exists
|
||||
|
||||
@staticmethod
|
||||
def __refresh_total_episode_before_completion(subscribe: Subscribe, mediainfo: MediaInfo):
|
||||
def __apply_episodes_refresh(current_total: int, season: Optional[int], *,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
subscribe_id: Optional[int] = None,
|
||||
scene: Optional[str] = None) -> int:
|
||||
"""
|
||||
发送订阅总集数推算事件,允许外部据自身策略覆盖按 TMDB 季集数算出的总集数。
|
||||
|
||||
用途:插件在"待定集数"等场景经事件注入 total_episode
|
||||
无监听者或外部未覆盖时返回入参原值,保证零行为变更。
|
||||
:param current_total: 主程序按 TMDB 季集数算出的默认总集数
|
||||
:param season: 季号
|
||||
:return: 最终采用的总集数
|
||||
"""
|
||||
event_data = SubscribeEpisodesRefreshEventData(
|
||||
tmdbid=tmdbid, doubanid=doubanid, season=season, mediainfo=mediainfo,
|
||||
current_total_episode=current_total, subscribe_id=subscribe_id, scene=scene)
|
||||
event = eventmanager.send_event(ChainEventType.SubscribeEpisodesRefresh, event_data)
|
||||
if event and event.event_data:
|
||||
result: SubscribeEpisodesRefreshEventData = event.event_data
|
||||
if result.updated and result.total_episode:
|
||||
return result.total_episode
|
||||
return current_total
|
||||
|
||||
@staticmethod
|
||||
async def __async_apply_episodes_refresh(current_total: int, season: Optional[int], *,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
subscribe_id: Optional[int] = None,
|
||||
scene: Optional[str] = None) -> int:
|
||||
"""
|
||||
__apply_episodes_refresh 的异步版本
|
||||
"""
|
||||
event_data = SubscribeEpisodesRefreshEventData(
|
||||
tmdbid=tmdbid, doubanid=doubanid, season=season, mediainfo=mediainfo,
|
||||
current_total_episode=current_total, subscribe_id=subscribe_id, scene=scene)
|
||||
event = await eventmanager.async_send_event(ChainEventType.SubscribeEpisodesRefresh, event_data)
|
||||
if event and event.event_data:
|
||||
result: SubscribeEpisodesRefreshEventData = event.event_data
|
||||
if result.updated and result.total_episode:
|
||||
return result.total_episode
|
||||
return current_total
|
||||
|
||||
def __refresh_total_episode_before_completion(self, subscribe: Subscribe, mediainfo: MediaInfo):
|
||||
"""
|
||||
在完成判断前,按最新识别结果兜底修正订阅总集数,防止旧总集数导致误完成。
|
||||
"""
|
||||
@@ -3134,6 +3203,11 @@ class SubscribeChain(ChainBase):
|
||||
return
|
||||
|
||||
new_total_episode = len((mediainfo.seasons or {}).get(subscribe.season) or [])
|
||||
# 允许外部覆盖按 TMDB 算出的总集数(如待定集数),后续“只增不减”仍作用于覆盖后的结果,避免误减导致提前完成。
|
||||
new_total_episode = self.__apply_episodes_refresh(
|
||||
new_total_episode, season=subscribe.season, mediainfo=mediainfo,
|
||||
tmdbid=subscribe.tmdbid, doubanid=subscribe.doubanid,
|
||||
subscribe_id=subscribe.id, scene="precheck")
|
||||
old_total_episode = subscribe.total_episode or 0
|
||||
if not new_total_episode or new_total_episode <= old_total_episode:
|
||||
return
|
||||
|
||||
@@ -223,7 +223,7 @@ def _release_sort_key(tag: str) -> tuple[int, ...]:
|
||||
def _github_api_json(url: str, *, repo: str) -> Any:
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": "MoviePilot-CLI",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
}
|
||||
headers.update(settings.REPO_GITHUB_HEADERS(repo))
|
||||
opener = build_opener(ProxyHandler(settings.PROXY or {}))
|
||||
|
||||
@@ -588,14 +588,6 @@ class ConfigModel(BaseModel):
|
||||
AI_RECOMMEND_ENABLED: bool = False
|
||||
# AI推荐用户偏好
|
||||
AI_RECOMMEND_USER_PREFERENCE: str = ""
|
||||
# Tavily API密钥(用于网络搜索)
|
||||
TAVILY_API_KEY: List[str] = [
|
||||
"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",
|
||||
"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",
|
||||
"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",
|
||||
]
|
||||
# Exa API密钥(用于网络搜索)
|
||||
EXA_API_KEY: str = "161ce010-fb56-419c-9ea8-4fb459b96298"
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
|
||||
@@ -685,7 +685,10 @@ class MediaInfo:
|
||||
if infobox:
|
||||
akas = [item.get("value") for item in infobox if item.get("key") == "别名"]
|
||||
if akas:
|
||||
self.names = [aka.get("v") for aka in akas[0]]
|
||||
if isinstance(akas[0], list):
|
||||
self.names = [aka.get("v") if isinstance(aka, dict) else aka for aka in akas[0]]
|
||||
elif isinstance(akas[0], str):
|
||||
self.names = [akas[0]]
|
||||
|
||||
# 剧集
|
||||
if self.type == MediaType.TV and not self.seasons:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import pickle
|
||||
from typing import Any, Optional, Generator, Tuple, AsyncGenerator, Union
|
||||
@@ -320,12 +321,18 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
self.redis_url = settings.CACHE_BACKEND_URL
|
||||
self.client: Optional[Redis] = None
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
async def _connect(self):
|
||||
"""
|
||||
建立异步Redis连接
|
||||
"""
|
||||
try:
|
||||
current_loop = asyncio.get_running_loop()
|
||||
# 检测事件循环是否发生变化,如果变化则重新连接
|
||||
if self.client is not None and self._loop is not current_loop:
|
||||
logger.debug("Event loop changed, reconnecting Redis (async)")
|
||||
await self._close_client()
|
||||
if self.client is None:
|
||||
self.client = Redis.from_url(
|
||||
self.redis_url,
|
||||
@@ -334,6 +341,7 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
socket_connect_timeout=_socket_connect_timeout,
|
||||
health_check_interval=_health_check_interval,
|
||||
)
|
||||
self._loop = current_loop
|
||||
# 测试连接,确保Redis可用
|
||||
await self.client.ping()
|
||||
logger.info(f"Successfully connected to Redis (async):{self.redis_url}")
|
||||
@@ -341,10 +349,23 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis (async): {e}")
|
||||
self.client = None
|
||||
self._loop = None
|
||||
raise RuntimeError("Redis async connection failed") from e
|
||||
|
||||
async def _close_client(self):
|
||||
"""
|
||||
关闭当前Redis客户端连接
|
||||
"""
|
||||
if self.client:
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.client = None
|
||||
self._loop = None
|
||||
|
||||
async def on_config_changed(self):
|
||||
await self.close()
|
||||
await self._close_client()
|
||||
await self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
@@ -526,7 +547,5 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
关闭异步Redis客户端的连接池
|
||||
"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
logger.debug("Redis async connection closed")
|
||||
await self._close_client()
|
||||
logger.debug("Redis async connection closed")
|
||||
|
||||
@@ -102,6 +102,7 @@ class MoviePilotServerHelper:
|
||||
user_uid = cls.get_user_uid()
|
||||
if user_uid:
|
||||
request_headers[cls.USER_UID_HEADER] = user_uid
|
||||
request_headers["User-Agent"] = settings.USER_AGENT
|
||||
return request_headers
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,19 +12,31 @@ from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class BangumiModule(_ModuleBase):
|
||||
"""
|
||||
Bangumi媒体信息匹配
|
||||
"""
|
||||
CONFIG_WATCH = {"PROXY_HOST"}
|
||||
|
||||
bangumiapi: BangumiApi = None
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化Bangumi客户端
|
||||
"""
|
||||
self.bangumiapi = BangumiApi()
|
||||
|
||||
def stop(self):
|
||||
self.bangumiapi.close()
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
关闭Bangumi客户端
|
||||
"""
|
||||
if self.bangumiapi:
|
||||
self.bangumiapi.close()
|
||||
|
||||
def test(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试模块连接性
|
||||
"""
|
||||
ret = RequestUtils().get_res("https://api.bgm.tv/")
|
||||
ret = RequestUtils(proxies=settings.PROXY).get_res("https://api.bgm.tv/")
|
||||
if ret and ret.status_code == 200:
|
||||
return True, ""
|
||||
elif ret:
|
||||
@@ -36,6 +48,9 @@ class BangumiModule(_ModuleBase):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
"""
|
||||
获取模块名称
|
||||
"""
|
||||
return "Bangumi"
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -10,7 +10,9 @@ from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
|
||||
class BangumiApi(object):
|
||||
"""
|
||||
https://bangumi.github.io/api/
|
||||
Bangumi API客户端。
|
||||
|
||||
接口文档:https://bangumi.github.io/api/
|
||||
"""
|
||||
|
||||
_urls = {
|
||||
@@ -28,8 +30,15 @@ class BangumiApi(object):
|
||||
|
||||
def __init__(self):
|
||||
self._session = requests.Session()
|
||||
self._req = RequestUtils(ua=settings.NORMAL_USER_AGENT, session=self._session)
|
||||
self._async_req = AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT)
|
||||
self._req = RequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=settings.PROXY,
|
||||
session=self._session,
|
||||
)
|
||||
self._async_req = AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=settings.PROXY,
|
||||
)
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
|
||||
def __invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
@@ -306,6 +315,9 @@ class BangumiApi(object):
|
||||
"""
|
||||
self.__invoke.cache_clear()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
关闭Bangumi会话
|
||||
"""
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
@@ -160,16 +160,20 @@ class Emby:
|
||||
else:
|
||||
library_type = MediaType.UNKNOWN.value
|
||||
image = self.__get_local_image_by_id(library.get("Id"))
|
||||
server_id = library.get("ServerId") or self.serverid
|
||||
server_query = f"serverId={server_id}&" if server_id else ""
|
||||
libraries.append(
|
||||
schemas.MediaServerLibrary(
|
||||
server="emby",
|
||||
id=library.get("Id"),
|
||||
item_id=library.get("Id"),
|
||||
server_id=server_id,
|
||||
name=library.get("Name"),
|
||||
path=library.get("Path"),
|
||||
type=library_type,
|
||||
image=image,
|
||||
link=f'{self._playhost or self._host}web/index.html'
|
||||
f'#!/videos?serverId={self.serverid}&parentId={library.get("Id")}',
|
||||
f'#!/videos?{server_query}parentId={library.get("Id")}',
|
||||
server_type="emby"
|
||||
)
|
||||
)
|
||||
@@ -247,19 +251,22 @@ class Emby:
|
||||
"""
|
||||
if not self._host or not self._apikey:
|
||||
return None
|
||||
url = f"{self._host}System/Info"
|
||||
params = {
|
||||
'api_key': self._apikey
|
||||
}
|
||||
try:
|
||||
res = RequestUtils().get_res(url, params)
|
||||
if res:
|
||||
return res.json().get("Id")
|
||||
else:
|
||||
logger.error(f"System/Info 未获取到返回数据")
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"连接System/Info出错:" + str(e))
|
||||
for path in ("System/Info", "emby/System/Info"):
|
||||
url = f"{self._host}{path}"
|
||||
try:
|
||||
res = RequestUtils().get_res(url, params)
|
||||
if res:
|
||||
result = res.json() or {}
|
||||
server_id = result.get("Id") or result.get("ServerId")
|
||||
if server_id:
|
||||
return server_id
|
||||
else:
|
||||
logger.error(f"{path} 未获取到返回数据")
|
||||
except Exception as e:
|
||||
logger.error(f"连接{path}出错:" + str(e))
|
||||
return None
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
@@ -648,6 +655,7 @@ class Emby:
|
||||
return schemas.MediaServerItem(
|
||||
server="emby",
|
||||
library=item.get("ParentId"),
|
||||
server_id=item.get("ServerId"),
|
||||
item_id=item.get("Id"),
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
@@ -1088,13 +1096,16 @@ class Emby:
|
||||
logger.error(f"连接Emby出错:" + str(e))
|
||||
return None
|
||||
|
||||
def get_play_url(self, item_id: str) -> str:
|
||||
def get_play_url(self, item_id: str, server_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
拼装媒体播放链接
|
||||
:param item_id: 媒体的的ID
|
||||
:param server_id: 媒体服务器ID
|
||||
"""
|
||||
server_id = server_id or self.serverid
|
||||
server_query = f"&serverId={server_id}" if server_id else ""
|
||||
return f"{self._playhost or self._host}web/index.html#!" \
|
||||
f"/item?id={item_id}&context=home&serverId={self.serverid}"
|
||||
f"/item?id={item_id}&context=home{server_query}"
|
||||
|
||||
def get_backdrop_url(self, item_id: str, image_tag: str, remote: Optional[bool] = False) -> str:
|
||||
"""
|
||||
@@ -1160,7 +1171,8 @@ class Emby:
|
||||
str(item_path).startswith(folder) for folder in library_folders):
|
||||
continue
|
||||
item_type = MediaType.MOVIE.value if item.get("Type") == "Movie" else MediaType.TV.value
|
||||
link = self.get_play_url(item.get("Id"))
|
||||
server_id = item.get("ServerId") or self.serverid
|
||||
link = self.get_play_url(item.get("Id"), server_id=server_id)
|
||||
if item_type == MediaType.MOVIE.value:
|
||||
title = item.get("Name")
|
||||
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
|
||||
@@ -1180,6 +1192,8 @@ class Emby:
|
||||
image = self.__get_local_image_by_id(item.get("SeriesId"))
|
||||
ret_resume.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
item_id=item.get("Id"),
|
||||
server_id=server_id,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
type=item_type,
|
||||
@@ -1230,10 +1244,13 @@ class Emby:
|
||||
str(item_path).startswith(folder) for folder in library_folders):
|
||||
continue
|
||||
item_type = MediaType.MOVIE.value if item.get("Type") == "Movie" else MediaType.TV.value
|
||||
link = self.get_play_url(item.get("Id"))
|
||||
server_id = item.get("ServerId") or self.serverid
|
||||
link = self.get_play_url(item.get("Id"), server_id=server_id)
|
||||
image = self.__get_local_image_by_id(item_id=item.get("Id"))
|
||||
ret_latest.append(schemas.MediaServerPlayItem(
|
||||
id=item.get("Id"),
|
||||
item_id=item.get("Id"),
|
||||
server_id=server_id,
|
||||
title=item.get("Name"),
|
||||
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
|
||||
type=item_type,
|
||||
|
||||
@@ -1137,6 +1137,7 @@ class TransHandler:
|
||||
meta = MetaInfoPath(path)
|
||||
season = meta.season
|
||||
episode = meta.episode
|
||||
part = meta.part
|
||||
logger.warn(f"正在删除目标目录中其它版本的文件:{path.parent}")
|
||||
# 获取父目录
|
||||
parent_item = storage_oper.get_item(path.parent)
|
||||
@@ -1163,6 +1164,9 @@ class TransHandler:
|
||||
# 相同季集的文件才删除
|
||||
if filemeta.season != season or filemeta.episode != episode:
|
||||
continue
|
||||
# 相同 Part 的文件才删除,避免误删多 Part 文件 (issue #5862)
|
||||
if part and filemeta.part and filemeta.part != part:
|
||||
continue
|
||||
logger.info(f"正在删除文件:{media_file.name}")
|
||||
storage_oper.delete(media_file)
|
||||
return True
|
||||
|
||||
@@ -549,3 +549,71 @@ class StorageOperSelectionEventData(ChainEventData):
|
||||
|
||||
# 输出参数
|
||||
storage_oper: Optional[Callable] = Field(default=None, description="存储操作对象")
|
||||
|
||||
|
||||
class SubscribeEpisodesRefreshEventData(ChainEventData):
|
||||
"""
|
||||
SubscribeEpisodesRefresh 事件的数据模型
|
||||
|
||||
主程序在推算订阅某季总集数时发出,携带按 TMDB 季集数算出的默认值,外部可据自身策略覆盖total_episode(如待定集数)
|
||||
|
||||
Attributes:
|
||||
# 输入参数
|
||||
tmdbid (Optional[int]): TMDB ID
|
||||
doubanid (Optional[str]): 豆瓣 ID
|
||||
season (Optional[int]): 季号
|
||||
mediainfo (Any): 媒体信息
|
||||
current_total_episode (int): 主程序按 TMDB 季集数算出的默认总集数
|
||||
subscribe_id (Optional[int]): 订阅 ID;订阅创建场景下尚未入库,为空
|
||||
scene (Optional[str]): 触发场景,create/refresh/precheck
|
||||
|
||||
# 输出参数
|
||||
updated (bool): 外部是否覆盖了总集数,默认 False
|
||||
total_episode (Optional[int]): 覆盖后的总集数,仅在 updated=True 时生效
|
||||
source (str): 覆盖来源
|
||||
reason (str): 覆盖原因
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
tmdbid: Optional[int] = Field(default=None, description="TMDB ID")
|
||||
doubanid: Optional[str] = Field(default=None, description="豆瓣 ID")
|
||||
season: Optional[int] = Field(default=None, description="季号")
|
||||
mediainfo: Any = Field(default=None, description="媒体信息")
|
||||
current_total_episode: int = Field(default=0, description="按 TMDB 季集数算出的默认总集数")
|
||||
subscribe_id: Optional[int] = Field(default=None, description="订阅 ID;创建场景为空")
|
||||
scene: Optional[str] = Field(default=None, description="触发场景:create/refresh/precheck")
|
||||
|
||||
# 输出参数
|
||||
updated: bool = Field(default=False, description="外部是否覆盖了总集数")
|
||||
total_episode: Optional[int] = Field(default=None, description="覆盖后的总集数")
|
||||
source: str = Field(default="未知来源", description="覆盖来源")
|
||||
reason: str = Field(default="", description="覆盖原因")
|
||||
|
||||
|
||||
class SubscribeCompletionCheckEventData(ChainEventData):
|
||||
"""
|
||||
SubscribeCompletionCheck 事件的数据模型
|
||||
|
||||
在订阅被自动判定完成、即将收口(写历史并删除)之前发出,允许外部据完结策略否决本次完成
|
||||
|
||||
Attributes:
|
||||
# 输入参数
|
||||
subscribe (Any): 订阅对象
|
||||
mediainfo (Any): 媒体信息
|
||||
meta (Any): 元数据
|
||||
|
||||
# 输出参数
|
||||
cancel (bool): 是否否决本次完成,默认 False
|
||||
source (str): 否决来源
|
||||
reason (str): 否决原因
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
subscribe: Any = Field(default=None, description="订阅对象")
|
||||
mediainfo: Any = Field(default=None, description="媒体信息")
|
||||
meta: Any = Field(default=None, description="元数据")
|
||||
|
||||
# 输出参数
|
||||
cancel: bool = Field(default=False, description="是否否决本次完成")
|
||||
source: str = Field(default="未知来源", description="否决来源")
|
||||
reason: str = Field(default="", description="否决原因")
|
||||
|
||||
@@ -34,6 +34,8 @@ class NotExistMediaInfo(BaseModel):
|
||||
total_episode: Optional[int] = 0
|
||||
# 开始集
|
||||
start_episode: Optional[int] = 0
|
||||
# 候选资源须完整覆盖目标范围
|
||||
require_complete_coverage: Optional[bool] = False
|
||||
|
||||
|
||||
class RefreshMediaItem(BaseModel):
|
||||
@@ -60,6 +62,10 @@ class MediaServerLibrary(BaseModel):
|
||||
server: Optional[str] = None
|
||||
# ID
|
||||
id: Optional[Union[str, int]] = None
|
||||
# 媒体服务器项目ID
|
||||
item_id: Optional[Union[str, int]] = None
|
||||
# 媒体服务器ID
|
||||
server_id: Optional[str] = None
|
||||
# 名称
|
||||
name: Optional[str] = None
|
||||
# 路径
|
||||
@@ -101,6 +107,8 @@ class MediaServerItem(BaseModel):
|
||||
server: Optional[str] = None
|
||||
# 媒体库ID
|
||||
library: Optional[Union[str, int]] = None
|
||||
# 媒体服务器ID
|
||||
server_id: Optional[str] = None
|
||||
# ID
|
||||
item_id: Optional[str] = None
|
||||
# 类型
|
||||
@@ -171,6 +179,8 @@ class MediaServerPlayItem(BaseModel):
|
||||
媒体服务器可播放项目信息
|
||||
"""
|
||||
id: Optional[Union[str, int]] = None
|
||||
item_id: Optional[Union[str, int]] = None
|
||||
server_id: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
subtitle: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
@@ -115,6 +115,15 @@ class SiteAuth(BaseModel):
|
||||
params: Optional[Dict[str, Union[int, str]]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SiteCookieUpdate(BaseModel):
|
||||
"""
|
||||
站点 Cookie 与 UA 更新请求。
|
||||
"""
|
||||
username: str = Field(..., description="站点登录用户名")
|
||||
password: str = Field(..., description="站点登录密码")
|
||||
code: Optional[str] = Field(None, description="二步验证码或密钥")
|
||||
|
||||
|
||||
class SiteCategory(BaseModel):
|
||||
id: Optional[int] = None
|
||||
cat: Optional[str] = None
|
||||
|
||||
@@ -179,6 +179,10 @@ class ChainEventType(Enum):
|
||||
StorageOperSelection = "storage.operation"
|
||||
# Agent LLM 供应商选择
|
||||
AgentLLMProvider = "agent.llm.provider"
|
||||
# 订阅总集数刷新
|
||||
SubscribeEpisodesRefresh = "subscribe.episodes.refresh"
|
||||
# 订阅完成检查
|
||||
SubscribeCompletionCheck = "subscribe.completion.check"
|
||||
|
||||
|
||||
# 系统配置Key字典
|
||||
|
||||
210
app/utils/coalesce.py
Normal file
210
app/utils/coalesce.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
通用时间窗口事件合并器。
|
||||
|
||||
定位:在固定时间窗口内对相同 key 的重复事件做合并,避免下游(通常是日志、告警、上报)被高频重复事件刷爆。
|
||||
|
||||
典型场景:同一原因的高频拦截 warning、同一目标的连续失败告警、同一错误码的批量上报——首条事件立即输出保留上下文,
|
||||
后续命中在窗口内合并为一条计数摘要。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, Hashable, Optional, Union
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class CoalesceDecision(Enum):
|
||||
"""
|
||||
`EventCoalescer.record` 的返回值,告知调用方对当前事件应采取的动作。
|
||||
"""
|
||||
|
||||
# 首次事件:调用方应立即按原样输出(写日志、发告警等)
|
||||
EMIT = "emit"
|
||||
# 窗口内已合并:调用方静默,不再输出
|
||||
SUPPRESS = "suppress"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoalesceSummary:
|
||||
"""
|
||||
窗口结束时回调给 `on_flush` 的聚合摘要,描述该窗口内被合并的事件。
|
||||
"""
|
||||
|
||||
# 聚合键,与 `record` 调用方传入的 key 一致
|
||||
key: Hashable
|
||||
# 窗口内同 key 命中总次数,包含首条 EMIT 的事件
|
||||
count: int
|
||||
# 首条事件的 payload,便于摘要里附"样例"以减少信息丢失
|
||||
first_payload: Any
|
||||
# 该窗口的时长(秒),与 coalescer 构造时一致
|
||||
window_seconds: float
|
||||
|
||||
|
||||
# `on_flush` 回调签名:同步或 async 均可,由 coalescer 内部按需调度
|
||||
OnFlushCallback = Callable[[CoalesceSummary], Union[Awaitable[None], None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BucketState:
|
||||
"""
|
||||
单个 key 的窗口内状态;仅供 `EventCoalescer` 内部使用。
|
||||
"""
|
||||
|
||||
# 首条事件 payload,原样保留用于 flush 摘要
|
||||
first_payload: Any
|
||||
# 窗口内累计命中次数(含首条)
|
||||
count: int
|
||||
# `loop.call_later` 返回的 handle,用于 close() 时取消
|
||||
flush_handle: Optional[asyncio.TimerHandle]
|
||||
|
||||
|
||||
class EventCoalescer:
|
||||
"""
|
||||
时间窗口内对相同 key 的重复事件做合并;
|
||||
|
||||
工作流程:
|
||||
- 首次出现某 key:`record` 返回 `EMIT`,调用方按原样输出事件;
|
||||
coalescer 通过 `loop.call_later(window_seconds, ...)` 注册 flush。
|
||||
- 窗口内同 key 再次出现:`record` 返回 `SUPPRESS`,累加计数。
|
||||
- 窗口到期:取出 bucket,若 `count > 1` 则触发 `on_flush(summary)`;
|
||||
`count == 1` 时认为单次事件已被首条 EMIT 完整表达,不再补摘要。
|
||||
|
||||
线程模型:所有公开方法均为 `async`,仅设计在单个事件循环内使用。
|
||||
bucket 字典的读写均落在不含 `await` 的同步段内,靠事件循环的协作式调度
|
||||
天然原子,因此不需要显式锁;也避免了模块级实例化的 `asyncio.Lock` 在
|
||||
跨事件循环复用时可能触发的 `RuntimeError`。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_seconds: float,
|
||||
on_flush: OnFlushCallback,
|
||||
source: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
:param window_seconds: 合并窗口时长(秒),必须 > 0
|
||||
:param on_flush: 窗口到期且 count>1 时回调;同步或 async 函数均可
|
||||
:param source: 业务来源标识,仅用于内部 debug 日志的前缀,便于区分多
|
||||
个 coalescer 的来源;不会出现在 `on_flush` 摘要里
|
||||
"""
|
||||
if window_seconds <= 0:
|
||||
raise ValueError("window_seconds must be positive")
|
||||
self._window_seconds = window_seconds
|
||||
self._on_flush = on_flush
|
||||
self._source = source
|
||||
self._buckets: Dict[Hashable, _BucketState] = {}
|
||||
self._is_flush_async = inspect.iscoroutinefunction(on_flush)
|
||||
|
||||
@property
|
||||
def window_seconds(self) -> float:
|
||||
"""
|
||||
合并窗口时长(秒),供外部只读。
|
||||
"""
|
||||
return self._window_seconds
|
||||
|
||||
async def record(
|
||||
self, key: Hashable, payload: Any = None
|
||||
) -> CoalesceDecision:
|
||||
"""
|
||||
登记一次事件。
|
||||
|
||||
:param key: 聚合键,必须可哈希;推荐使用 tuple 组合业务维度(如
|
||||
`(host, reason)`),避免不同业务维度互相吞并
|
||||
:param payload: 事件附加信息,仅在该 key 在当前窗口内"首次出现"
|
||||
时被保留,用于 flush 摘要里附样例
|
||||
:return: `EMIT` 表示调用方应立即输出原事件;`SUPPRESS` 表示窗口
|
||||
内已合并,调用方应静默
|
||||
"""
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
handle = self._schedule_flush(key)
|
||||
self._buckets[key] = _BucketState(
|
||||
first_payload=payload,
|
||||
count=1,
|
||||
flush_handle=handle,
|
||||
)
|
||||
return CoalesceDecision.EMIT
|
||||
bucket.count += 1
|
||||
return CoalesceDecision.SUPPRESS
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
立即 flush 所有未到期窗口并清空内部状态。
|
||||
|
||||
典型用于进程退出路径与单元测试。已注册的 `loop.call_later` 句柄
|
||||
会被取消,避免在事件循环关闭后再被触发;count>1 的 bucket 同步
|
||||
调用 `on_flush`(async on_flush 会被 await)。
|
||||
"""
|
||||
buckets = list(self._buckets.items())
|
||||
self._buckets.clear()
|
||||
for key, bucket in buckets:
|
||||
if bucket.flush_handle is not None:
|
||||
bucket.flush_handle.cancel()
|
||||
await self._emit_summary_if_needed(key, bucket)
|
||||
|
||||
def _schedule_flush(self, key: Hashable) -> asyncio.TimerHandle:
|
||||
"""
|
||||
为指定 key 注册窗口到期 flush。
|
||||
|
||||
`call_later` 回调本身只能是同步函数,因此用 `asyncio.create_task`
|
||||
把异步 flush 链接回事件循环。捕获事件循环异常并降级为 debug 日志,
|
||||
避免基础设施层把异常抛回业务调用方。
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.call_later(self._window_seconds, self._on_flush_timer, key)
|
||||
|
||||
def _on_flush_timer(self, key: Hashable) -> None:
|
||||
"""
|
||||
`loop.call_later` 到期回调:从事件循环里把异步 flush 任务接力起来。
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self._flush_key(key))
|
||||
except RuntimeError as exc:
|
||||
# 事件循环已关闭等罕见路径:记录后丢弃,避免影响其它 bucket
|
||||
self._log_debug(f"flush 调度失败,已忽略 key={key!r}: {exc}")
|
||||
|
||||
async def _flush_key(self, key: Hashable) -> None:
|
||||
"""
|
||||
窗口到期后的实际 flush 路径:取出 bucket 并按需调用 `on_flush`。
|
||||
"""
|
||||
bucket = self._buckets.pop(key, None)
|
||||
if bucket is None:
|
||||
return
|
||||
await self._emit_summary_if_needed(key, bucket)
|
||||
|
||||
async def _emit_summary_if_needed(
|
||||
self, key: Hashable, bucket: _BucketState
|
||||
) -> None:
|
||||
"""
|
||||
仅当窗口内命中次数 > 1 时输出聚合摘要。
|
||||
|
||||
on_flush 的同步/异步形态在构造时已识别;该方法负责按形态正确调度,
|
||||
并把消费者异常吞掉转为 debug 日志,避免基础设施把上层业务搞崩。
|
||||
"""
|
||||
if bucket.count <= 1:
|
||||
return
|
||||
summary = CoalesceSummary(
|
||||
key=key,
|
||||
count=bucket.count,
|
||||
first_payload=bucket.first_payload,
|
||||
window_seconds=self._window_seconds,
|
||||
)
|
||||
try:
|
||||
if self._is_flush_async:
|
||||
await self._on_flush(summary) # type: ignore[misc]
|
||||
else:
|
||||
self._on_flush(summary)
|
||||
except Exception as exc: # noqa: BLE001 - 基础设施不能因消费者异常崩溃
|
||||
self._log_debug(f"on_flush 回调异常已吞: key={key!r}, error={exc}")
|
||||
|
||||
def _log_debug(self, message: str) -> None:
|
||||
"""
|
||||
内部 debug 日志统一加 source 前缀,便于排查多 coalescer 共存时的来源。
|
||||
"""
|
||||
if self._source:
|
||||
logger.debug(f"[EventCoalescer:{self._source}] {message}")
|
||||
else:
|
||||
logger.debug(f"[EventCoalescer] {message}")
|
||||
@@ -3,6 +3,8 @@ import hmac
|
||||
import ipaddress
|
||||
import socket
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Set, Union
|
||||
@@ -13,6 +15,11 @@ from cachetools import TTLCache
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.coalesce import (
|
||||
CoalesceDecision,
|
||||
CoalesceSummary,
|
||||
EventCoalescer,
|
||||
)
|
||||
|
||||
|
||||
# DNS 解析结果缓存。
|
||||
@@ -36,6 +43,50 @@ _dns_inflight_locks: Dict[str, asyncio.Lock] = {}
|
||||
_dns_inflight_meta_lock = threading.Lock()
|
||||
|
||||
|
||||
class UrlSafetyReason(str, Enum):
|
||||
"""
|
||||
`evaluate_url_safety` 返回的诊断原因枚举。
|
||||
|
||||
成员值为稳定的小写蛇形字符串,可直接作为日志字段或告警标签使用,
|
||||
扩展枚举时保留既有成员的取值,避免破坏下游聚合系统对原因的归类。
|
||||
"""
|
||||
|
||||
# 通过全部校验,URL 可被请求
|
||||
ALLOWED = "allowed"
|
||||
# 协议非 http/https,或 netloc 无效,或域名不在允许列表内
|
||||
DOMAIN_NOT_ALLOWED = "domain_not_allowed"
|
||||
# 已通过域名 allowlist,但 DNS 解析失败(无返回或抛错)
|
||||
DNS_RESOLUTION_FAILED = "dns_resolution_failed"
|
||||
# DNS 解析到至少一个非公网地址,且未配置 `allowed_private_ranges`
|
||||
NON_GLOBAL_DNS_RESULT = "non_global_dns_result"
|
||||
# 配置了 `allowed_private_ranges`,但仍存在不在允许网段内的解析结果
|
||||
MIXED_OR_DISALLOWED_PRIVATE_RESULT = "mixed_or_disallowed_private_result"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UrlSafetyDiagnosis:
|
||||
"""
|
||||
URL 安全校验的结构化诊断结果,由 `evaluate_url_safety(_async)` 返回。
|
||||
|
||||
`is_safe_url` 仅使用 `allowed` 字段;日志、告警、运维诊断需要细分原因或
|
||||
解析 IP 时通过本对象消费。字段约束:
|
||||
- `host` 仅在通过域名 allowlist 后才被填充;DOMAIN_NOT_ALLOWED 场景为 None。
|
||||
- `ips` 仅在执行过 DNS 阶段后才可能非空;不含纯字符串协议失败场景。
|
||||
- `matched_private_ranges` 仅在通过 `allowed_private_ranges` 放行时填充。
|
||||
"""
|
||||
|
||||
# 是否放行
|
||||
allowed: bool
|
||||
# 放行/拦截的具体原因
|
||||
reason: UrlSafetyReason
|
||||
# 通过 allowlist 后从 URL 解析出的 hostname,未通过时为 None
|
||||
host: Optional[str] = None
|
||||
# DNS 解析结果(含命中或未命中私网放行的 IP),格式化为字符串
|
||||
ips: List[str] = field(default_factory=list)
|
||||
# 命中允许放行的非公网网段,仅 `ALLOWED` 且走私网放行分支时非空
|
||||
matched_private_ranges: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _resolve_addrinfo_to_ips(
|
||||
address_infos: Iterable,
|
||||
) -> Optional[List[ipaddress._BaseAddress]]:
|
||||
@@ -546,36 +597,26 @@ class SecurityUtils:
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证URL是否在允许的域名列表中,包括带有端口的域名(同步版本)
|
||||
验证 URL 是否在允许的域名列表中,包括带有端口的域名(同步版本)。
|
||||
|
||||
:param url: 需要验证的 URL
|
||||
:param allowed_domains: 允许的域名集合,域名可以包含端口
|
||||
:param strict: 是否严格匹配一级域名(默认为 False,允许多级域名)
|
||||
:param strict: 是否严格匹配一级域名(默认 False,允许多级域名)
|
||||
:param block_private: 是否拦截解析到非公网地址的 URL,防止 SSRF
|
||||
:param allowed_private_ranges: 域名命中后额外允许的非公网 IP/CIDR 网段
|
||||
:return: 如果URL合法且在允许的域名列表中,返回 True;否则返回 False
|
||||
:return: URL 合法且通过安全校验时返回 True,否则返回 False
|
||||
|
||||
注意:`block_private=True` 时会同步调用 `getaddrinfo`;async 上下文请改用
|
||||
`is_safe_url_async`。
|
||||
校验细节与失败原因由 `evaluate_url_safety` 返回;本方法只暴露布尔结果,
|
||||
作为只关心通过/拒绝判断的调用方的最薄入口。`block_private=True` 时会
|
||||
同步调用 `getaddrinfo`;async 上下文请改用 `is_safe_url_async`。
|
||||
"""
|
||||
try:
|
||||
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
|
||||
if hostname is None:
|
||||
return False
|
||||
|
||||
if block_private and not SecurityUtils._is_global_hostname(hostname):
|
||||
private_match = SecurityUtils._is_allowed_private_hostname(
|
||||
hostname, allowed_private_ranges
|
||||
)
|
||||
if private_match:
|
||||
SecurityUtils._log_private_range_allowed(url, private_match)
|
||||
return True
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occurred while validating URL: {e}")
|
||||
return False
|
||||
return SecurityUtils.evaluate_url_safety(
|
||||
url,
|
||||
allowed_domains,
|
||||
strict=strict,
|
||||
block_private=block_private,
|
||||
allowed_private_ranges=allowed_private_ranges,
|
||||
).allowed
|
||||
|
||||
@staticmethod
|
||||
async def is_safe_url_async(
|
||||
@@ -586,30 +627,194 @@ class SecurityUtils:
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
`is_safe_url` 的异步版本,参数与返回值含义不变。
|
||||
判定 URL 是否在允许的域名列表中,包括带有端口的域名。
|
||||
|
||||
DNS 解析通过事件循环线程池执行,并复用 TTL 缓存。
|
||||
DNS 解析通过事件循环线程池执行,并复用 TTL 缓存,不阻塞调用方所在的
|
||||
事件循环。参数与返回值含义同 `is_safe_url`;需要失败原因/解析 IP
|
||||
等结构化信息时调用 `evaluate_url_safety_async`。
|
||||
"""
|
||||
diagnosis = await SecurityUtils.evaluate_url_safety_async(
|
||||
url,
|
||||
allowed_domains,
|
||||
strict=strict,
|
||||
block_private=block_private,
|
||||
allowed_private_ranges=allowed_private_ranges,
|
||||
)
|
||||
return diagnosis.allowed
|
||||
|
||||
@staticmethod
|
||||
def evaluate_url_safety(
|
||||
url: str,
|
||||
allowed_domains: Union[Set[str], List[str]],
|
||||
strict: bool = False,
|
||||
block_private: bool = False,
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> "UrlSafetyDiagnosis":
|
||||
"""
|
||||
在 `is_safe_url` 的判定路径上输出结构化诊断结果(同步版本)。
|
||||
|
||||
与 `is_safe_url` 共用同一套校验顺序:协议/域名 allowlist → 可选 DNS 解析
|
||||
→ 可选非公网放行匹配;本方法额外返回失败原因、解析到的 IP 列表和命中的
|
||||
私网网段,供日志与告警渲染消费。校验中遇到未预期异常时按默认拒绝原则
|
||||
归类为 `DOMAIN_NOT_ALLOWED`,避免任何解析路径漏过 SSRF 校验。
|
||||
"""
|
||||
try:
|
||||
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
|
||||
if hostname is None:
|
||||
return False
|
||||
|
||||
if block_private and not await SecurityUtils._is_global_hostname_async(
|
||||
hostname
|
||||
):
|
||||
private_match = await SecurityUtils._is_allowed_private_hostname_async(
|
||||
hostname, allowed_private_ranges
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.DOMAIN_NOT_ALLOWED,
|
||||
)
|
||||
if private_match:
|
||||
SecurityUtils._log_private_range_allowed(url, private_match)
|
||||
return True
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
if not block_private:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=True,
|
||||
reason=UrlSafetyReason.ALLOWED,
|
||||
host=hostname,
|
||||
)
|
||||
addresses = SecurityUtils._hostname_addresses(hostname)
|
||||
return SecurityUtils._diagnose_resolved_addresses(
|
||||
url, hostname, addresses, allowed_private_ranges
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 - 默认拒绝,避免漏过 SSRF 校验
|
||||
logger.debug(f"Error occurred while validating URL: {e}")
|
||||
return False
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.DOMAIN_NOT_ALLOWED,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def evaluate_url_safety_async(
|
||||
url: str,
|
||||
allowed_domains: Union[Set[str], List[str]],
|
||||
strict: bool = False,
|
||||
block_private: bool = False,
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> "UrlSafetyDiagnosis":
|
||||
"""
|
||||
输出与 `evaluate_url_safety` 完全一致的结构化诊断结果。
|
||||
|
||||
DNS 解析通过事件循环线程池执行,并复用 TTL 缓存,不阻塞调用方所在的
|
||||
事件循环;校验顺序、字段含义、异常归类均与同步版本相同。
|
||||
"""
|
||||
try:
|
||||
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
|
||||
if hostname is None:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.DOMAIN_NOT_ALLOWED,
|
||||
)
|
||||
if not block_private:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=True,
|
||||
reason=UrlSafetyReason.ALLOWED,
|
||||
host=hostname,
|
||||
)
|
||||
addresses = await SecurityUtils._hostname_addresses_async(hostname)
|
||||
return SecurityUtils._diagnose_resolved_addresses(
|
||||
url, hostname, addresses, allowed_private_ranges
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 - 默认拒绝,避免漏过 SSRF 校验
|
||||
logger.debug(f"Error occurred while validating URL: {e}")
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.DOMAIN_NOT_ALLOWED,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def is_safe_image_url_async(
|
||||
url: str,
|
||||
allowed_domains: Union[Set[str], List[str]],
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
判定 URL 是否可作为图片代理请求目标。
|
||||
|
||||
校验顺序:协议 + 域名 allowlist + DNS SSRF 拦截 + 非公网放行匹配;标准
|
||||
校验失败时再用 `verify_signed_url` 兜底,允许后端预签名的媒体服务器
|
||||
URL 跳过私网拦截。两者皆失败才视为拒绝。
|
||||
|
||||
拒绝路径会输出结构化阻断日志:单次拦截立即打印一条 warning,同
|
||||
`(host, reason)` 的连续命中在 `_IMAGE_PROXY_BLOCK_LOG_WINDOW_SECONDS`
|
||||
窗口内合并为一条聚合摘要,避免媒体详情页一次请求把日志刷爆。日志字段
|
||||
范围严格限定为 URL、host、reason、解析 IP 与允许网段配置;cookies、
|
||||
签名串、token、请求头等敏感材料一律不进入日志。
|
||||
"""
|
||||
diagnosis = await SecurityUtils.evaluate_url_safety_async(
|
||||
url,
|
||||
allowed_domains,
|
||||
block_private=True,
|
||||
allowed_private_ranges=allowed_private_ranges,
|
||||
)
|
||||
if diagnosis.allowed:
|
||||
return True
|
||||
if SecurityUtils.verify_signed_url(url) is not None:
|
||||
return True
|
||||
await _emit_image_proxy_block_warning(
|
||||
url=url,
|
||||
diagnosis=diagnosis,
|
||||
signature_carried=_url_carries_signature(url),
|
||||
allowed_private_ranges=allowed_private_ranges,
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _diagnose_resolved_addresses(
|
||||
url: str,
|
||||
hostname: str,
|
||||
addresses: Optional[List[ipaddress._BaseAddress]],
|
||||
allowed_private_ranges: Optional[Iterable[str]],
|
||||
) -> "UrlSafetyDiagnosis":
|
||||
"""
|
||||
对已完成 DNS 解析的地址列表执行非公网放行判断,并归一化诊断结果。
|
||||
|
||||
- 地址列表为空/None:视为 DNS 不可信,拒绝并标记 `DNS_RESOLUTION_FAILED`。
|
||||
- 全部公网地址:直接放行。
|
||||
- 存在非公网地址且未配置允许网段:拒绝并标记 `NON_GLOBAL_DNS_RESULT`,
|
||||
供日志附带"如使用 fake-ip 需要配置 IMAGE_PROXY_ALLOWED_PRIVATE_RANGES"
|
||||
的提示。
|
||||
- 存在非公网地址且配置了允许网段但未全部命中:拒绝并标记
|
||||
`MIXED_OR_DISALLOWED_PRIVATE_RESULT`,提示存在不允许的解析结果。
|
||||
- 全部命中允许网段:放行并附带命中的 IP 与网段,由
|
||||
`_log_private_range_allowed` 输出排查日志。
|
||||
"""
|
||||
if not addresses:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.DNS_RESOLUTION_FAILED,
|
||||
host=hostname,
|
||||
)
|
||||
if SecurityUtils._addresses_all_global(addresses):
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=True,
|
||||
reason=UrlSafetyReason.ALLOWED,
|
||||
host=hostname,
|
||||
ips=[str(addr) for addr in addresses],
|
||||
)
|
||||
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
|
||||
if not networks:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.NON_GLOBAL_DNS_RESULT,
|
||||
host=hostname,
|
||||
ips=[str(addr) for addr in addresses],
|
||||
)
|
||||
match = SecurityUtils._match_private_addresses(addresses, networks)
|
||||
if match is None:
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=UrlSafetyReason.MIXED_OR_DISALLOWED_PRIVATE_RESULT,
|
||||
host=hostname,
|
||||
ips=[str(addr) for addr in addresses],
|
||||
)
|
||||
matched_addresses, matched_networks = match
|
||||
SecurityUtils._log_private_range_allowed(url, match)
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=True,
|
||||
reason=UrlSafetyReason.ALLOWED,
|
||||
host=hostname,
|
||||
ips=[str(addr) for addr in matched_addresses],
|
||||
matched_private_ranges=[str(net) for net in matched_networks],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sanitize_url_path(url: str, max_length: int = 120) -> str:
|
||||
@@ -636,3 +841,134 @@ class SecurityUtils:
|
||||
safe_path = f"compressed_{hash_value}{file_extension}"
|
||||
|
||||
return safe_path
|
||||
|
||||
|
||||
# 图片代理阻断日志聚合窗口(秒)。媒体详情页一次请求会批量触发同 host/同原因的拦截,
|
||||
# 按 (host, reason) 合并后只输出首条 warning + 窗口结束的聚合摘要,避免日志刷屏。
|
||||
_IMAGE_PROXY_BLOCK_LOG_WINDOW_SECONDS = 60.0
|
||||
|
||||
# fake-ip / 旁路 DNS 用户最常因 IMAGE_PROXY_ALLOWED_PRIVATE_RANGES 未配置而踩坑,
|
||||
# 在 reason=NON_GLOBAL_DNS_RESULT 且当前未配置允许网段时随 warning 一起输出,指向正确的修复开关。
|
||||
_IMAGE_PROXY_FAKEIP_HINT = (
|
||||
"提示:若使用 fake-ip / 旁路 DNS(常见网段 198.18.0.0/15、100.64.0.0/10),"
|
||||
"请将对应网段加入 IMAGE_PROXY_ALLOWED_PRIVATE_RANGES"
|
||||
)
|
||||
|
||||
# URL fragment 中实际携带代理签名但校验失败时附在 reason 末尾的标记。
|
||||
# 仅起标识作用,签名串本身不写入日志,避免泄露签名材料。
|
||||
_INVALID_SIGNATURE_TAG = "invalid_signature"
|
||||
|
||||
|
||||
def _url_carries_signature(url: str) -> bool:
|
||||
"""
|
||||
判断 URL 是否在 fragment 中显式携带代理签名参数 `mp_sig`。
|
||||
|
||||
仅做轻量字符串匹配,避免对普通图片 URL 跑完整签名校验路径;未携带签名
|
||||
的外链不会触发 `invalid_signature` 标记,避免阻断日志误导未签名调用方。
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
fragment_start = url.find("#")
|
||||
if fragment_start < 0:
|
||||
return False
|
||||
return "mp_sig=" in url[fragment_start + 1:]
|
||||
|
||||
|
||||
def _format_image_proxy_block_warning(
|
||||
*,
|
||||
url: str,
|
||||
reason: str,
|
||||
host: Optional[str],
|
||||
ips: List[str],
|
||||
allowed_private_ranges: List[str],
|
||||
hint: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
渲染图片代理首条阻断 warning 文案。
|
||||
|
||||
字段范围严格限定为 URL、host、reason、IP 与允许网段配置;hint 仅在
|
||||
reason 与配置缺失同时满足时由调用方填充。其余敏感材料(cookies、签名
|
||||
串、token、请求头)不允许进入该日志路径。
|
||||
"""
|
||||
fields = [
|
||||
f"url={url}",
|
||||
f"reason={reason}",
|
||||
f"host={host or ''}",
|
||||
f"ips={','.join(ips)}",
|
||||
f"allowed_private_ranges={','.join(allowed_private_ranges)}",
|
||||
]
|
||||
line = "Blocked unsafe image URL: " + ", ".join(fields)
|
||||
if hint:
|
||||
line = f"{line} | {hint}"
|
||||
return line
|
||||
|
||||
|
||||
def _log_image_proxy_block_summary(summary: CoalesceSummary) -> None:
|
||||
"""
|
||||
图片代理阻断日志聚合窗口到期回调,输出窗口内的命中计数与首条样例。
|
||||
|
||||
summary.key 由 `_emit_image_proxy_block_warning` 固定构造为
|
||||
`(host, reason_label)` 二元组;摘要保留首条事件的 URL 与解析 IP,
|
||||
避免运维只看到 count 而无法定位是哪批请求被合并。
|
||||
"""
|
||||
host, reason = summary.key
|
||||
payload = summary.first_payload or {}
|
||||
sample_ips = ",".join(payload.get("ips") or [])
|
||||
logger.warn(
|
||||
"Blocked unsafe image URL (aggregated): "
|
||||
f"host={host or ''}, reason={reason}, "
|
||||
f"count={summary.count}, window={summary.window_seconds:g}s, "
|
||||
f"sample_url={payload.get('url', '')}, sample_ips={sample_ips}"
|
||||
)
|
||||
|
||||
|
||||
# 图片代理阻断日志聚合器。同 (host, reason) 高频拦截在窗口内合并为一条聚合摘要,避免媒体详情页一次请求把日志刷爆;
|
||||
# 放行 debug 日志与诊断布尔结果不受聚合影响。
|
||||
_image_proxy_block_log_coalescer = EventCoalescer(
|
||||
window_seconds=_IMAGE_PROXY_BLOCK_LOG_WINDOW_SECONDS,
|
||||
on_flush=_log_image_proxy_block_summary,
|
||||
source="image_proxy",
|
||||
)
|
||||
|
||||
|
||||
async def _emit_image_proxy_block_warning(
|
||||
*,
|
||||
url: str,
|
||||
diagnosis: "UrlSafetyDiagnosis",
|
||||
signature_carried: bool,
|
||||
allowed_private_ranges: Optional[Iterable[str]],
|
||||
) -> None:
|
||||
"""
|
||||
把诊断结果转写为结构化阻断 warning,并交由 coalescer 决定是否实际输出。
|
||||
|
||||
`signature_carried=True` 表示请求 URL 在 fragment 里实际携带了代理签名但
|
||||
校验失败,此时在 reason 末尾追加 `invalid_signature` 标记,便于区分
|
||||
"未签名外链直接撞 allowlist"与"签名 URL 已失效"两种排查路径。
|
||||
"""
|
||||
# reason_label 既作为 warning 字段,也作为 coalescer 桶键的一部分;签名
|
||||
# 标记拼接到同一字符串里是为了让"带签名失败"的命中与"裸 URL 失败"分桶,
|
||||
# 各自独立计数与摘要,不要在不引入新桶维度的情况下拆开。
|
||||
reason_label = diagnosis.reason.value
|
||||
if signature_carried:
|
||||
reason_label = f"{reason_label}+{_INVALID_SIGNATURE_TAG}"
|
||||
allowed_ranges = [str(r) for r in (allowed_private_ranges or [])]
|
||||
hint = (
|
||||
_IMAGE_PROXY_FAKEIP_HINT
|
||||
if diagnosis.reason is UrlSafetyReason.NON_GLOBAL_DNS_RESULT
|
||||
and not allowed_ranges
|
||||
else None
|
||||
)
|
||||
key = (diagnosis.host or "", reason_label)
|
||||
payload = {"url": url, "ips": list(diagnosis.ips)}
|
||||
decision = await _image_proxy_block_log_coalescer.record(key=key, payload=payload)
|
||||
if decision is CoalesceDecision.EMIT:
|
||||
logger.warn(
|
||||
_format_image_proxy_block_warning(
|
||||
url=url,
|
||||
reason=reason_label,
|
||||
host=diagnosis.host,
|
||||
ips=list(diagnosis.ips),
|
||||
allowed_private_ranges=allowed_ranges,
|
||||
hint=hint,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -146,6 +146,21 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智
|
||||
}
|
||||
```
|
||||
|
||||
**`search_web` 网络搜索示例**:
|
||||
```json
|
||||
{
|
||||
"tool_name": "search_web",
|
||||
"arguments": {
|
||||
"query": "asyncio TaskGroup",
|
||||
"search_engine": "duckduckgo",
|
||||
"site_url": "https://docs.python.org/3/",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`search_engine` 可选,通过 DDGS 支持 `auto`、`duckduckgo`、`google`、`brave`、`yahoo`、`wikipedia`、`yandex`、`mojeek`。`site_url` 可选,用于限定搜索到指定域名或 URL 路径范围。搜索默认使用系统代理配置。
|
||||
|
||||
### 3. 获取工具详情
|
||||
|
||||
**GET** `/api/v1/mcp/tools/{tool_name}`
|
||||
|
||||
@@ -40,7 +40,9 @@ dedicated tool can complete the task more directly and safely.
|
||||
- `browse_webpage` - Real browser actions: `goto`, `get_content`, `screenshot`,
|
||||
`click`, `fill`, `select`, `evaluate`, `wait`.
|
||||
- `search_web` - Find current pages or official references before opening a
|
||||
target URL.
|
||||
target URL. It supports DDGS-backed `search_engine` (`auto`, `duckduckgo`,
|
||||
`google`, `brave`, etc.) and `site_url` for limiting results to a specified
|
||||
domain or URL path. It uses the configured system proxy by default.
|
||||
- `query_sites` - Get MoviePilot site IDs before site-specific operations.
|
||||
- `update_site_cookie` - Update a configured site's Cookie and User-Agent using
|
||||
username, password, and optional two-step code.
|
||||
@@ -76,6 +78,12 @@ If the user only described the page, search first:
|
||||
search_web query="official site or page name"
|
||||
```
|
||||
|
||||
To search within a specific site:
|
||||
|
||||
```text
|
||||
search_web query="release notes" site_url="https://docs.example.com/"
|
||||
```
|
||||
|
||||
Then open the most relevant result with `browse_webpage action="goto"`.
|
||||
|
||||
### 3. Observe Before Acting
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
name: moviepilot-api
|
||||
version: 1
|
||||
description: Use this skill when you need to call MoviePilot REST API endpoints directly. Covers all 237 API endpoints across 27 categories including media search, downloads, subscriptions, library management, site management, system administration, plugins, workflows, and more. Use this skill whenever the user asks to interact with MoviePilot via its HTTP API, or when the moviepilot-cli skill cannot cover a specific operation.
|
||||
description: Use this skill when you need to call MoviePilot REST API endpoints directly. Covers all 238 API endpoints across 27 categories including media search, downloads, subscriptions, library management, site management, system administration, plugins, workflows, and more. Use this skill whenever the user asks to interact with MoviePilot via its HTTP API, or when the moviepilot-cli skill cannot cover a specific operation.
|
||||
---
|
||||
|
||||
# MoviePilot REST API
|
||||
@@ -161,7 +161,7 @@ All endpoints are under the base URL `{MP_HOST}`. Path parameters are shown as `
|
||||
| GET | `/api/v1/subscribe/shares` | List shared subscriptions. Params: `name`, `page`, `count`, `genre_id`, `min_rating`, `max_rating`, `sort_type` |
|
||||
| GET | `/api/v1/subscribe/share/statistics` | Share statistics |
|
||||
|
||||
### Site (24 endpoints)
|
||||
### Site (25 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
@@ -174,7 +174,8 @@ All endpoints are under the base URL `{MP_HOST}`. Path parameters are shown as `
|
||||
| GET | `/api/v1/site/cookiecloud` | Sync CookieCloud |
|
||||
| GET | `/api/v1/site/reset` | Reset sites |
|
||||
| POST | `/api/v1/site/priorities` | Batch update site priorities. Body: array |
|
||||
| GET | `/api/v1/site/cookie/{site_id}` | Update site cookie & UA. Params: `username`, `password`, `code` |
|
||||
| POST | `/api/v1/site/cookie/{site_id}` | Update site cookie & UA. Body: `SiteCookieUpdate` JSON |
|
||||
| GET | `/api/v1/site/cookie/{site_id}` | Legacy update site cookie & UA. Params: `username`, `password`, `code` |
|
||||
| POST | `/api/v1/site/userdata/{site_id}` | Refresh site user data |
|
||||
| GET | `/api/v1/site/userdata/{site_id}` | Get site user data. Params: `workdate` |
|
||||
| GET | `/api/v1/site/userdata/latest` | All sites latest user data |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import tempfile
|
||||
@@ -9,7 +10,7 @@ from urllib.parse import quote
|
||||
|
||||
from telebot import apihelper
|
||||
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.agent.tools.impl.send_message import SendMessageInput, SendMessageTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.agent.llm import AgentCapabilityManager
|
||||
@@ -27,7 +28,7 @@ from app.modules.vocechat import VoceChatModule
|
||||
from app.modules.wechat import WechatModule
|
||||
from app.modules.wechat.wechatbot import WeChatBot
|
||||
from app.schemas import CommingMessage, Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
class AgentImageSupportTest(unittest.TestCase):
|
||||
@@ -515,6 +516,39 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(payload.image_url, "https://example.com/poster.png")
|
||||
|
||||
def test_send_message_tool_uses_agent_notification_type(self):
|
||||
"""发送消息工具应固定使用智能体消息类型。"""
|
||||
|
||||
async def _run():
|
||||
tool = SendMessageTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.base.ToolChain.async_post_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message:
|
||||
result = await tool.run(
|
||||
message="处理完成",
|
||||
title="智能体通知",
|
||||
image_url="https://example.com/poster.png",
|
||||
)
|
||||
return result, async_post_message
|
||||
|
||||
result, async_post_message = asyncio.run(_run())
|
||||
notification = async_post_message.await_args.args[0]
|
||||
|
||||
self.assertEqual(result, "消息已发送")
|
||||
self.assertEqual(notification.mtype, NotificationType.Agent)
|
||||
self.assertEqual(notification.channel, MessageChannel.Telegram)
|
||||
self.assertEqual(notification.source, "telegram-test")
|
||||
self.assertEqual(notification.title, "智能体通知")
|
||||
self.assertEqual(notification.text, "处理完成")
|
||||
self.assertEqual(notification.image, "https://example.com/poster.png")
|
||||
|
||||
def test_send_local_file_input_accepts_file_payload(self):
|
||||
payload = SendLocalFileInput(
|
||||
explanation="send generated report",
|
||||
|
||||
@@ -43,6 +43,24 @@ class TestAgentPromptStyle(unittest.TestCase):
|
||||
"Do not let user memory or persona style override this core identity",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"Never directly modify application source code",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"If the user has not explicitly requested an operation that changes system behavior",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn("<non_negotiable_boundaries>", prompt)
|
||||
self.assertIn("<confirmation_policy>", prompt)
|
||||
self.assertIn(
|
||||
"Treat read-only inspection as allowed",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"Use `execute_command` only for diagnostics, read-only inspection, or commands the user explicitly asked to run",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn("当前日期", prompt)
|
||||
self.assertNotIn("当前时间", prompt)
|
||||
|
||||
|
||||
159
tests/test_agent_search_web_tool.py
Normal file
159
tests/test_agent_search_web_tool.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.search_web import (
|
||||
DDGS_AUTO_BACKEND,
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
SearchWebTool,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestAgentSearchWebTool(unittest.TestCase):
|
||||
"""Agent 网络搜索工具测试"""
|
||||
|
||||
def test_build_search_query_adds_site_filter(self):
|
||||
"""指定网址时应生成搜索引擎可识别的 site 查询"""
|
||||
site_filter = SearchWebTool._normalize_site_filter("https://docs.python.org/3/")
|
||||
|
||||
self.assertEqual("docs.python.org", site_filter.domain)
|
||||
self.assertEqual("/3", site_filter.path)
|
||||
self.assertEqual("docs.python.org/3", site_filter.search_target)
|
||||
self.assertEqual(
|
||||
"asyncio site:docs.python.org/3",
|
||||
SearchWebTool._build_search_query("asyncio", site_filter),
|
||||
)
|
||||
|
||||
def test_build_search_query_keeps_existing_site_operator(self):
|
||||
"""用户已写 site 条件时不应重复追加限定条件"""
|
||||
site_filter = SearchWebTool._normalize_site_filter("python.org")
|
||||
|
||||
self.assertEqual(
|
||||
"asyncio site:docs.python.org",
|
||||
SearchWebTool._build_search_query(
|
||||
"asyncio site:docs.python.org",
|
||||
site_filter,
|
||||
),
|
||||
)
|
||||
|
||||
def test_filter_results_by_site_matches_domain_and_path(self):
|
||||
"""站点过滤应同时约束域名和路径前缀"""
|
||||
site_filter = SearchWebTool._normalize_site_filter("https://docs.python.org/3/")
|
||||
results = [
|
||||
{"url": "https://docs.python.org/3/library/asyncio.html"},
|
||||
{"url": "https://www.docs.python.org/3/tutorial/index.html"},
|
||||
{"url": "https://docs.python.org/2/library/asyncio.html"},
|
||||
{"url": "https://example.com/3/library/asyncio.html"},
|
||||
]
|
||||
|
||||
filtered_results = SearchWebTool._filter_results_by_site(results, site_filter)
|
||||
|
||||
self.assertEqual(2, len(filtered_results))
|
||||
self.assertEqual(
|
||||
"https://docs.python.org/3/library/asyncio.html",
|
||||
filtered_results[0]["url"],
|
||||
)
|
||||
|
||||
def test_auto_search_plan_falls_back_to_search_engine(self):
|
||||
"""自动模式应只使用 DDGS 搜索引擎后端"""
|
||||
search_plan = SearchWebTool._get_search_plan(DEFAULT_SEARCH_ENGINE)
|
||||
|
||||
self.assertEqual([DEFAULT_SEARCH_ENGINE], search_plan)
|
||||
|
||||
def test_auto_ddgs_backend_excludes_bing(self):
|
||||
"""DDGS 自动搜索后端不应包含 Bing"""
|
||||
auto_backends = SearchWebTool._get_ddgs_backend(
|
||||
DEFAULT_SEARCH_ENGINE
|
||||
).split(",")
|
||||
|
||||
self.assertNotIn("bing", auto_backends)
|
||||
self.assertIn("duckduckgo", auto_backends)
|
||||
self.assertEqual(DDGS_AUTO_BACKEND, ",".join(auto_backends))
|
||||
|
||||
def test_bing_search_engine_is_not_supported(self):
|
||||
"""Bing 不应再作为可选 DDGS 搜索后端暴露"""
|
||||
tool = SearchWebTool(session_id="session-1", user_id="10001")
|
||||
|
||||
result = asyncio.run(tool.run(query="asyncio", search_engine="bing"))
|
||||
|
||||
self.assertIn("不支持的搜索源 'bing'", result)
|
||||
|
||||
def test_ddgs_alias_uses_auto_backend(self):
|
||||
"""DDGS 别名应映射到自动 DDGS 后端"""
|
||||
self.assertEqual(
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
SearchWebTool._normalize_search_engine("ddgs"),
|
||||
)
|
||||
|
||||
def test_run_uses_specific_search_engine_and_site_filter(self):
|
||||
"""显式搜索引擎和指定网址应传入后端搜索调用"""
|
||||
|
||||
async def _run_tool():
|
||||
"""执行一次带 mock 后端的搜索工具调用"""
|
||||
tool = SearchWebTool(session_id="session-1", user_id="10001")
|
||||
with patch.object(
|
||||
tool,
|
||||
"_search_with_backend",
|
||||
new_callable=AsyncMock,
|
||||
) as search_mock:
|
||||
search_mock.return_value = [
|
||||
{
|
||||
"title": "asyncio",
|
||||
"snippet": "Python asyncio docs",
|
||||
"url": "https://docs.python.org/3/library/asyncio.html",
|
||||
"source": "DuckDuckGo",
|
||||
}
|
||||
]
|
||||
|
||||
result = await tool.run(
|
||||
query="asyncio",
|
||||
max_results=5,
|
||||
search_engine="duckduckgo",
|
||||
site_url="https://docs.python.org/3/",
|
||||
)
|
||||
return result, search_mock.await_args.kwargs
|
||||
|
||||
result, call_kwargs = asyncio.run(_run_tool())
|
||||
payload = json.loads(result)
|
||||
|
||||
self.assertEqual("duckduckgo", call_kwargs["engine"])
|
||||
self.assertEqual("asyncio site:docs.python.org/3", call_kwargs["query"])
|
||||
self.assertEqual("docs.python.org", call_kwargs["site_filter"].domain)
|
||||
self.assertEqual(1, payload["total_results"])
|
||||
self.assertEqual("DuckDuckGo", payload["results"][0]["source"])
|
||||
|
||||
def test_ddgs_uses_system_proxy_by_default(self):
|
||||
"""DDGS 搜索默认应使用系统代理配置"""
|
||||
|
||||
async def _run_tool():
|
||||
"""执行一次带 mock DDGS 后端的搜索工具调用"""
|
||||
tool = SearchWebTool(session_id="session-1", user_id="10001")
|
||||
with patch.object(
|
||||
settings, "PROXY_HOST", "http://proxy.example.com:7890"
|
||||
), patch("app.agent.tools.impl.search_web.DDGS") as ddgs_mock:
|
||||
ddgs = ddgs_mock.return_value.__enter__.return_value
|
||||
ddgs.text.return_value = [
|
||||
{
|
||||
"title": "asyncio",
|
||||
"body": "Python asyncio docs",
|
||||
"href": "https://docs.python.org/3/library/asyncio.html",
|
||||
}
|
||||
]
|
||||
|
||||
results = await tool._search_ddgs(
|
||||
query="asyncio",
|
||||
max_results=1,
|
||||
search_engine="duckduckgo",
|
||||
)
|
||||
return results, ddgs_mock.call_args.kwargs
|
||||
|
||||
results, ddgs_kwargs = asyncio.run(_run_tool())
|
||||
|
||||
self.assertEqual("http://proxy.example.com:7890", ddgs_kwargs["proxy"])
|
||||
self.assertEqual(1, len(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
201
tests/test_coalesce.py
Normal file
201
tests/test_coalesce.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
`EventCoalescer` 基础设施单元测试。
|
||||
|
||||
测试策略:用极短窗口(默认 0.05s)驱动真实事件循环触发 flush,避免引入
|
||||
对时间 mock 的复杂度;同时通过 `asyncio.sleep` 让出控制权以保证 flush
|
||||
回调被调度执行。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from app.utils.coalesce import (
|
||||
CoalesceDecision,
|
||||
CoalesceSummary,
|
||||
EventCoalescer,
|
||||
)
|
||||
|
||||
|
||||
# 窗口尽量短,但要大于事件循环单次 tick 的开销,避免 flush 在 record 仍持锁时触发
|
||||
_TEST_WINDOW = 0.05
|
||||
# 等待窗口到期 + flush 任务完成所需的额外余量
|
||||
_TEST_WAIT = _TEST_WINDOW * 4
|
||||
|
||||
|
||||
class EventCoalescerTest(IsolatedAsyncioTestCase):
|
||||
"""
|
||||
覆盖 EventCoalescer 的核心契约:首条 EMIT、窗口内 SUPPRESS、count>1
|
||||
时 flush 摘要、不同 key 互不影响、close() 立即 flush、on_flush 异常
|
||||
被吞、同步/async on_flush 都可用。
|
||||
"""
|
||||
|
||||
async def test_first_record_returns_emit(self):
|
||||
"""
|
||||
某 key 在新窗口内的首次出现必须返回 EMIT,确保调用方按原样输出。
|
||||
"""
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, summaries.append)
|
||||
|
||||
decision = await coalescer.record(("host", "reason"), payload={"i": 1})
|
||||
|
||||
self.assertIs(decision, CoalesceDecision.EMIT)
|
||||
await coalescer.close()
|
||||
|
||||
async def test_subsequent_same_key_records_are_suppressed(self):
|
||||
"""
|
||||
同一 key 在窗口内连续命中,第 2 次起返回 SUPPRESS。
|
||||
"""
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, lambda _s: None)
|
||||
await coalescer.record("k", payload="first")
|
||||
|
||||
for _ in range(3):
|
||||
self.assertIs(
|
||||
await coalescer.record("k", payload="ignored"),
|
||||
CoalesceDecision.SUPPRESS,
|
||||
)
|
||||
await coalescer.close()
|
||||
|
||||
async def test_window_expiry_flushes_summary_when_count_gt_one(self):
|
||||
"""
|
||||
窗口到期且 count>1 时,on_flush 收到包含 count、first_payload、window 的摘要。
|
||||
"""
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, summaries.append, source="test")
|
||||
key = ("h", "r")
|
||||
await coalescer.record(key, payload={"url": "u1"})
|
||||
await coalescer.record(key, payload={"url": "u2"})
|
||||
await coalescer.record(key, payload={"url": "u3"})
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
self.assertEqual(len(summaries), 1)
|
||||
summary = summaries[0]
|
||||
self.assertEqual(summary.key, key)
|
||||
self.assertEqual(summary.count, 3)
|
||||
self.assertEqual(summary.first_payload, {"url": "u1"})
|
||||
self.assertEqual(summary.window_seconds, _TEST_WINDOW)
|
||||
|
||||
async def test_window_expiry_does_not_flush_when_count_is_one(self):
|
||||
"""
|
||||
窗口内只出现一次时,首条 EMIT 已表达完整事件,不再补发聚合摘要。
|
||||
"""
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, summaries.append)
|
||||
await coalescer.record("solo", payload=None)
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
self.assertEqual(summaries, [])
|
||||
|
||||
async def test_different_keys_do_not_collapse(self):
|
||||
"""
|
||||
不同 key 各自独立计数与 flush,互不吞并。
|
||||
"""
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, summaries.append)
|
||||
await coalescer.record("a", payload="a1")
|
||||
await coalescer.record("b", payload="b1")
|
||||
await coalescer.record("a", payload="a2")
|
||||
await coalescer.record("b", payload="b2")
|
||||
await coalescer.record("a", payload="a3")
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
by_key = {s.key: s for s in summaries}
|
||||
self.assertEqual(set(by_key.keys()), {"a", "b"})
|
||||
self.assertEqual(by_key["a"].count, 3)
|
||||
self.assertEqual(by_key["a"].first_payload, "a1")
|
||||
self.assertEqual(by_key["b"].count, 2)
|
||||
self.assertEqual(by_key["b"].first_payload, "b1")
|
||||
|
||||
async def test_new_window_after_flush_emits_again(self):
|
||||
"""
|
||||
窗口结束后下一条同 key 事件应被视为新窗口的首条,返回 EMIT。
|
||||
"""
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, lambda _s: None)
|
||||
await coalescer.record("k", payload=1)
|
||||
await coalescer.record("k", payload=2)
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
decision = await coalescer.record("k", payload=3)
|
||||
|
||||
self.assertIs(decision, CoalesceDecision.EMIT)
|
||||
await coalescer.close()
|
||||
|
||||
async def test_close_flushes_pending_buckets_immediately(self):
|
||||
"""
|
||||
close() 必须取消未到期 timer 并立即触发 count>1 的 bucket flush,
|
||||
用于进程退出路径。
|
||||
"""
|
||||
# 使用一个足够长的窗口,确保自然到期不会先于 close 触发
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(1.0, summaries.append)
|
||||
await coalescer.record("k", payload="first")
|
||||
await coalescer.record("k", payload="second")
|
||||
|
||||
await coalescer.close()
|
||||
|
||||
self.assertEqual(len(summaries), 1)
|
||||
self.assertEqual(summaries[0].count, 2)
|
||||
self.assertEqual(summaries[0].first_payload, "first")
|
||||
|
||||
async def test_close_does_not_emit_when_count_is_one(self):
|
||||
"""
|
||||
close() 与正常窗口到期一致,count==1 时不输出摘要。
|
||||
"""
|
||||
summaries: List[CoalesceSummary] = []
|
||||
coalescer = EventCoalescer(1.0, summaries.append)
|
||||
await coalescer.record("k", payload="only")
|
||||
|
||||
await coalescer.close()
|
||||
|
||||
self.assertEqual(summaries, [])
|
||||
|
||||
async def test_async_on_flush_is_awaited(self):
|
||||
"""
|
||||
on_flush 为 async 函数时应被正确 await,而不是被丢弃成协程对象。
|
||||
"""
|
||||
awaited: List[CoalesceSummary] = []
|
||||
|
||||
async def on_flush(summary: CoalesceSummary) -> None:
|
||||
await asyncio.sleep(0)
|
||||
awaited.append(summary)
|
||||
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, on_flush)
|
||||
await coalescer.record("k", payload="a")
|
||||
await coalescer.record("k", payload="b")
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
self.assertEqual(len(awaited), 1)
|
||||
self.assertEqual(awaited[0].count, 2)
|
||||
|
||||
async def test_on_flush_exception_is_swallowed(self):
|
||||
"""
|
||||
on_flush 抛异常不能影响 coalescer 自身或上层调用方,仅 debug 记录。
|
||||
"""
|
||||
def on_flush(_summary: CoalesceSummary) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
coalescer = EventCoalescer(_TEST_WINDOW, on_flush)
|
||||
await coalescer.record("k", payload="x")
|
||||
await coalescer.record("k", payload="y")
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
# 异常被吞,新窗口可以继续接受 record
|
||||
self.assertIs(
|
||||
await coalescer.record("k", payload="z"),
|
||||
CoalesceDecision.EMIT,
|
||||
)
|
||||
await coalescer.close()
|
||||
|
||||
async def test_invalid_window_raises(self):
|
||||
"""
|
||||
非正数窗口值在构造期即拒绝,避免运行期出现 0 或负窗口的死循环 flush。
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
EventCoalescer(0, lambda _s: None)
|
||||
with self.assertRaises(ValueError):
|
||||
EventCoalescer(-1.0, lambda _s: None)
|
||||
@@ -1,10 +1,12 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import app.chain.download as download_module
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.schemas import NotExistMediaInfo
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
@@ -95,3 +97,231 @@ def test_download_single_submits_download_added_to_background(monkeypatch):
|
||||
download_dir=Path("/downloads"),
|
||||
torrent_content=b"torrent-content",
|
||||
)
|
||||
|
||||
|
||||
class _FakeBatchTorrentHelper:
|
||||
"""
|
||||
为批量下载测试提供稳定排序和种子文件集数解析。
|
||||
"""
|
||||
|
||||
episodes = []
|
||||
|
||||
def sort_group_torrents(self, contexts):
|
||||
return contexts
|
||||
|
||||
def get_torrent_episodes(self, _files):
|
||||
return list(self.episodes)
|
||||
|
||||
|
||||
def _build_tv_context(episode_list=None):
|
||||
"""
|
||||
构造标题未显式标集数的单季电视剧候选。
|
||||
"""
|
||||
episodes = episode_list or []
|
||||
return SimpleNamespace(
|
||||
media_info=SimpleNamespace(type=MediaType.TV, tmdb_id=1, douban_id=None),
|
||||
meta_info=SimpleNamespace(
|
||||
season_list=[1],
|
||||
episode_list=episodes,
|
||||
title="Test Show",
|
||||
org_string="Test Show S01 2160p",
|
||||
set_episodes=lambda begin, end: None,
|
||||
),
|
||||
torrent_info=SimpleNamespace(title="Test Show S01 2160p", site_name="TestSite"),
|
||||
allowed_episodes=None,
|
||||
)
|
||||
|
||||
|
||||
def test_batch_download_rejects_complete_coverage_when_files_do_not_cover_target(monkeypatch):
|
||||
"""
|
||||
完整覆盖要求不能让 1-13 这种局部包冒充 1-143 的目标范围。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = list(range(1, 14))
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock(return_value=(b"torrent-content", "", ["demo.mkv"]))
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context()
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=143,
|
||||
start_episode=1,
|
||||
require_complete_coverage=True,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == []
|
||||
assert lefts == no_exists
|
||||
chain.download_single.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_download_accepts_complete_coverage_when_files_cover_target_range(monkeypatch):
|
||||
"""
|
||||
自定义起始集场景按目标范围覆盖判断,100-143 可满足 start=100、total=143。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = list(range(100, 144))
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock(return_value=(b"torrent-content", "", ["demo.mkv"]))
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context()
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=143,
|
||||
start_episode=100,
|
||||
require_complete_coverage=True,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == [context]
|
||||
assert lefts == {}
|
||||
chain.download_single.assert_called_once()
|
||||
|
||||
|
||||
def test_batch_download_accepts_complete_coverage_when_title_episodes_cover_target(monkeypatch):
|
||||
"""
|
||||
显式标出完整范围的候选也可满足完整覆盖任务。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = []
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock()
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context(episode_list=list(range(1, 144)))
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=143,
|
||||
start_episode=1,
|
||||
require_complete_coverage=True,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == [context]
|
||||
assert lefts == {}
|
||||
chain.download_torrent.assert_not_called()
|
||||
chain.download_single.assert_called_once()
|
||||
|
||||
|
||||
def test_batch_download_rejects_complete_coverage_when_title_episodes_are_partial(monkeypatch):
|
||||
"""
|
||||
显式标出局部范围的候选不能满足完整覆盖任务。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = []
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock()
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context(episode_list=list(range(1, 14)))
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=143,
|
||||
start_episode=1,
|
||||
require_complete_coverage=True,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == []
|
||||
assert lefts == no_exists
|
||||
chain.download_torrent.assert_not_called()
|
||||
chain.download_single.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_download_complete_coverage_ignores_allowed_episode_narrowing(monkeypatch):
|
||||
"""
|
||||
完整覆盖任务不能因候选允许集裁剪而把局部包误判为覆盖目标范围。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = []
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock()
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context(episode_list=[1, 2])
|
||||
context.allowed_episodes = {1, 2}
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=12,
|
||||
start_episode=1,
|
||||
require_complete_coverage=True,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == []
|
||||
assert lefts == no_exists
|
||||
chain.download_torrent.assert_not_called()
|
||||
chain.download_single.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_download_keeps_count_check_without_complete_coverage(monkeypatch):
|
||||
"""
|
||||
普通整季缺失仍沿用数量判断,避免完整覆盖语义影响非严格场景。
|
||||
"""
|
||||
_FakeBatchTorrentHelper.episodes = list(range(2, 145))
|
||||
monkeypatch.setattr(download_module, "TorrentHelper", _FakeBatchTorrentHelper)
|
||||
monkeypatch.setattr(download_module.eventmanager, "send_event", lambda *args, **kwargs: None)
|
||||
|
||||
chain = DownloadChain.__new__(DownloadChain)
|
||||
chain.download_torrent = MagicMock(return_value=(b"torrent-content", "", ["demo.mkv"]))
|
||||
chain.download_single = MagicMock(return_value="hash")
|
||||
|
||||
context = _build_tv_context()
|
||||
no_exists = {
|
||||
1: {
|
||||
1: NotExistMediaInfo(
|
||||
season=1,
|
||||
episodes=[],
|
||||
total_episode=143,
|
||||
start_episode=1,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
downloads, lefts = chain.batch_download(contexts=[context], no_exists=no_exists)
|
||||
|
||||
assert downloads == [context]
|
||||
assert lefts == {}
|
||||
chain.download_single.assert_called_once()
|
||||
|
||||
147
tests/test_emby_dashboard_links.py
Normal file
147
tests/test_emby_dashboard_links.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from app import schemas
|
||||
from app.api.endpoints.mediaserver import play_item
|
||||
from app.modules.emby.emby import Emby
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""提供 Emby 接口响应的最小 json 封装。"""
|
||||
|
||||
def __init__(self, payload: Any):
|
||||
"""保存测试预置的响应体。"""
|
||||
self._payload = payload
|
||||
|
||||
def json(self) -> Any:
|
||||
"""返回测试预置的响应体。"""
|
||||
return self._payload
|
||||
|
||||
|
||||
class EmbyDashboardLinksTest(unittest.TestCase):
|
||||
"""验证 Emby 仪表盘条目使用真实媒体服务器标识生成跳转链接。"""
|
||||
|
||||
@staticmethod
|
||||
def _build_client() -> Emby:
|
||||
"""构造绕过真实初始化的 Emby 实例。"""
|
||||
client = Emby.__new__(Emby)
|
||||
client._host = "http://emby.local/"
|
||||
client._playhost = None
|
||||
client._apikey = "api-key"
|
||||
client._sync_libraries = []
|
||||
client.user = "user-id"
|
||||
client.serverid = "server-id"
|
||||
return client
|
||||
|
||||
def test_get_server_id_falls_back_to_emby_prefixed_system_info(self):
|
||||
"""
|
||||
兼容 Emby 反代只暴露 /emby/System/Info 的场景,避免生成 serverId=None。
|
||||
"""
|
||||
client = self._build_client()
|
||||
client.serverid = None
|
||||
|
||||
with patch("app.modules.emby.emby.RequestUtils") as request_utils_cls:
|
||||
request_utils_cls.return_value.get_res.side_effect = [
|
||||
None,
|
||||
_FakeResponse({"Id": "server-id"}),
|
||||
]
|
||||
|
||||
server_id = client.get_server_id()
|
||||
|
||||
self.assertEqual(server_id, "server-id")
|
||||
self.assertEqual(
|
||||
request_utils_cls.return_value.get_res.call_args_list[0].args[0],
|
||||
"http://emby.local/System/Info",
|
||||
)
|
||||
self.assertEqual(
|
||||
request_utils_cls.return_value.get_res.call_args_list[1].args[0],
|
||||
"http://emby.local/emby/System/Info",
|
||||
)
|
||||
|
||||
def test_get_play_url_omits_missing_server_id(self):
|
||||
"""serverId 为空时不应把 None 字符串拼入播放链接。"""
|
||||
client = self._build_client()
|
||||
client.serverid = None
|
||||
|
||||
play_url = client.get_play_url("item-id")
|
||||
|
||||
self.assertEqual(
|
||||
play_url,
|
||||
"http://emby.local/web/index.html#!/item?id=item-id&context=home",
|
||||
)
|
||||
|
||||
def test_get_latest_returns_item_and_server_ids(self):
|
||||
"""最近入库条目需要显式返回 Emby item_id 和 server_id 供前端纠偏链接。"""
|
||||
client = self._build_client()
|
||||
client.get_user_library_folders = Mock(return_value=[])
|
||||
|
||||
with patch("app.modules.emby.emby.RequestUtils") as request_utils_cls:
|
||||
request_utils_cls.return_value.get_res.return_value = _FakeResponse([
|
||||
{
|
||||
"Id": "emby-item-id",
|
||||
"ServerId": "item-server-id",
|
||||
"Name": "测试电影",
|
||||
"Type": "Movie",
|
||||
"ProductionYear": 2026,
|
||||
}
|
||||
])
|
||||
|
||||
items = client.get_latest()
|
||||
|
||||
self.assertEqual(items[0].id, "emby-item-id")
|
||||
self.assertEqual(items[0].item_id, "emby-item-id")
|
||||
self.assertEqual(items[0].server_id, "item-server-id")
|
||||
self.assertIn("id=emby-item-id", items[0].link)
|
||||
self.assertIn("serverId=item-server-id", items[0].link)
|
||||
|
||||
def test_get_librarys_returns_item_and_server_ids(self):
|
||||
"""媒体库卡片需要返回 Emby parentId 和 server_id 供前端生成 App 跳转。"""
|
||||
client = self._build_client()
|
||||
|
||||
with (
|
||||
patch.object(client, "_Emby__get_emby_librarys") as librarys,
|
||||
patch.object(client, "_Emby__get_local_image_by_id") as image_by_id,
|
||||
):
|
||||
librarys.return_value = [
|
||||
{
|
||||
"Id": "library-id",
|
||||
"ServerId": "library-server-id",
|
||||
"Name": "电影库",
|
||||
"CollectionType": "movies",
|
||||
}
|
||||
]
|
||||
image_by_id.return_value = "http://emby.local/image"
|
||||
|
||||
items = client.get_librarys()
|
||||
|
||||
self.assertEqual(items[0].id, "library-id")
|
||||
self.assertEqual(items[0].item_id, "library-id")
|
||||
self.assertEqual(items[0].server_id, "library-server-id")
|
||||
self.assertIn("parentId=library-id", items[0].link)
|
||||
self.assertIn("serverId=library-server-id", items[0].link)
|
||||
|
||||
def test_play_item_returns_server_type(self):
|
||||
"""播放地址接口需要返回 server_type,供前端跳转时选择正确媒体服务器类型。"""
|
||||
item = schemas.MediaServerItem(server="emby", item_id="emby-item-id", server_id="server-id")
|
||||
|
||||
with (
|
||||
patch("app.api.endpoints.mediaserver.MediaServerHelper") as helper_cls,
|
||||
patch("app.api.endpoints.mediaserver.MediaServerChain") as chain_cls,
|
||||
):
|
||||
helper_cls.return_value.get_configs.return_value = {"Emby": object()}
|
||||
chain = chain_cls.return_value
|
||||
chain.iteminfo.return_value = item
|
||||
chain.get_play_url.return_value = "http://emby.local/web/index.html#!/item?id=emby-item-id"
|
||||
|
||||
response = play_item("emby-item-id")
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.data["url"], "http://emby.local/web/index.html#!/item?id=emby-item-id")
|
||||
self.assertEqual(response.data["item_id"], "emby-item-id")
|
||||
self.assertEqual(response.data["server_id"], "server-id")
|
||||
self.assertEqual(response.data["server_type"], "emby")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -108,8 +108,17 @@ def _build_fake_openai_modules(chat_openai_cls=_FakeChatOpenAIForPatch):
|
||||
def _convert_delta_to_message_chunk(delta, default_class):
|
||||
return AIMessageChunk(content=delta.get("content") or "")
|
||||
|
||||
def _construct_lc_result_from_responses_api(response, *args, **kwargs):
|
||||
"""模拟旧版 langchain-openai 直接遍历 response.output 的行为。"""
|
||||
for _item in response.output:
|
||||
pass
|
||||
return SimpleNamespace(args=args, kwargs=kwargs, response=response)
|
||||
|
||||
base_module._convert_dict_to_message = _convert_dict_to_message
|
||||
base_module._convert_delta_to_message_chunk = _convert_delta_to_message_chunk
|
||||
base_module._construct_lc_result_from_responses_api = (
|
||||
_construct_lc_result_from_responses_api
|
||||
)
|
||||
|
||||
return {
|
||||
"langchain_openai": openai_module,
|
||||
@@ -262,6 +271,39 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
"先调用工具",
|
||||
)
|
||||
|
||||
def test_openai_responses_patch_handles_completed_chunk_without_output(self):
|
||||
"""校验 Responses API 流式完成事件 output 为空时不再崩溃。"""
|
||||
|
||||
class _FakeResponse:
|
||||
"""模拟 OpenAI Responses API 完成事件里的 Response 对象。"""
|
||||
|
||||
def __init__(self, output):
|
||||
"""保存 output 字段用于复现空输出场景。"""
|
||||
self.output = output
|
||||
|
||||
def model_copy(self, update=None):
|
||||
"""模拟 Pydantic v2 model_copy(update=...) 行为。"""
|
||||
copied = _FakeResponse(self.output)
|
||||
for key, value in (update or {}).items():
|
||||
setattr(copied, key, value)
|
||||
return copied
|
||||
|
||||
fake_modules, openai_base = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
with self.assertRaises(TypeError):
|
||||
openai_base._construct_lc_result_from_responses_api(
|
||||
_FakeResponse(None)
|
||||
)
|
||||
|
||||
llm_module._patch_openai_responses_instructions_support()
|
||||
result = openai_base._construct_lc_result_from_responses_api(
|
||||
_FakeResponse(None),
|
||||
schema=object,
|
||||
)
|
||||
|
||||
self.assertEqual(result.response.output, [])
|
||||
self.assertEqual(result.kwargs.get("schema"), object)
|
||||
|
||||
def test_openai_compatible_patch_injects_xiaomi_reasoning_content(self):
|
||||
fake_modules, _ = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
|
||||
289
tests/test_security_image_url_log.py
Normal file
289
tests/test_security_image_url_log.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
覆盖 `SecurityUtils.is_safe_image_url_async` 的阻断分支与日志聚合接线:
|
||||
|
||||
- 各 `UrlSafetyReason` 分支落入正确的 warning 字段;
|
||||
- `NON_GLOBAL_DNS_RESULT` 且未配置允许网段时附 fake-ip 提示;
|
||||
- 签名 URL 校验通过时静默放行;URL 携带签名但失败时附 `invalid_signature` 标记;
|
||||
- 同 (host, reason) 高频拦截只输出首条 warning,窗口结束输出聚合摘要;
|
||||
- 不同 (host, reason) 互不吞并。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.utils import security as security_module
|
||||
from app.utils.coalesce import EventCoalescer
|
||||
from app.utils.security import (
|
||||
SecurityUtils,
|
||||
UrlSafetyDiagnosis,
|
||||
UrlSafetyReason,
|
||||
)
|
||||
|
||||
|
||||
_TEST_WINDOW = 0.05
|
||||
_TEST_WAIT = _TEST_WINDOW * 4
|
||||
|
||||
|
||||
def _diag(
|
||||
reason: UrlSafetyReason,
|
||||
*,
|
||||
host: Optional[str] = "image.tmdb.org",
|
||||
ips: Optional[List[str]] = None,
|
||||
) -> UrlSafetyDiagnosis:
|
||||
"""
|
||||
构造测试用 `UrlSafetyDiagnosis`:DOMAIN_NOT_ALLOWED 强制清空 host,保持与
|
||||
`evaluate_url_safety_async` 真实输出的字段约束一致。
|
||||
"""
|
||||
if reason is UrlSafetyReason.DOMAIN_NOT_ALLOWED:
|
||||
host = None
|
||||
return UrlSafetyDiagnosis(
|
||||
allowed=False,
|
||||
reason=reason,
|
||||
host=host,
|
||||
ips=ips or [],
|
||||
)
|
||||
|
||||
|
||||
class IsSafeImageUrlLogTest(IsolatedAsyncioTestCase):
|
||||
"""
|
||||
`is_safe_image_url_async` 阻断路径的结构化日志 + 聚合行为校验。
|
||||
"""
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
# 用短窗口实例临时替换模块级 coalescer,便于在测试内驱动窗口到期 flush
|
||||
self._original_coalescer = security_module._image_proxy_block_log_coalescer
|
||||
self._coalescer = EventCoalescer(
|
||||
window_seconds=_TEST_WINDOW,
|
||||
on_flush=security_module._log_image_proxy_block_summary,
|
||||
source="image_proxy_test",
|
||||
)
|
||||
security_module._image_proxy_block_log_coalescer = self._coalescer
|
||||
self._allowed_domains = {"image.tmdb.org"}
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self._coalescer.close()
|
||||
security_module._image_proxy_block_log_coalescer = self._original_coalescer
|
||||
|
||||
async def _invoke(
|
||||
self,
|
||||
diagnosis: UrlSafetyDiagnosis,
|
||||
*,
|
||||
url: str = "https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
signed_clean_url: Optional[str] = None,
|
||||
allowed_private_ranges: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
以指定诊断结果与签名校验返回值驱动 `is_safe_image_url_async`,捕获 warning。
|
||||
"""
|
||||
async def fake_evaluate(*_args, **_kwargs):
|
||||
return diagnosis
|
||||
|
||||
warns: List[str] = []
|
||||
with patch.object(
|
||||
SecurityUtils,
|
||||
"evaluate_url_safety_async",
|
||||
side_effect=fake_evaluate,
|
||||
), patch.object(
|
||||
SecurityUtils,
|
||||
"verify_signed_url",
|
||||
return_value=signed_clean_url,
|
||||
), patch.object(
|
||||
security_module.logger,
|
||||
"warn",
|
||||
side_effect=warns.append,
|
||||
):
|
||||
allowed = await SecurityUtils.is_safe_image_url_async(
|
||||
url,
|
||||
self._allowed_domains,
|
||||
allowed_private_ranges=allowed_private_ranges,
|
||||
)
|
||||
return allowed, warns
|
||||
|
||||
async def test_domain_not_allowed_emits_clean_reason_label(self):
|
||||
"""
|
||||
普通外链(未携带 mp_sig)撞 allowlist 失败时,warning 标记
|
||||
DOMAIN_NOT_ALLOWED,不附 fake-ip 提示,也不挂签名失败标记,
|
||||
避免误导未签名调用方以为必须签名。
|
||||
"""
|
||||
allowed, warns = await self._invoke(
|
||||
_diag(UrlSafetyReason.DOMAIN_NOT_ALLOWED),
|
||||
)
|
||||
|
||||
self.assertFalse(allowed)
|
||||
self.assertEqual(len(warns), 1)
|
||||
self.assertIn("reason=domain_not_allowed", warns[0])
|
||||
self.assertIn("Blocked unsafe image URL", warns[0])
|
||||
self.assertNotIn("fake-ip", warns[0])
|
||||
self.assertNotIn("invalid_signature", warns[0])
|
||||
|
||||
async def test_invalid_signature_tag_only_when_url_signed(self):
|
||||
"""
|
||||
URL 显式携带 `#mp_sig=...` 但校验失败时,reason 末尾追加
|
||||
`invalid_signature`,便于区分"签名失效"与"未签名外链拦截"。
|
||||
"""
|
||||
allowed, warns = await self._invoke(
|
||||
_diag(UrlSafetyReason.DOMAIN_NOT_ALLOWED),
|
||||
url="https://attacker.example.com/x.jpg#mp_sig=deadbeef&mp_purpose=image-proxy",
|
||||
)
|
||||
|
||||
self.assertFalse(allowed)
|
||||
self.assertEqual(len(warns), 1)
|
||||
self.assertIn(
|
||||
"reason=domain_not_allowed+invalid_signature", warns[0]
|
||||
)
|
||||
|
||||
async def test_non_global_dns_result_lists_ips_with_hint(self):
|
||||
"""
|
||||
DNS 解析到非公网且未配置允许网段时,warning 列出解析 IP 并附 fake-ip 提示。
|
||||
"""
|
||||
allowed, warns = await self._invoke(
|
||||
_diag(
|
||||
UrlSafetyReason.NON_GLOBAL_DNS_RESULT,
|
||||
ips=["198.18.16.96", "198.18.16.97"],
|
||||
),
|
||||
)
|
||||
|
||||
self.assertFalse(allowed)
|
||||
self.assertEqual(len(warns), 1)
|
||||
warning = warns[0]
|
||||
self.assertIn("reason=non_global_dns_result", warning)
|
||||
self.assertIn("host=image.tmdb.org", warning)
|
||||
self.assertIn("ips=198.18.16.96,198.18.16.97", warning)
|
||||
self.assertIn("IMAGE_PROXY_ALLOWED_PRIVATE_RANGES", warning)
|
||||
self.assertIn("198.18.0.0/15", warning)
|
||||
|
||||
async def test_configured_ranges_skip_fakeip_hint(self):
|
||||
"""
|
||||
已配置 allowed_private_ranges 时不再追加 fake-ip 提示,避免重复引导。
|
||||
warning 同时把已生效的网段列在字段里供运维对照。
|
||||
"""
|
||||
_, warns = await self._invoke(
|
||||
_diag(
|
||||
UrlSafetyReason.MIXED_OR_DISALLOWED_PRIVATE_RESULT,
|
||||
ips=["10.0.0.8"],
|
||||
),
|
||||
allowed_private_ranges=["198.18.0.0/15"],
|
||||
)
|
||||
|
||||
self.assertEqual(len(warns), 1)
|
||||
warning = warns[0]
|
||||
self.assertIn("reason=mixed_or_disallowed_private_result", warning)
|
||||
self.assertIn("allowed_private_ranges=198.18.0.0/15", warning)
|
||||
self.assertNotIn("提示", warning)
|
||||
|
||||
async def test_dns_resolution_failed_carries_empty_ips(self):
|
||||
"""
|
||||
DNS 解析失败的 warning 携带空 ips 字段,便于运维直接定位 DNS 路径。
|
||||
"""
|
||||
_, warns = await self._invoke(
|
||||
_diag(UrlSafetyReason.DNS_RESOLUTION_FAILED, ips=[]),
|
||||
)
|
||||
|
||||
self.assertEqual(len(warns), 1)
|
||||
self.assertIn("reason=dns_resolution_failed", warns[0])
|
||||
self.assertIn("ips=,", warns[0])
|
||||
|
||||
async def test_signed_url_success_silently_allows(self):
|
||||
"""
|
||||
标准校验失败但签名 URL 校验通过时返回 True,且不输出 warning,
|
||||
避免运维误判后端预签名路径是异常拦截。
|
||||
"""
|
||||
allowed, warns = await self._invoke(
|
||||
_diag(UrlSafetyReason.DOMAIN_NOT_ALLOWED),
|
||||
signed_clean_url="https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
)
|
||||
|
||||
self.assertTrue(allowed)
|
||||
self.assertEqual(warns, [])
|
||||
|
||||
async def test_repeated_block_in_window_emits_only_first_warning(self):
|
||||
"""
|
||||
同 (host, reason) 在窗口内的多次命中只输出首条 warning;窗口到期后
|
||||
补一条聚合摘要,count 等于窗口内总命中数,sample_url 来自首条事件。
|
||||
"""
|
||||
diag = _diag(
|
||||
UrlSafetyReason.NON_GLOBAL_DNS_RESULT,
|
||||
ips=["198.18.16.96"],
|
||||
)
|
||||
|
||||
async def fake_evaluate(*_args, **_kwargs):
|
||||
return diag
|
||||
|
||||
warns: List[str] = []
|
||||
with patch.object(
|
||||
SecurityUtils,
|
||||
"evaluate_url_safety_async",
|
||||
side_effect=fake_evaluate,
|
||||
), patch.object(
|
||||
SecurityUtils,
|
||||
"verify_signed_url",
|
||||
return_value=None,
|
||||
), patch.object(
|
||||
security_module.logger,
|
||||
"warn",
|
||||
side_effect=warns.append,
|
||||
):
|
||||
for i in range(5):
|
||||
await SecurityUtils.is_safe_image_url_async(
|
||||
f"https://image.tmdb.org/t/p/w500/{i}.jpg",
|
||||
self._allowed_domains,
|
||||
)
|
||||
self.assertEqual(len(warns), 1)
|
||||
self.assertIn("/0.jpg", warns[0])
|
||||
|
||||
await asyncio.sleep(_TEST_WAIT)
|
||||
|
||||
self.assertEqual(len(warns), 2)
|
||||
summary = warns[1]
|
||||
self.assertIn("aggregated", summary)
|
||||
self.assertIn("count=5", summary)
|
||||
self.assertIn("/0.jpg", summary)
|
||||
self.assertNotIn("/1.jpg", summary)
|
||||
self.assertNotIn("/4.jpg", summary)
|
||||
# 摘要附带首条样例的解析 IP,便于直接锁定批量拦截的网络成因
|
||||
self.assertIn("sample_ips=198.18.16.96", summary)
|
||||
|
||||
async def test_different_keys_do_not_collapse(self):
|
||||
"""
|
||||
不同 (host, reason) 各自计数与输出,互不吞并。
|
||||
"""
|
||||
warns: List[str] = []
|
||||
sequence = {
|
||||
"evil": _diag(UrlSafetyReason.DOMAIN_NOT_ALLOWED, host=None),
|
||||
"tmdb": _diag(
|
||||
UrlSafetyReason.NON_GLOBAL_DNS_RESULT,
|
||||
host="image.tmdb.org",
|
||||
ips=["198.18.16.96"],
|
||||
),
|
||||
}
|
||||
|
||||
async def fake_evaluate(url, *_args, **_kwargs):
|
||||
return sequence["evil"] if "evil" in url else sequence["tmdb"]
|
||||
|
||||
with patch.object(
|
||||
SecurityUtils,
|
||||
"evaluate_url_safety_async",
|
||||
side_effect=fake_evaluate,
|
||||
), patch.object(
|
||||
SecurityUtils,
|
||||
"verify_signed_url",
|
||||
return_value=None,
|
||||
), patch.object(
|
||||
security_module.logger,
|
||||
"warn",
|
||||
side_effect=warns.append,
|
||||
):
|
||||
await SecurityUtils.is_safe_image_url_async(
|
||||
"https://evil.example.com/x.jpg",
|
||||
self._allowed_domains,
|
||||
)
|
||||
await SecurityUtils.is_safe_image_url_async(
|
||||
"https://image.tmdb.org/t/p/w500/a.jpg",
|
||||
self._allowed_domains,
|
||||
)
|
||||
|
||||
self.assertEqual(len(warns), 2)
|
||||
self.assertIn("reason=domain_not_allowed", warns[0])
|
||||
self.assertIn("reason=non_global_dns_result", warns[1])
|
||||
@@ -4,6 +4,8 @@ from unittest.mock import patch
|
||||
|
||||
from app.utils.security import (
|
||||
SecurityUtils,
|
||||
UrlSafetyDiagnosis,
|
||||
UrlSafetyReason,
|
||||
_dns_inflight_locks,
|
||||
_dns_negative_cache,
|
||||
_dns_positive_cache,
|
||||
@@ -681,3 +683,171 @@ class SecurityUtilsTest(TestCase):
|
||||
_dns_inflight_locks,
|
||||
"并发等待者全部退出后必须释放 in-flight 锁字典条目",
|
||||
)
|
||||
|
||||
|
||||
class UrlSafetyDiagnosisTest(TestCase):
|
||||
"""
|
||||
覆盖 `evaluate_url_safety(_async)` 的结构化诊断结果,确保每条
|
||||
`UrlSafetyReason` 分支返回的字段满足日志渲染契约。
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
_dns_positive_cache.clear()
|
||||
_dns_negative_cache.clear()
|
||||
_dns_inflight_locks.clear()
|
||||
|
||||
def test_domain_not_allowed_returns_reason_and_no_host(self):
|
||||
"""
|
||||
协议或 allowlist 校验未通过时,诊断返回 DOMAIN_NOT_ALLOWED,
|
||||
且不暴露 host/ips 字段。
|
||||
"""
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://attacker.example.com/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
)
|
||||
|
||||
self.assertIsInstance(diag, UrlSafetyDiagnosis)
|
||||
self.assertFalse(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.DOMAIN_NOT_ALLOWED)
|
||||
self.assertIsNone(diag.host)
|
||||
self.assertEqual(diag.ips, [])
|
||||
self.assertEqual(diag.matched_private_ranges, [])
|
||||
|
||||
def test_allowed_without_block_private_skips_dns(self):
|
||||
"""
|
||||
未启用 block_private 时直接放行,不发起 DNS 解析,ips 保持为空。
|
||||
"""
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
side_effect=AssertionError("不应触发 DNS 解析"),
|
||||
):
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
)
|
||||
|
||||
self.assertTrue(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.ALLOWED)
|
||||
self.assertEqual(diag.host, "image.tmdb.org")
|
||||
self.assertEqual(diag.ips, [])
|
||||
|
||||
def test_dns_resolution_failed_carries_host_without_ips(self):
|
||||
"""
|
||||
`block_private=True` 下 DNS 抛错时返回 DNS_RESOLUTION_FAILED,
|
||||
附带 host 便于排查但不携带 ips。
|
||||
"""
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror,
|
||||
):
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
self.assertFalse(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.DNS_RESOLUTION_FAILED)
|
||||
self.assertEqual(diag.host, "image.tmdb.org")
|
||||
self.assertEqual(diag.ips, [])
|
||||
|
||||
def test_non_global_dns_result_lists_resolved_ips(self):
|
||||
"""
|
||||
命中 allowlist 但 DNS 解析到非公网且未配置允许网段时,诊断标记
|
||||
NON_GLOBAL_DNS_RESULT 并把解析到的 IP 列出来,供日志附带 fake-ip 提示。
|
||||
"""
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("198.18.16.96", 0)),
|
||||
],
|
||||
):
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
self.assertFalse(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.NON_GLOBAL_DNS_RESULT)
|
||||
self.assertEqual(diag.host, "image.tmdb.org")
|
||||
self.assertEqual(diag.ips, ["198.18.16.96"])
|
||||
self.assertEqual(diag.matched_private_ranges, [])
|
||||
|
||||
def test_mixed_private_and_public_with_ranges_reports_mixed_reason(self):
|
||||
"""
|
||||
配置了 allowed_private_ranges 但解析结果存在公网或不在允许网段内的私网
|
||||
地址时,诊断必须标记 MIXED_OR_DISALLOWED_PRIVATE_RESULT,避免与"未配置
|
||||
允许网段"场景混淆。
|
||||
"""
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("198.18.16.96", 0)),
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.8", 0)),
|
||||
],
|
||||
):
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
block_private=True,
|
||||
allowed_private_ranges=["198.18.0.0/15"],
|
||||
)
|
||||
|
||||
self.assertFalse(diag.allowed)
|
||||
self.assertIs(
|
||||
diag.reason, UrlSafetyReason.MIXED_OR_DISALLOWED_PRIVATE_RESULT
|
||||
)
|
||||
self.assertEqual(diag.ips, ["198.18.16.96", "10.0.0.8"])
|
||||
|
||||
def test_allowed_via_configured_private_range_reports_matched_networks(self):
|
||||
"""
|
||||
通过 allowed_private_ranges 放行时返回 ALLOWED,同时把命中的 IP 与
|
||||
网段填入诊断对象,便于排查日志确认放行依据。
|
||||
"""
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("198.18.16.96", 0)),
|
||||
],
|
||||
):
|
||||
diag = SecurityUtils.evaluate_url_safety(
|
||||
"https://image.tmdb.org/t/p/w500/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
block_private=True,
|
||||
allowed_private_ranges=["198.18.0.0/15"],
|
||||
)
|
||||
|
||||
self.assertTrue(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.ALLOWED)
|
||||
self.assertEqual(diag.ips, ["198.18.16.96"])
|
||||
self.assertEqual(diag.matched_private_ranges, ["198.18.0.0/15"])
|
||||
|
||||
def test_async_evaluation_returns_same_diagnosis(self):
|
||||
"""
|
||||
异步版本走事件循环线程池但应保持与同步版本一致的诊断结果。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
async def fake_getaddrinfo(host, *_args, **_kwargs):
|
||||
return [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("198.18.16.96", 0)),
|
||||
]
|
||||
|
||||
async def run():
|
||||
with patch.object(
|
||||
asyncio.get_running_loop(),
|
||||
"getaddrinfo",
|
||||
side_effect=fake_getaddrinfo,
|
||||
create=True,
|
||||
):
|
||||
return await SecurityUtils.evaluate_url_safety_async(
|
||||
"https://image.tmdb.org/x.jpg",
|
||||
{"image.tmdb.org"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
diag = asyncio.run(run())
|
||||
self.assertFalse(diag.allowed)
|
||||
self.assertIs(diag.reason, UrlSafetyReason.NON_GLOBAL_DNS_RESULT)
|
||||
self.assertEqual(diag.ips, ["198.18.16.96"])
|
||||
|
||||
64
tests/test_site_cookie_endpoint.py
Normal file
64
tests/test_site_cookie_endpoint.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from app import schemas
|
||||
from app.api.endpoints import site as site_endpoint
|
||||
|
||||
|
||||
def test_update_cookie_by_body_uses_request_body():
|
||||
"""
|
||||
POST 更新站点 Cookie 时应从请求体读取登录参数。
|
||||
"""
|
||||
fake_site = SimpleNamespace(id=1, name="TestSite")
|
||||
fake_chain = Mock()
|
||||
fake_chain.update_cookie.return_value = (True, "ok")
|
||||
request = schemas.SiteCookieUpdate(username="user", password="password", code="123456")
|
||||
|
||||
with patch.object(site_endpoint.Site, "get", return_value=fake_site), patch.object(
|
||||
site_endpoint, "SiteChain", return_value=fake_chain
|
||||
):
|
||||
response = site_endpoint.update_cookie_by_body(
|
||||
site_id=1,
|
||||
site_cookie_update=request,
|
||||
db=Mock(),
|
||||
_=Mock(),
|
||||
)
|
||||
|
||||
assert response.success is True
|
||||
assert response.message == "ok"
|
||||
fake_chain.update_cookie.assert_called_once_with(
|
||||
site_info=fake_site,
|
||||
username="user",
|
||||
password="password",
|
||||
two_step_code="123456",
|
||||
)
|
||||
|
||||
|
||||
def test_update_cookie_legacy_get_keeps_query_params():
|
||||
"""
|
||||
旧 GET 入口仍应兼容查询参数更新站点 Cookie。
|
||||
"""
|
||||
fake_site = SimpleNamespace(id=1, name="TestSite")
|
||||
fake_chain = Mock()
|
||||
fake_chain.update_cookie.return_value = (False, "failed")
|
||||
|
||||
with patch.object(site_endpoint.Site, "get", return_value=fake_site), patch.object(
|
||||
site_endpoint, "SiteChain", return_value=fake_chain
|
||||
):
|
||||
response = site_endpoint.update_cookie(
|
||||
site_id=1,
|
||||
username="user",
|
||||
password="password",
|
||||
code=None,
|
||||
db=Mock(),
|
||||
_=Mock(),
|
||||
)
|
||||
|
||||
assert response.success is False
|
||||
assert response.message == "failed"
|
||||
fake_chain.update_cookie.assert_called_once_with(
|
||||
site_info=fake_site,
|
||||
username="user",
|
||||
password="password",
|
||||
two_step_code=None,
|
||||
)
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.13.2'
|
||||
FRONTEND_VERSION = 'v2.13.2'
|
||||
APP_VERSION = 'v2.13.3'
|
||||
FRONTEND_VERSION = 'v2.13.3'
|
||||
|
||||
Reference in New Issue
Block a user