mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-03 07:26:51 +00:00
chore: update text processing dependencies
This commit is contained in:
@@ -2,11 +2,9 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||||
from typing import Annotated, Any, NotRequired
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
@@ -16,78 +14,18 @@ from langchain.agents.middleware.types import (
|
||||
from langchain.agents.middleware.types import (
|
||||
PrivateStateAttr, # noqa
|
||||
)
|
||||
from langchain.agents.middleware.tool_selection import (
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.log import logger
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)]
|
||||
for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
@@ -102,9 +40,7 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
@@ -129,94 +65,19 @@ class ToolSelectorMiddleware(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
model: BaseChatModel | str | None = None,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [
|
||||
tool for tool in base_tools if tool.name not in self.always_include
|
||||
]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
super().__init__(
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
system_prompt=system_prompt,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include,
|
||||
)
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
@@ -225,46 +86,29 @@ class ToolSelectorMiddleware(
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
if selected_tool_names:
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
else:
|
||||
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
|
||||
"""
|
||||
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
|
||||
"""
|
||||
if response.get("tools") == []:
|
||||
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
|
||||
selected_tools = available_tools
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
return request.override(
|
||||
tools=[*available_tools, *always_included_tools, *provider_tools]
|
||||
)
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
return super()._process_selection_response(
|
||||
response,
|
||||
available_tools,
|
||||
valid_tool_names,
|
||||
request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import jieba
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
@@ -11,6 +10,7 @@ from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
from app.utils.jieba import cut as jieba_cut
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
@@ -69,8 +69,8 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用 jieba 分词处理标题
|
||||
words = jieba.cut(title, HMM=False)
|
||||
# 使用 fast-jieba 分词处理标题。
|
||||
words = jieba_cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await TransferHistory.async_list_by_title(
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional
|
||||
|
||||
import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -24,6 +23,7 @@ from app.db.user_oper import (
|
||||
)
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.jieba import cut as jieba_cut
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -272,7 +272,7 @@ async def transfer_history(
|
||||
db, title=like_pattern, page=page, count=count, status=status, wildcard=True
|
||||
)
|
||||
else:
|
||||
words = jieba.cut(title, HMM=False)
|
||||
words = jieba_cut(title, HMM=False)
|
||||
like_pattern = "%".join(words)
|
||||
total = await TransferHistory.async_count_by_title(
|
||||
db, title=like_pattern, status=status
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import re
|
||||
import traceback
|
||||
|
||||
import zhconv
|
||||
import anitopy
|
||||
from app.core.meta.customization import CustomizationMatcher
|
||||
from app.core.meta.metabase import MetaBase
|
||||
from app.core.meta.releasegroup import ReleaseGroupsMatcher
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.zhconv import convert as zhconv_convert
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
@@ -219,7 +219,7 @@ class MetaAnime(MetaBase):
|
||||
# 截掉分类
|
||||
first_item = title.split(']')[0]
|
||||
if first_item and re.search(r"[动漫画纪录片电影视连续剧集日美韩中港台海外亚洲华语大陆综艺原盘高清]{2,}|TV|Animation|Movie|Documentar|Anime",
|
||||
zhconv.convert(first_item, "zh-hans"),
|
||||
zhconv_convert(first_item, "zh-hans"),
|
||||
re.IGNORECASE):
|
||||
title = re.sub(r"^[^]]*]", "", title).strip()
|
||||
# 去掉大小
|
||||
|
||||
@@ -2,7 +2,6 @@ import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import cn2an
|
||||
import zhconv
|
||||
|
||||
from app import schemas
|
||||
from app.core.config import settings
|
||||
@@ -19,6 +18,7 @@ from app.schemas.types import MediaType, ModuleType, MediaRecognizeType
|
||||
from app.utils.common import retry
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.limit import rate_limit_exponential
|
||||
from app.utils.zhconv import convert as zhconv_convert
|
||||
|
||||
|
||||
class DoubanModule(_ModuleBase):
|
||||
@@ -77,7 +77,7 @@ class DoubanModule(_ModuleBase):
|
||||
准备搜索名称列表,保留中英文名称分别识别且按顺序去重的历史行为。
|
||||
"""
|
||||
# 简体名称
|
||||
zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None
|
||||
zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None
|
||||
# 使用中英文名分别识别,去重去空,但要保持顺序
|
||||
return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k]))
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import re
|
||||
from typing import Optional, List, Tuple, Union, Dict
|
||||
|
||||
import cn2an
|
||||
import zhconv
|
||||
|
||||
from app import schemas
|
||||
from app.core.config import settings
|
||||
@@ -17,6 +16,7 @@ from app.modules.themoviedb.tmdbapi import TmdbApi
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.zhconv import convert as zhconv_convert
|
||||
|
||||
|
||||
_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
||||
@@ -116,7 +116,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
准备搜索名称列表
|
||||
"""
|
||||
# 简体名称
|
||||
zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None
|
||||
zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None
|
||||
# 使用中英文名分别识别,去重去空,但要保持顺序
|
||||
return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k]))
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ import re
|
||||
import traceback
|
||||
from typing import Optional, List
|
||||
|
||||
import zhconv
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.zhconv import convert as zhconv_convert
|
||||
from .tmdbv3api import TMDb, Search, Movie, TV, Season, Episode, Discover, Trending, Person, Collection
|
||||
from .tmdbv3api.exceptions import TMDbException
|
||||
|
||||
@@ -726,7 +725,7 @@ class TmdbApi:
|
||||
if iso_3166_1 == "CN":
|
||||
title = alternative_title.get("title")
|
||||
if title and StringUtils.is_chinese(title) \
|
||||
and zhconv.convert(title, "zh-hans") == title:
|
||||
and zhconv_convert(title, "zh-hans") == title:
|
||||
return title
|
||||
return tmdbinfo.get("title") if tmdbinfo.get("media_type") == MediaType.MOVIE else tmdbinfo.get("name")
|
||||
|
||||
|
||||
10
app/utils/jieba.py
Normal file
10
app/utils/jieba.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""中文分词工具。"""
|
||||
|
||||
from fast_jieba import cut as fast_jieba_cut
|
||||
|
||||
|
||||
def cut(text: str, HMM: bool = True, cut_all: bool = False) -> list[str]:
|
||||
"""
|
||||
使用 fast-jieba 执行中文分词,并兼容 jieba.cut 的常用参数名。
|
||||
"""
|
||||
return fast_jieba_cut(text, hmm=HMM, cut_all=cut_all)
|
||||
10
app/utils/zhconv.py
Normal file
10
app/utils/zhconv.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""中文简繁转换工具。"""
|
||||
|
||||
from zhconv_rs import zhconv as _zhconv # pylint: disable=no-name-in-module
|
||||
|
||||
|
||||
def convert(text: str, target: str) -> str:
|
||||
"""
|
||||
使用 zhconv-rs 执行中文简繁转换,并隔离第三方包的函数名差异。
|
||||
"""
|
||||
return _zhconv(text, target)
|
||||
@@ -14,10 +14,10 @@ alembic~=1.16.2
|
||||
anyio~=4.10.0
|
||||
bcrypt~=4.0.1
|
||||
regex~=2024.11.6
|
||||
cn2an~=0.5.19
|
||||
cn2an~=0.5.24
|
||||
dateparser~=1.2.2
|
||||
python-dateutil~=2.8.2
|
||||
zhconv~=1.4.3
|
||||
zhconv-rs~=0.4.1
|
||||
anitopy~=2.1.1
|
||||
requests[socks]~=2.32.4
|
||||
urllib3~=2.5.0
|
||||
@@ -41,6 +41,7 @@ pyTelegramBotAPI~=4.27.0
|
||||
telegramify-markdown~=0.5.2
|
||||
cloakbrowser~=0.3.28
|
||||
torrentool~=1.2.0
|
||||
fast-bencode~=1.1.7
|
||||
slack-bolt~=1.23.0
|
||||
slack-sdk~=3.35.0
|
||||
discord.py==2.6.4
|
||||
@@ -63,7 +64,7 @@ pywebpush~=2.0.3
|
||||
aiosqlite~=0.21.0
|
||||
psycopg2-binary~=2.9.10
|
||||
asyncpg~=0.30.0
|
||||
jieba~=0.42.1
|
||||
fast-jieba~=0.4.0
|
||||
rsa~=4.9
|
||||
redis~=6.2.0
|
||||
async_timeout~=5.0.1; python_full_version < "3.11.3"
|
||||
@@ -75,17 +76,17 @@ pympler~=1.1
|
||||
smbprotocol~=1.15.0
|
||||
setproctitle~=1.3.6
|
||||
httpx[socks,http2]~=0.28.1
|
||||
langchain~=1.2.15
|
||||
langchain-core~=1.3.2
|
||||
langchain-community~=0.4.1
|
||||
langchain-anthropic~=1.4.2
|
||||
langchain-openai~=1.2.1
|
||||
langchain-google-genai~=4.2.2
|
||||
langchain~=1.3.1
|
||||
langchain-core~=1.4.0
|
||||
langchain-community~=0.4.2
|
||||
langchain-anthropic~=1.4.3
|
||||
langchain-openai~=1.2.2
|
||||
langchain-google-genai~=4.2.3
|
||||
langchain-deepseek~=1.0.1
|
||||
langgraph~=1.1.9
|
||||
anthropic>=0.57,<1
|
||||
openai~=2.33.0
|
||||
google-genai~=1.74.0
|
||||
langgraph~=1.2.1
|
||||
anthropic~=0.104.1
|
||||
openai~=2.38.0
|
||||
google-genai~=1.75.0
|
||||
ddgs~=9.10.0
|
||||
websocket-client~=1.8.0
|
||||
lark-oapi~=1.4.23
|
||||
|
||||
9
tests/test_fast_jieba_utils.py
Normal file
9
tests/test_fast_jieba_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.utils.jieba import cut
|
||||
|
||||
|
||||
def test_cut_accepts_legacy_hmm_argument():
|
||||
"""验证兼容封装仍支持旧 jieba.cut 的 HMM 参数名。"""
|
||||
words = cut("台湾后台测试", HMM=False)
|
||||
|
||||
assert "".join(words) == "台湾后台测试"
|
||||
assert "后台" in words
|
||||
@@ -10,7 +10,6 @@ from unittest.mock import ANY, MagicMock, patch
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
sys.modules.setdefault("cn2an", ModuleType("cn2an"))
|
||||
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
|
||||
sys.modules.setdefault("zhconv", ModuleType("zhconv"))
|
||||
|
||||
if "Pinyin2Hanzi" not in sys.modules:
|
||||
pinyin_module = ModuleType("Pinyin2Hanzi")
|
||||
|
||||
Reference in New Issue
Block a user