chore: update text processing dependencies

This commit is contained in:
jxxghp
2026-05-23 11:51:57 +08:00
parent 5f0ae3a75e
commit 00fc8b2f53
12 changed files with 87 additions and 215 deletions

View File

@@ -2,11 +2,9 @@
import json
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Annotated, Any, Literal, Union, NotRequired
from typing import Annotated, Any, NotRequired
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
@@ -16,78 +14,18 @@ from langchain.agents.middleware.types import (
from langchain.agents.middleware.types import (
PrivateStateAttr, # noqa
)
from langchain.agents.middleware.tool_selection import (
DEFAULT_SYSTEM_PROMPT,
LLMToolSelectorMiddleware,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict # noqa
from app.log import logger
DEFAULT_SYSTEM_PROMPT = (
"Your goal is to select the most relevant tools for answering the user's query."
)
@dataclass
class _SelectionRequest:
"""Prepared inputs for tool selection."""
available_tools: list[BaseTool]
system_message: str
last_user_message: HumanMessage
model: BaseChatModel
valid_tool_names: list[str]
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
"""Create a structured output schema for tool selection.
Args:
tools: Available tools to include in the schema.
Returns:
`TypeAdapter` for a schema where each tool name is a `Literal` with its
description.
Raises:
AssertionError: If `tools` is empty.
"""
if not tools:
msg = "Invalid usage: tools must be non-empty"
raise AssertionError(msg)
# Create a Union of Annotated Literal types for each tool name with description
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
literals = [
Annotated[Literal[tool.name], Field(description=tool.description)]
for tool in tools # noqa
]
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
description = "Tools to use. Place the most relevant tools first."
class ToolSelectionResponse(TypedDict):
"""Use to select relevant tools."""
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
return TypeAdapter(ToolSelectionResponse)
def _render_tool_list(tools: list[BaseTool]) -> str:
"""Format tools as markdown list.
Args:
tools: Tools to format.
Returns:
Markdown string with each tool on a new line.
"""
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
class ToolSelectionState(AgentState):
"""工具筛选中间件私有状态。"""
@@ -102,9 +40,7 @@ class ToolSelectionStateUpdate(TypedDict):
selected_tool_names: list[str] | None
class ToolSelectorMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
@@ -129,94 +65,19 @@ class ToolSelectorMiddleware(
def __init__(
self,
model: BaseChatModel,
model: BaseChatModel | str | None = None,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
selection_tools: list[Any] | None = None,
max_tools: int | None = None,
always_include: list[str] | None = None,
) -> None:
super().__init__()
self.model = model
self.system_prompt = system_prompt
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
def _prepare_selection_request(
self, request: ModelRequest[ContextT]
) -> _SelectionRequest | None:
"""Prepare inputs for tool selection.
Args:
request: the model request.
Returns:
`SelectionRequest` with prepared inputs, or `None` if no selection is
needed.
Raises:
ValueError: If tools in `always_include` are not found in the request.
AssertionError: If no user message is found in the request messages.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
return None
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
# Validate that always_include tools exist
if self.always_include:
available_tool_names = {tool.name for tool in base_tools}
missing_tools = [
name for name in self.always_include if name not in available_tool_names
]
if missing_tools:
msg = (
f"Tools in always_include not found in request: {missing_tools}. "
f"Available tools: {sorted(available_tool_names)}"
)
raise ValueError(msg)
# Separate tools that are always included from those available for selection
available_tools = [
tool for tool in base_tools if tool.name not in self.always_include
]
# If no tools available for selection, return None
if not available_tools:
return None
system_message = self.system_prompt
# If there's a max_tools limit, append instructions to the system prompt
if self.max_tools is not None:
system_message += (
f"\nIMPORTANT: List the tool names in order of relevance, "
f"with the most relevant first. "
f"If you exceed the maximum number of tools, "
f"only the first {self.max_tools} will be used."
)
# Get the last user message from the conversation history
last_user_message: HumanMessage
for message in reversed(request.messages):
if isinstance(message, HumanMessage):
last_user_message = message
break
else:
msg = "No user message found in request messages"
raise AssertionError(msg)
model = self.model or request.model
valid_tool_names = [tool.name for tool in available_tools]
return _SelectionRequest(
available_tools=available_tools,
system_message=system_message,
last_user_message=last_user_message,
super().__init__(
model=model,
valid_tool_names=valid_tool_names,
system_prompt=system_prompt,
max_tools=max_tools,
always_include=always_include,
)
self.selection_tools = selection_tools or []
def _process_selection_response(
self,
@@ -225,46 +86,29 @@ class ToolSelectorMiddleware(
valid_tool_names: list[str],
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT]:
"""Process the selection response and return filtered `ModelRequest`."""
selected_tool_names: list[str] = []
invalid_tool_selections = []
for tool_name in response["tools"]:
if tool_name not in valid_tool_names:
invalid_tool_selections.append(tool_name)
continue
# Only add if not already selected and within max_tools limit
if tool_name not in selected_tool_names and (
self.max_tools is None or len(selected_tool_names) < self.max_tools
):
selected_tool_names.append(tool_name)
if invalid_tool_selections:
msg = f"Model selected invalid tools: {invalid_tool_selections}"
raise ValueError(msg)
# Filter tools based on selection and append always-included tools
if selected_tool_names:
selected_tools: list[BaseTool] = [
tool for tool in available_tools if tool.name in selected_tool_names
]
else:
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
"""
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
"""
if response.get("tools") == []:
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
selected_tools = available_tools
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
if not isinstance(tool, dict) and tool.name in self.always_include
]
selected_tools.extend(always_included_tools)
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
if not isinstance(tool, dict) and tool.name in self.always_include
]
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
return request.override(
tools=[*available_tools, *always_included_tools, *provider_tools]
)
return request.override(tools=[*selected_tools, *provider_tools])
return super()._process_selection_response(
response,
available_tools,
valid_tool_names,
request,
)
@staticmethod
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:

View File

@@ -3,7 +3,6 @@
import json
from typing import Optional, Type
import jieba
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
@@ -11,6 +10,7 @@ from app.db import AsyncSessionFactory
from app.db.models.transferhistory import TransferHistory
from app.log import logger
from app.schemas.types import media_type_to_agent
from app.utils.jieba import cut as jieba_cut
class QueryTransferHistoryInput(BaseModel):
@@ -69,8 +69,8 @@ class QueryTransferHistoryTool(MoviePilotTool):
async with AsyncSessionFactory() as db:
# 处理标题搜索
if title:
# 使用 jieba 分词处理标题
words = jieba.cut(title, HMM=False)
# 使用 fast-jieba 分词处理标题
words = jieba_cut(title, HMM=False)
title_search = "%".join(words)
# 查询记录
result = await TransferHistory.async_list_by_title(

View File

@@ -3,7 +3,6 @@ import time
from pathlib import Path
from typing import List, Any, Optional
import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
@@ -24,6 +23,7 @@ from app.db.user_oper import (
)
from app.helper.progress import ProgressHelper
from app.schemas.types import EventType
from app.utils.jieba import cut as jieba_cut
router = APIRouter()
@@ -272,7 +272,7 @@ async def transfer_history(
db, title=like_pattern, page=page, count=count, status=status, wildcard=True
)
else:
words = jieba.cut(title, HMM=False)
words = jieba_cut(title, HMM=False)
like_pattern = "%".join(words)
total = await TransferHistory.async_count_by_title(
db, title=like_pattern, status=status

View File

@@ -1,13 +1,13 @@
import re
import traceback
import zhconv
import anitopy
from app.core.meta.customization import CustomizationMatcher
from app.core.meta.metabase import MetaBase
from app.core.meta.releasegroup import ReleaseGroupsMatcher
from app.log import logger
from app.utils.string import StringUtils
from app.utils.zhconv import convert as zhconv_convert
from app.schemas.types import MediaType
@@ -219,7 +219,7 @@ class MetaAnime(MetaBase):
# 截掉分类
first_item = title.split(']')[0]
if first_item and re.search(r"[动漫画纪录片电影视连续剧集日美韩中港台海外亚洲华语大陆综艺原盘高清]{2,}|TV|Animation|Movie|Documentar|Anime",
zhconv.convert(first_item, "zh-hans"),
zhconv_convert(first_item, "zh-hans"),
re.IGNORECASE):
title = re.sub(r"^[^]]*]", "", title).strip()
# 去掉大小

View File

@@ -2,7 +2,6 @@ import re
from typing import List, Optional, Tuple, Union
import cn2an
import zhconv
from app import schemas
from app.core.config import settings
@@ -19,6 +18,7 @@ from app.schemas.types import MediaType, ModuleType, MediaRecognizeType
from app.utils.common import retry
from app.utils.http import RequestUtils
from app.utils.limit import rate_limit_exponential
from app.utils.zhconv import convert as zhconv_convert
class DoubanModule(_ModuleBase):
@@ -77,7 +77,7 @@ class DoubanModule(_ModuleBase):
准备搜索名称列表,保留中英文名称分别识别且按顺序去重的历史行为。
"""
# 简体名称
zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None
zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None
# 使用中英文名分别识别,去重去空,但要保持顺序
return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k]))

View File

@@ -2,7 +2,6 @@ import re
from typing import Optional, List, Tuple, Union, Dict
import cn2an
import zhconv
from app import schemas
from app.core.config import settings
@@ -17,6 +16,7 @@ from app.modules.themoviedb.tmdbapi import TmdbApi
from app.schemas.category import CategoryConfig
from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType
from app.utils.http import RequestUtils
from app.utils.zhconv import convert as zhconv_convert
_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
@@ -116,7 +116,7 @@ class TheMovieDbModule(_ModuleBase):
准备搜索名称列表
"""
# 简体名称
zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None
zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None
# 使用中英文名分别识别,去重去空,但要保持顺序
return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k]))

View File

@@ -2,12 +2,11 @@ import re
import traceback
from typing import Optional, List
import zhconv
from app.core.config import settings
from app.log import logger
from app.schemas.types import MediaType
from app.utils.string import StringUtils
from app.utils.zhconv import convert as zhconv_convert
from .tmdbv3api import TMDb, Search, Movie, TV, Season, Episode, Discover, Trending, Person, Collection
from .tmdbv3api.exceptions import TMDbException
@@ -726,7 +725,7 @@ class TmdbApi:
if iso_3166_1 == "CN":
title = alternative_title.get("title")
if title and StringUtils.is_chinese(title) \
and zhconv.convert(title, "zh-hans") == title:
and zhconv_convert(title, "zh-hans") == title:
return title
return tmdbinfo.get("title") if tmdbinfo.get("media_type") == MediaType.MOVIE else tmdbinfo.get("name")

10
app/utils/jieba.py Normal file
View File

@@ -0,0 +1,10 @@
"""中文分词工具。"""
from fast_jieba import cut as fast_jieba_cut
def cut(text: str, HMM: bool = True, cut_all: bool = False) -> list[str]:
"""
使用 fast-jieba 执行中文分词,并兼容 jieba.cut 的常用参数名。
"""
return fast_jieba_cut(text, hmm=HMM, cut_all=cut_all)

10
app/utils/zhconv.py Normal file
View File

@@ -0,0 +1,10 @@
"""中文简繁转换工具。"""
from zhconv_rs import zhconv as _zhconv # pylint: disable=no-name-in-module
def convert(text: str, target: str) -> str:
"""
使用 zhconv-rs 执行中文简繁转换,并隔离第三方包的函数名差异。
"""
return _zhconv(text, target)

View File

@@ -14,10 +14,10 @@ alembic~=1.16.2
anyio~=4.10.0
bcrypt~=4.0.1
regex~=2024.11.6
cn2an~=0.5.19
cn2an~=0.5.24
dateparser~=1.2.2
python-dateutil~=2.8.2
zhconv~=1.4.3
zhconv-rs~=0.4.1
anitopy~=2.1.1
requests[socks]~=2.32.4
urllib3~=2.5.0
@@ -41,6 +41,7 @@ pyTelegramBotAPI~=4.27.0
telegramify-markdown~=0.5.2
cloakbrowser~=0.3.28
torrentool~=1.2.0
fast-bencode~=1.1.7
slack-bolt~=1.23.0
slack-sdk~=3.35.0
discord.py==2.6.4
@@ -63,7 +64,7 @@ pywebpush~=2.0.3
aiosqlite~=0.21.0
psycopg2-binary~=2.9.10
asyncpg~=0.30.0
jieba~=0.42.1
fast-jieba~=0.4.0
rsa~=4.9
redis~=6.2.0
async_timeout~=5.0.1; python_full_version < "3.11.3"
@@ -75,17 +76,17 @@ pympler~=1.1
smbprotocol~=1.15.0
setproctitle~=1.3.6
httpx[socks,http2]~=0.28.1
langchain~=1.2.15
langchain-core~=1.3.2
langchain-community~=0.4.1
langchain-anthropic~=1.4.2
langchain-openai~=1.2.1
langchain-google-genai~=4.2.2
langchain~=1.3.1
langchain-core~=1.4.0
langchain-community~=0.4.2
langchain-anthropic~=1.4.3
langchain-openai~=1.2.2
langchain-google-genai~=4.2.3
langchain-deepseek~=1.0.1
langgraph~=1.1.9
anthropic>=0.57,<1
openai~=2.33.0
google-genai~=1.74.0
langgraph~=1.2.1
anthropic~=0.104.1
openai~=2.38.0
google-genai~=1.75.0
ddgs~=9.10.0
websocket-client~=1.8.0
lark-oapi~=1.4.23

View File

@@ -0,0 +1,9 @@
from app.utils.jieba import cut
def test_cut_accepts_legacy_hmm_argument():
"""验证兼容封装仍支持旧 jieba.cut 的 HMM 参数名。"""
words = cut("台湾后台测试", HMM=False)
assert "".join(words) == "台湾后台测试"
assert "后台" in words

View File

@@ -10,7 +10,6 @@ from unittest.mock import ANY, MagicMock, patch
sys.modules.setdefault("psutil", ModuleType("psutil"))
sys.modules.setdefault("cn2an", ModuleType("cn2an"))
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
sys.modules.setdefault("zhconv", ModuleType("zhconv"))
if "Pinyin2Hanzi" not in sys.modules:
pinyin_module = ModuleType("Pinyin2Hanzi")