mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-13 07:26:45 +00:00
Compare commits
248 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64cb5742d2 | ||
|
|
4601c41794 | ||
|
|
6167e7e6a2 | ||
|
|
a106738de5 | ||
|
|
e0ce11a9d3 | ||
|
|
3052f2cb31 | ||
|
|
7905e622f9 | ||
|
|
3fa5d31d81 | ||
|
|
9e5cb702c5 | ||
|
|
ed380e2a17 | ||
|
|
bc358fc6d2 | ||
|
|
223854d4c6 | ||
|
|
7c73a57bbc | ||
|
|
2b9f5d8d90 | ||
|
|
437baec620 | ||
|
|
1c41d9f253 | ||
|
|
db522e8829 | ||
|
|
e43adf51af | ||
|
|
d353e7b208 | ||
|
|
df732731d9 | ||
|
|
ac5374c244 | ||
|
|
fcdba27a5d | ||
|
|
e4242058e2 | ||
|
|
b7c78da214 | ||
|
|
ba2feb2bfe | ||
|
|
6f014cee14 | ||
|
|
6453935584 | ||
|
|
40d0b60aa2 | ||
|
|
1922cce499 | ||
|
|
c89df496a5 | ||
|
|
855681ff35 | ||
|
|
13b2163788 | ||
|
|
5d3c262e60 | ||
|
|
a5c44a5097 | ||
|
|
16ada1a6c4 | ||
|
|
ac09ce5230 | ||
|
|
2255b61195 | ||
|
|
314ac3903c | ||
|
|
5c3796bf73 | ||
|
|
492e3c333b | ||
|
|
cce72d0884 | ||
|
|
69a064e986 | ||
|
|
f4ca4120bc | ||
|
|
b45956f850 | ||
|
|
762a7fbba7 | ||
|
|
10290ca17b | ||
|
|
12a2561ca8 | ||
|
|
543bee9ad5 | ||
|
|
cc3e062262 | ||
|
|
bf4f5f8744 | ||
|
|
f8f06a602a | ||
|
|
3cb8925e92 | ||
|
|
3ffdf1b38e | ||
|
|
6557b8b9d8 | ||
|
|
2b2e088784 | ||
|
|
d9a06f4433 | ||
|
|
b1259fdc02 | ||
|
|
0e5c592862 | ||
|
|
db3ad91408 | ||
|
|
5b6b4c9744 | ||
|
|
990a28b51b | ||
|
|
b6ffd286fe | ||
|
|
1f7fb304dd | ||
|
|
896631d63e | ||
|
|
db8363fee1 | ||
|
|
31554bdcb5 | ||
|
|
ccbcce0573 | ||
|
|
e00e18f31e | ||
|
|
c7965edd47 | ||
|
|
8aeba8a6d2 | ||
|
|
aee8b05737 | ||
|
|
821bd3decd | ||
|
|
b65c8dcfe0 | ||
|
|
877d89abb3 | ||
|
|
d4718bf9dc | ||
|
|
8bd1288e7e | ||
|
|
a65c5364d9 | ||
|
|
f761e07779 | ||
|
|
91f6ad092e | ||
|
|
c33c62b938 | ||
|
|
05943287c0 | ||
|
|
94633173b1 | ||
|
|
7ab1a668cb | ||
|
|
d57deb1df1 | ||
|
|
d940373f6b | ||
|
|
ca01b8ec3f | ||
|
|
384d6a3fe1 | ||
|
|
922e8473c5 | ||
|
|
01c3451679 | ||
|
|
98e3ea4e6f | ||
|
|
0e8bcb4df6 | ||
|
|
784672af5c | ||
|
|
63b9994b0e | ||
|
|
d713ea54c1 | ||
|
|
766d2699ea | ||
|
|
9af61c4744 | ||
|
|
7c8b973f30 | ||
|
|
0fdf1fadab | ||
|
|
477c49587c | ||
|
|
5532f14efb | ||
|
|
b08c335bb4 | ||
|
|
c7670e5cc8 | ||
|
|
a725789045 | ||
|
|
5d5c95dcd8 | ||
|
|
4d8c910f0d | ||
|
|
4b4b0335e8 | ||
|
|
ac3432c54f | ||
|
|
ea52537423 | ||
|
|
c9bdaf2f40 | ||
|
|
2b629185b9 | ||
|
|
a97e3ea092 | ||
|
|
7af2aa4266 | ||
|
|
1550b75548 | ||
|
|
b7f6ee12ee | ||
|
|
79539760da | ||
|
|
dc73d61682 | ||
|
|
6430b864b4 | ||
|
|
ec588037a0 | ||
|
|
0b7854a0af | ||
|
|
0273adc61c | ||
|
|
d6472088cb | ||
|
|
0c133b7ccd | ||
|
|
0bf228d29d | ||
|
|
a6826e6a4e | ||
|
|
ed0f8c471b | ||
|
|
ad38f51d6b | ||
|
|
d1e2881347 | ||
|
|
222f6ce7d8 | ||
|
|
39d09c2956 | ||
|
|
2b531afe49 | ||
|
|
5a1a6b47a5 | ||
|
|
134c441754 | ||
|
|
00fc8b2f53 | ||
|
|
5f0ae3a75e | ||
|
|
3ebd06a3a7 | ||
|
|
2eb7f57a4c | ||
|
|
7cbfeb2377 | ||
|
|
fcbea077b7 | ||
|
|
da54f3a302 | ||
|
|
efdb4d1b28 | ||
|
|
9190699cd1 | ||
|
|
4f107a7cc8 | ||
|
|
b26bf2a019 | ||
|
|
a74f04a149 | ||
|
|
cde267c55f | ||
|
|
f7b78721c3 | ||
|
|
7e6cd47712 | ||
|
|
4de4044a3e | ||
|
|
052e1ca8e4 | ||
|
|
bd4d493f34 | ||
|
|
7daeb17d85 | ||
|
|
2b5528c0ac | ||
|
|
cb15b711b9 | ||
|
|
9319b47fad | ||
|
|
23487b7ae0 | ||
|
|
fec109712b | ||
|
|
737bcb5c62 | ||
|
|
b6b5529d19 | ||
|
|
2bd4a41cbe | ||
|
|
0245c8db80 | ||
|
|
4c64b1769d | ||
|
|
ee9eced2f1 | ||
|
|
2109d323ae | ||
|
|
fd4d162287 | ||
|
|
617692616c | ||
|
|
014dc2884c | ||
|
|
d37954e6bc | ||
|
|
284c272001 | ||
|
|
0fb9d18b30 | ||
|
|
5d34bc5c56 | ||
|
|
ad7cce72f4 | ||
|
|
c52ccaf75f | ||
|
|
c661bc4764 | ||
|
|
8a375e022c | ||
|
|
7cc037c683 | ||
|
|
068d0af4ca | ||
|
|
8f117d79f2 | ||
|
|
47c4e84fdd | ||
|
|
e00aa42f94 | ||
|
|
72ead2970c | ||
|
|
5fe5523d13 | ||
|
|
3ec0964a01 | ||
|
|
a5745af484 | ||
|
|
c3e4e1a764 | ||
|
|
b07c47551c | ||
|
|
9e0846961f | ||
|
|
71dc9df7ff | ||
|
|
6edb627145 | ||
|
|
07f51c5d94 | ||
|
|
5d02550874 | ||
|
|
2ff6474f0f | ||
|
|
c4eb4d9b95 | ||
|
|
7866aee1de | ||
|
|
cdddd8e080 | ||
|
|
407b60a14f | ||
|
|
b989d08385 | ||
|
|
f46488cb9c | ||
|
|
34ff80e26c | ||
|
|
195e34563d | ||
|
|
29dab5a312 | ||
|
|
9e9c398177 | ||
|
|
1f0eeb25e6 | ||
|
|
3c1ff5242c | ||
|
|
9076acc52e | ||
|
|
f5eeeebeba | ||
|
|
22bb15583d | ||
|
|
bedf06b864 | ||
|
|
cb8636e967 | ||
|
|
36a0d78f08 | ||
|
|
23d6ba0466 | ||
|
|
6685bd0e0e | ||
|
|
c857ae3e14 | ||
|
|
93130baf0a | ||
|
|
3653164924 | ||
|
|
ca0127cc87 | ||
|
|
092666f9d2 | ||
|
|
7b97e2039f | ||
|
|
e168e31a8f | ||
|
|
3ee601574c | ||
|
|
0ee9fec1d2 | ||
|
|
9069dccb2a | ||
|
|
3c055e2482 | ||
|
|
28718094e4 | ||
|
|
9b23265c3b | ||
|
|
9f61bce039 | ||
|
|
1f49f9b454 | ||
|
|
51229204c9 | ||
|
|
2831eecbeb | ||
|
|
b2a18f9ae4 | ||
|
|
5a06e7b8bc | ||
|
|
f303d9e576 | ||
|
|
b76c4edc4a | ||
|
|
41da9b62c2 | ||
|
|
9128955bf9 | ||
|
|
f50773711e | ||
|
|
23784f614b | ||
|
|
7b27b7fd16 | ||
|
|
6834d8b2c7 | ||
|
|
4322f8a3c1 | ||
|
|
0f3a4e4c15 | ||
|
|
f4423e121e | ||
|
|
e5b67438d9 | ||
|
|
7b1ece8b83 | ||
|
|
6d5cda5d51 | ||
|
|
1af3a0ef59 | ||
|
|
5a585839ba | ||
|
|
fcf6e14ac9 | ||
|
|
0959c4ace4 |
@@ -73,6 +73,7 @@ test_*
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
rust/**/target/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
@@ -81,4 +82,4 @@ docker-compose*
|
||||
|
||||
# Other
|
||||
app.ico
|
||||
frozen.spec
|
||||
frozen.spec
|
||||
|
||||
6
.github/workflows/beta.yml
vendored
6
.github/workflows/beta.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v7
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
@@ -56,5 +56,5 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-to: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-from: type=gha,scope=moviepilot-docker,version=2
|
||||
cache-to: type=gha,scope=moviepilot-docker,mode=max,version=2
|
||||
|
||||
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v7
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
@@ -66,8 +66,8 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-to: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-from: type=gha,scope=moviepilot-docker,version=2
|
||||
cache-to: type=gha,scope=moviepilot-docker,mode=max,version=2
|
||||
|
||||
- name: Generate Changelog
|
||||
id: changelog
|
||||
|
||||
127
.github/workflows/issues.yml
vendored
127
.github/workflows/issues.yml
vendored
@@ -2,13 +2,138 @@ name: Close inactive issues
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
schedule:
|
||||
# Github Action 只支持 UTC 时间。
|
||||
# '0 18 * * *' 对应 UTC 时间的 18:00,也就是中国时区 (UTC+8) 的第二天凌晨 02:00。
|
||||
- cron: "0 18 * * *"
|
||||
|
||||
jobs:
|
||||
label-opened-issue:
|
||||
if: github.event_name == 'issues'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const issue = context.payload.issue;
|
||||
const title = issue.title || '';
|
||||
const body = issue.body || '';
|
||||
const currentLabels = (issue.labels || []).map((label) => label.name);
|
||||
|
||||
// 网页 Issue Form 已经会自动带模板 labels;这里只兜底处理
|
||||
// API 创建或异常路径产生的无 label issue,避免重复补标。
|
||||
if (currentLabels.length > 0) {
|
||||
core.info(`Issue #${issue.number} already has labels: ${currentLabels.join(', ')}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const hasAllMarkers = (markers) => markers.every((marker) => body.includes(marker));
|
||||
const labelRules = [
|
||||
{
|
||||
label: 'bug',
|
||||
titlePrefix: '[错误报告]:',
|
||||
markers: ['### 当前程序版本', '### 运行环境', '### 问题类型', '### 问题描述'],
|
||||
},
|
||||
{
|
||||
label: 'feature request',
|
||||
titlePrefix: '[Feature Request]:',
|
||||
markers: ['### 当前程序版本', '### 运行环境', '### 功能改进类型', '### 功能改进'],
|
||||
},
|
||||
{
|
||||
label: 'RFC',
|
||||
titlePrefix: '[RFC]',
|
||||
markers: ['### 背景 or 问题', '### 目标 & 方案简述'],
|
||||
},
|
||||
];
|
||||
|
||||
const matched = labelRules.find((rule) => (
|
||||
title.startsWith(rule.titlePrefix) || hasAllMarkers(rule.markers)
|
||||
));
|
||||
|
||||
if (!matched) {
|
||||
core.info(`Issue #${issue.number} does not match known issue templates.`);
|
||||
return;
|
||||
}
|
||||
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issue.number,
|
||||
labels: [matched.label],
|
||||
});
|
||||
core.info(`Added label "${matched.label}" to issue #${issue.number}.`);
|
||||
|
||||
label-unlabeled-issues:
|
||||
if: github.event_name != 'issues'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const labelRules = [
|
||||
{
|
||||
label: 'bug',
|
||||
titlePrefix: '[错误报告]:',
|
||||
markers: ['### 当前程序版本', '### 运行环境', '### 问题类型', '### 问题描述'],
|
||||
},
|
||||
{
|
||||
label: 'feature request',
|
||||
titlePrefix: '[Feature Request]:',
|
||||
markers: ['### 当前程序版本', '### 运行环境', '### 功能改进类型', '### 功能改进'],
|
||||
},
|
||||
{
|
||||
label: 'RFC',
|
||||
titlePrefix: '[RFC]',
|
||||
markers: ['### 背景 or 问题', '### 目标 & 方案简述'],
|
||||
},
|
||||
];
|
||||
|
||||
const hasAllMarkers = (body, markers) => markers.every((marker) => body.includes(marker));
|
||||
const getMatchedRule = (issue) => {
|
||||
const title = issue.title || '';
|
||||
const body = issue.body || '';
|
||||
return labelRules.find((rule) => (
|
||||
title.startsWith(rule.titlePrefix) || hasAllMarkers(body, rule.markers)
|
||||
));
|
||||
};
|
||||
|
||||
// Search API 支持 no:label 查询;issues.listForRepo 的 labels=none
|
||||
// 会被当作名为 none 的标签,不能用于扫描无 label issue。
|
||||
const query = `repo:${context.repo.owner}/${context.repo.repo} is:issue is:open no:label`;
|
||||
for await (const response of github.paginate.iterator(github.rest.search.issuesAndPullRequests, {
|
||||
q: query,
|
||||
per_page: 100,
|
||||
})) {
|
||||
for (const issue of response.data) {
|
||||
if (issue.pull_request) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const matched = getMatchedRule(issue);
|
||||
if (!matched) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issue.number,
|
||||
labels: [matched.label],
|
||||
});
|
||||
core.info(`Added label "${matched.label}" to issue #${issue.number}.`);
|
||||
}
|
||||
}
|
||||
|
||||
close-issues:
|
||||
if: github.event_name != 'issues'
|
||||
needs: label-unlabeled-issues
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
@@ -30,4 +155,4 @@ jobs:
|
||||
# 排除带有RFC标签的issue
|
||||
exempt-issue-labels: "RFC"
|
||||
operations-per-run: 500
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
48
.github/workflows/test.yml
vendored
Normal file
48
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
# 指向 v2 的 PR 与推送都跑全量单测,作为合并门禁
|
||||
pull_request:
|
||||
branches:
|
||||
- v2
|
||||
push:
|
||||
branches:
|
||||
- v2
|
||||
# 允许手动触发
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
name: Unit Tests
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.in', '**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
# 用 requirements.in 还原 CI / 全新环境(含 pytest~=8.4 与 moviepilot-rust 等可选扩展),
|
||||
# 与本地"干净 venv 复现"一致;测试运行器 pytest 已在 requirements.in 中声明。
|
||||
pip install -r requirements.in
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
# tests/run.py 以 pytest 跑 tests 全量;tests/conftest.py 在收集前把 CONFIG_DIR
|
||||
# 指向临时库并建表,测试杜绝真实网络/外部服务(详见 docs/testing.md)。
|
||||
python tests/run.py
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,6 +6,7 @@
|
||||
build/
|
||||
cython_cache/
|
||||
dist/
|
||||
rust/**/target/
|
||||
nginx/
|
||||
test.py
|
||||
safety_report.txt
|
||||
@@ -21,6 +22,7 @@ config/user.db*
|
||||
config/sites/**
|
||||
config/agent/
|
||||
config/logs/
|
||||
config/plugins/
|
||||
config/temp/
|
||||
config/cache/
|
||||
.runtime/
|
||||
@@ -39,3 +41,6 @@ pylint-report.json
|
||||
.claude/
|
||||
!.claude/*.json
|
||||
.claude/settings.local.json
|
||||
|
||||
# Superpowers 设计/计划文档(本地协作产物,不纳入仓库)
|
||||
docs/superpowers/
|
||||
|
||||
32
.pylintrc
32
.pylintrc
@@ -5,38 +5,30 @@ init-hook='import sys; sys.path.append(".")'
|
||||
# 忽略的文件和目录
|
||||
ignore=.git,__pycache__,.venv,build,dist,tests,docs
|
||||
|
||||
# 通过 `pylint app/` 检查主程序时不扫描内置插件目录,
|
||||
# 插件依赖和动态模型较多,容易产生与主程序无关的误报。
|
||||
ignore-paths=^app/plugins(/|$)
|
||||
|
||||
# 并行作业数量
|
||||
jobs=0
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
# 只关注错误级别的问题,禁用警告、约定和重构建议
|
||||
# E = Error (错误) - 会导致构建失败
|
||||
# W = Warning (警告) - 仅显示,不会失败
|
||||
# R = Refactor (重构建议) - 仅显示,不会失败
|
||||
# C = Convention (约定) - 仅显示,不会失败
|
||||
# I = Information (信息) - 仅显示,不会失败
|
||||
|
||||
# 禁用大部分警告、约定和重构建议,只保留错误和重要警告
|
||||
# 只启用确定性较强的严重问题检查,避免 SQLAlchemy、FastAPI 依赖注入、
|
||||
# 第三方 SDK 等动态对象被 Pylint 推断成误报。
|
||||
disable=all
|
||||
enable=E,
|
||||
syntax-error,
|
||||
enable=syntax-error,
|
||||
undefined-variable,
|
||||
used-before-assignment,
|
||||
possibly-used-before-assignment,
|
||||
unreachable,
|
||||
return-outside-function,
|
||||
yield-outside-function,
|
||||
continue-in-finally,
|
||||
nonlocal-without-binding,
|
||||
undefined-loop-variable,
|
||||
redefined-builtin,
|
||||
not-callable,
|
||||
assignment-from-no-return,
|
||||
no-value-for-parameter,
|
||||
too-many-function-args,
|
||||
unexpected-keyword-arg,
|
||||
redundant-keyword-arg,
|
||||
import-error,
|
||||
relative-beyond-top-level
|
||||
relative-beyond-top-level,
|
||||
no-name-in-module
|
||||
|
||||
[REPORTS]
|
||||
# 设置报告格式
|
||||
@@ -80,4 +72,6 @@ ignore-imports=yes
|
||||
|
||||
[TYPECHECK]
|
||||
# 生成缺失成员提示的类列表
|
||||
generated-members=requests.packages.urllib3
|
||||
generated-members=requests.packages.urllib3
|
||||
# app.helper.sites 会主动隐藏模块属性枚举,避免误报 no-name-in-module
|
||||
ignored-modules=app.helper.sites
|
||||
|
||||
195
AGENTS.md
195
AGENTS.md
@@ -1,152 +1,107 @@
|
||||
# MoviePilot AI Agent Guide
|
||||
# AGENTS.md
|
||||
|
||||
This file defines the default behavior for AI agents working in the MoviePilot repository. Unless a deeper directory provides another `AGENTS.md`, these rules apply to the entire repo.
|
||||
This file is the primary instruction set for all AI agents and LLMs working in this repository. Local documentation takes precedence over general training data. You must follow this file and the rule documents it references.
|
||||
|
||||
## 1. Project Scope
|
||||
---
|
||||
|
||||
- This repository contains the MoviePilot backend, CLI, MCP/API, Docker assets, and AI skills.
|
||||
- The backend is based on FastAPI, with most code under `app/`.
|
||||
- Frontend source code is not in this repository. The frontend source repository is `MoviePilot-Frontend`.
|
||||
- This repository also includes the local CLI, database migrations, developer docs, tests, Docker scripts, and AI skills.
|
||||
## Task-to-Documentation Mapping
|
||||
|
||||
## 2. Working Principles
|
||||
Before executing any task, identify the domain and load the corresponding document.
|
||||
|
||||
- Read the relevant implementation, tests, and docs before changing code. Do not infer behavior from directory names alone.
|
||||
- Prefer the smallest correct change. Reuse existing functions, patterns, and naming whenever possible.
|
||||
- Do not perform unrelated large refactors, mass renames, or formatting-only cleanup.
|
||||
- Before adding a new abstraction, check whether it is actually reusable. If the logic fits well inside an existing function, class, or flow, keep it there.
|
||||
- The worktree may contain user changes. Do not revert, overwrite, or reorganize changes you do not fully understand.
|
||||
- Default to writing conclusions, validation results, and risk notes in Chinese unless the user asks otherwise.
|
||||
### Architectural Decisions
|
||||
* **Primary Reference:** `docs/rules/05-architecture.md`
|
||||
* **Required Constraints:** Respect layer boundaries and dependency flow. Do not introduce circular dependencies. Verify the correct layer for any new capability before implementing.
|
||||
|
||||
## 3. Key Directories
|
||||
### Business Logic and Design Patterns
|
||||
* **Primary Reference:** `docs/rules/04-design-patterns.md`
|
||||
* **Required Constraints:** Use the project's established Module, Chain, Event, and Oper structural patterns. Do not introduce abstractions the project has not adopted.
|
||||
|
||||
- `app/api/endpoints/`: HTTP entrypoints. Handles auth, parameters, responses, and simple CRUD.
|
||||
- `app/chain/`: Business orchestration layer for search, recognition, subscriptions, downloads, messaging flows, and similar use cases.
|
||||
- `app/modules/`: Dynamically loaded system modules. Encapsulates pluggable downloaders, media servers, message channels, and other backend capabilities.
|
||||
- `app/helper/`: Reusable low-level helper logic. Not a place for full business orchestration.
|
||||
- `app/core/config.py`: Environment variables, deployment parameters, and startup-level settings.
|
||||
- `app/schemas/types.py`: Shared enums and types such as `SystemConfigKey` and module categories.
|
||||
- `app/db/`: Database models, sessions, and `*_oper.py` data access wrappers.
|
||||
- `moviepilot`: Local CLI entrypoint and help text.
|
||||
- `database/versions/`: Alembic migration scripts.
|
||||
- `docs/`: CLI, MCP/API, and development workflow documentation.
|
||||
- `skills/`: AI agent skills and related scripts.
|
||||
- `tests/`: Pytest tests and a few manual test scripts.
|
||||
- `config/`, `.moviepilot.env`, and `*.db`: Local config or runtime data. Do not modify or commit them unless the user explicitly asks for it.
|
||||
### Coding Standards and Style
|
||||
* **Primary Reference:** `docs/rules/06-code-styles.md`
|
||||
* **Required Constraints:** Match the style of the surrounding file. Type annotations, Pydantic models, and async/await usage must all conform to the documented standards.
|
||||
|
||||
## 4. Layering And Access Boundaries
|
||||
### Identifiers and Naming
|
||||
* **Primary Reference:** `docs/rules/07-naming-conventions.md`
|
||||
* **Required Constraints:** All filenames, class names, function names, and constants must follow the project's taxonomy. No arbitrary abbreviations or mixed casing styles.
|
||||
|
||||
### API / Endpoint Layer
|
||||
### Comments and Documentation
|
||||
* **Primary Reference:** `docs/rules/08-comment-styles.md`
|
||||
* **Required Constraints:** All public classes and methods require Chinese docstrings. Comments must explain the *why*, not restate the code.
|
||||
* **⚠️ MANDATORY GATE:** Code that is missing proper Chinese docstrings on public interfaces is **REJECTED** at review. No exceptions.
|
||||
|
||||
- Endpoints should only handle HTTP concerns: auth, parameter parsing, response models, streaming adaptation, and simple input validation.
|
||||
- Simple list, detail, toggle, settings read/write, and pure CRUD endpoints may directly call `app/db/` or an existing `helper`.
|
||||
- If the logic coordinates multiple modules, triggers events, touches caches, or combines search, recognition, subscription, or download workflows, move it into `chain`.
|
||||
- Prefer adding new endpoints to an existing domain file. Create a new endpoint file only when introducing a new top-level resource domain.
|
||||
- After adding a new endpoint, register it in `app/api/apiv1.py`.
|
||||
### External Communication and Interfaces
|
||||
* **Primary Reference:** `docs/rules/09-external-response.md`
|
||||
* **Required Constraints:** All third-party HTTP requests must go through `RequestUtils`. Response formats must use the project's standard schemas. Error handling must follow the per-layer conventions.
|
||||
|
||||
### Chain Layer
|
||||
### Data and Persistence
|
||||
* **Primary Reference:** `docs/rules/10-data-and-persistent.md`
|
||||
* **Required Constraints:** Any database model change requires a matching Alembic migration. Runtime configuration must be managed via `SystemConfigKey` + `SystemConfigOper`. Raw string keys are forbidden.
|
||||
|
||||
- `chain` is the business orchestration layer shared by API, CLI, message interaction, agents, schedulers, and similar entrypoints.
|
||||
- `chain` is responsible for composing `module`, `helper`, `db`, events, caches, and other stable `chain` capabilities.
|
||||
- Inside `chain`, prefer calling module capabilities through `run_module()` or `async_run_module()`. Only use `ModuleManager` or similar helpers directly when you truly need to enumerate modules, inspect instances, or run health checks.
|
||||
- `chain` should focus on use cases and workflows. It should not hold low-level protocol details, HTTP request objects, or page-specific parameter assembly.
|
||||
- Before adding a new `chain`, ask whether this is a reusable business use case shared by multiple entrypoints, or a flow that coordinates multiple modules or resources. If it is just short logic for one endpoint, do not create a new `chain`.
|
||||
- `chain` may call other `chain` classes when reusing stable domain logic, but avoid introducing new circular dependencies.
|
||||
### Quality and Security
|
||||
* **Primary Reference:** `docs/rules/11-quality-and-security.md`
|
||||
* **Required Constraints:** All code changes must pass the relevant pytest tests and pylint checks. Dependency changes require a passing safety scan.
|
||||
|
||||
### Module Layer
|
||||
### Testing
|
||||
* **Primary Reference:** `docs/testing.md`
|
||||
* **Required Constraints:** pytest is the only runner; `tests/conftest.py` isolates each run to a temporary `CONFIG_DIR`. Tests must not touch the real database, network, or external services (TMDB, LLM catalogs, downloaders, media servers, MP server) — mock at the boundary or replay recorded responses; the bar is zero real outbound traffic. Tests must restore any process-level state they stub (`sys.modules`, singletons, caches, settings). New tests must be pytest-native (function + `assert` + fixtures); do not add new `unittest.TestCase`. Convert existing `TestCase` files to pytest-native opportunistically when you modify them. Before opening a PR to `v2`, run the full suite locally (`python tests/run.py`) and confirm it is green with zero real network calls; the `.github/workflows/test.yml` gate runs the same suite on every PR/push to `v2`.
|
||||
|
||||
- `module` is the pluggable capability layer discovered and loaded by `ModuleManager`.
|
||||
- Put logic in `module` when it represents a new downloader, media server, message channel, recognition backend, filtering backend, file-management backend, or any other capability that needs lifecycle management, priority, configuration switches, or independent testing.
|
||||
- New modules should follow the existing base-class contract and implement or align with `init_module()`, `init_setting()`, `get_name()`, `get_type()`, `get_subtype()`, `get_priority()`, `test()`, and `stop()`.
|
||||
- A `module` should focus on one backend or one capability implementation. It should return domain results, not HTTP responses, and should not depend on endpoint auth or FastAPI request objects.
|
||||
- `chain -> module` is the intended main direction. The repository contains a small number of historical `module -> chain` usages. Do not expand that pattern in new code. If a module needs shared business logic, prefer moving that logic up into `chain` or down into `helper`.
|
||||
- Do not add direct `module -> module` coupling for new code. Cross-module orchestration should be handled by `chain`.
|
||||
### Commands and Development Workflow
|
||||
* **Primary Reference:** `docs/rules/03-commands.md`
|
||||
* **Required Constraints:** Only suggest or execute commands documented in that file. Do not assume tool defaults or global flags.
|
||||
|
||||
### Helper Layer
|
||||
---
|
||||
|
||||
- `helper` is for reusable low-level support logic such as path handling, config aggregation, site index loading, protocol wrappers, rate limiting, cache helpers, and page parsing.
|
||||
- Add a new `helper` only when the logic is reused in multiple places, or when it is clearly a standalone low-level concern.
|
||||
- If logic is used only by a single `chain` or a single `module`, prefer keeping it in the original file instead of turning `helper` into a dumping ground.
|
||||
- If the code needs configuration switches, runtime loading, priorities, independent test entrypoints, or multi-implementation dispatch, it is probably a `module`, not a `helper`.
|
||||
- `helper` must not become another orchestration layer. Full business workflows still belong in `chain`.
|
||||
## Agent Execution Rules
|
||||
|
||||
### Preferred Call Directions
|
||||
### Pre-Flight Check
|
||||
|
||||
- Preferred direction: `endpoint/CLI/agent/command -> chain -> module/helper/db`
|
||||
- Allowed direction: `chain -> chain`, as long as the reused logic is stable and does not introduce cycles.
|
||||
- Cautious direction: `endpoint -> db/model/oper/helper`, only for simple queries, simple CRUD, or input normalization.
|
||||
- Avoid for new code: `module -> chain`, `module -> module`, `helper -> chain`, `helper -> endpoint`.
|
||||
Before generating any code or proposing changes, you must:
|
||||
|
||||
## 5. Where New Capabilities Should Go
|
||||
1. Identify the task domain (architecture / business logic / coding style / naming / comments / external interfaces / data / quality).
|
||||
2. Load the corresponding document from `docs/rules/`.
|
||||
3. Explicitly verify that your proposed solution does not violate the following three mandatory constraints:
|
||||
- **Naming Conventions (07):** Are all files, classes, functions, and constants named correctly?
|
||||
- **Architecture Boundaries (05):** Is the code placed in the correct layer? Are all call directions valid?
|
||||
- **Comment Standards (08):** Do all new public classes and methods include Chinese docstrings?
|
||||
|
||||
- Scenario: adding a new business workflow such as search, recognition, subscription, download orchestration, or message interaction.
|
||||
Action: prefer `app/chain/` so API, CLI, agents, and schedulers can share the same orchestration logic.
|
||||
- Scenario: adding a new downloader, media server, message channel, or other pluggable backend integration.
|
||||
Action: put it in `app/modules/`. If this introduces a new module category or subtype, also check `app/schemas/types.py` and related schemas.
|
||||
- Scenario: adding a new public HTTP API.
|
||||
Action: put it in `app/api/endpoints/`, register it in `app/api/apiv1.py`, and add auth, schemas, docs, and tests. Move complex logic into `chain`.
|
||||
- Scenario: adding a new low-level utility, parser, config reader, or protocol wrapper.
|
||||
Action: put it in `app/helper/`, but only if it is not a one-off implementation and not a full business use case.
|
||||
- Scenario: adding a deployment-level, environment-level, or startup-time config such as ports, paths, proxies, switches, keys, or third-party service addresses.
|
||||
Action: put it in `ConfigModel` or `Settings` inside `app/core/config.py`.
|
||||
- Scenario: adding a runtime business config, user-editable rule, or persistent system option.
|
||||
Action: prefer `SystemConfigKey` plus `SystemConfigOper`. Do not scatter raw string keys.
|
||||
- Scenario: a config change should automatically reload a long-lived object.
|
||||
Action: add `CONFIG_WATCH`, `on_config_changed()`, and `get_reload_name()` where appropriate on the related `chain`, `module`, `helper`, or manager class.
|
||||
- Scenario: adding a few dozen lines of private logic inside one `chain` or `module`.
|
||||
Action: prefer a private function or private method in the same file. Do not create a new `helper` by default.
|
||||
### Implementation Guidelines
|
||||
|
||||
## 6. Code And Comment Requirements
|
||||
* **Pattern Adherence:** Avoid generic boilerplate. If `04-design-patterns.md` defines a project-level pattern for a scenario, you are required to use it.
|
||||
* **Documentation Standards:** Docstring style for any new function or module must match `08-comment-styles.md`.
|
||||
* **⚠️ MANDATORY GATE:** Public classes, methods, and functions without proper Chinese docstrings are **REJECTED**. No exceptions.
|
||||
* **Command Reliance:** Only suggest commands listed in `03-commands.md`. Do not rely on inferred tool defaults.
|
||||
* **Minimal Change Principle:** Prefer the smallest correct change. Do not perform unrelated refactors, mass renames, or formatting-only cleanup.
|
||||
* **Output Language:** Summaries, validation results, and risk notes default to Chinese unless the user requests otherwise.
|
||||
|
||||
- Preserve the existing code style. Do not introduce a new abstraction layer without a clear payoff.
|
||||
- The repository already uses short docstrings for many public classes and methods. For new public classes and methods, follow the local style of the surrounding file.
|
||||
- Comments and docstrings should default to Chinese. If the surrounding file is already consistently in English, match the local style.
|
||||
- Comments should explain why the code is written that way and what non-obvious constraints exist, such as edge cases, compatibility reasons, call ordering, cache or reload semantics, and external system limitations.
|
||||
- Do not write line-by-line translation comments. Do not comment obvious assignments, branches, or straightforward calls.
|
||||
- For complex notes, place the comment above the code block instead of using long end-of-line comments.
|
||||
- When changing code, update or remove stale comments so the documentation stays aligned with the implementation.
|
||||
- Do not add TODO or FIXME without context. Only keep one if it is genuinely useful and cannot be addressed as part of the current task.
|
||||
- Do not add noisy comments like "change starts here", "change ends here", or "this is important".
|
||||
### Conflict Resolution
|
||||
|
||||
## 7. Dependency And Environment Conventions
|
||||
If existing code appears to contradict the documentation:
|
||||
|
||||
- Target Python version is `3.11+`. Current CI uses Python `3.12`.
|
||||
- The dependency source file is `requirements.in`.
|
||||
- `requirements.txt` is the lock file generated by `pip-compile requirements.in`. Do not maintain it manually.
|
||||
- Install dependencies with `pip install -r requirements.txt`.
|
||||
- When adding or upgrading dependencies:
|
||||
1. Update `requirements.in`
|
||||
2. Run `pip-compile requirements.in`
|
||||
3. Run the relevant tests and security checks
|
||||
1. Stop implementation immediately.
|
||||
2. Identify the specific file and line of the contradiction.
|
||||
3. Prompt the user: "The documentation in `[File]` requires Pattern A, but the current implementation uses Pattern B. Which is the current standard?"
|
||||
|
||||
## 8. Coupled Updates
|
||||
---
|
||||
|
||||
- When fixing a bug, prefer adding a test that reproduces it. When adding a feature, prefer the smallest useful test coverage.
|
||||
- When changing CLI behavior, also check and update `moviepilot`, `docs/cli.md`, and related tests.
|
||||
- When changing MCP or REST API behavior, exposed tools, or AI interaction behavior, also check and update `docs/mcp-api.md`, related `skills/*/SKILL.md` files or scripts, and related tests.
|
||||
- When changing development workflow, dependency management, or security-check procedures, also update `docs/development-setup.md`.
|
||||
- When changing database structure, add an Alembic migration under `database/versions/`. Do not update models without a migration.
|
||||
- When changing user-visible config, defaults, or initialization flow, also check related docs, help text, setup or init flows, and tests.
|
||||
- When adding a new skill, follow the existing `skills/<name>/SKILL.md` structure, keep the YAML front matter, and prefer script paths relative to the `SKILL.md` file.
|
||||
## Coupled Update Rules
|
||||
|
||||
## 9. Validation Requirements
|
||||
When modifying the following, you must also update the listed artifacts:
|
||||
|
||||
- Run at least the tests directly related to the change, for example `pytest tests/test_xxx.py`.
|
||||
- If the change affects common modules, startup flow, CLI, or agent runtime behavior, expand the validation scope.
|
||||
- After Python code changes, at minimum ensure the change does not introduce new error-level issues in `pylint app/`.
|
||||
- When changing CLI behavior, validate the relevant help output such as `moviepilot help` or the specific subcommand help.
|
||||
- When changing dependencies, also run `pip-compile requirements.in` and `safety check -r requirements.txt --policy-file=safety.policy.yml`.
|
||||
- If the task only changes documentation, explicitly say that tests were not run. Do not claim checks that were not executed.
|
||||
| Changed Content | Must Also Update |
|
||||
|---|---|
|
||||
| CLI behavior | `moviepilot` entrypoint, `docs/cli.md`, related tests |
|
||||
| MCP / REST API, exposed tools | `docs/mcp-api.md`, `skills/*/SKILL.md`, related tests |
|
||||
| Dev workflow, dependency management, security checks | `docs/development-setup.md` |
|
||||
| Database model schema | New Alembic migration under `database/versions/` |
|
||||
| User-visible config or init flow | Related docs, help text, setup/init flows, tests |
|
||||
| New skill | Follow `skills/<name>/SKILL.md` structure, keep YAML front matter |
|
||||
|
||||
## 10. Commit And Release Conventions
|
||||
---
|
||||
|
||||
- Only create a commit when the user explicitly asks for one.
|
||||
- Prefer Conventional Commits such as `feat: ...`, `fix: ...`, and `docs: ...`.
|
||||
- This is not just stylistic. The release workflow uses Conventional Commits to categorize changelog entries.
|
||||
- Do not casually change version numbers, release settings, or Docker release flow unless the task explicitly involves them.
|
||||
## Primary Entry Point
|
||||
|
||||
## 11. Output Requirements
|
||||
For the full documentation map and cross-references, refer to:
|
||||
|
||||
- Result summaries should focus on three things: what changed, how it was validated, and what risks remain.
|
||||
- Do not write vague summaries. Do not describe unexecuted checks as completed.
|
||||
- If there is compatibility impact, config migration risk, or user-data risk, call it out explicitly.
|
||||
**[Documentation Hub Index](./docs/rules/README.md)**
|
||||
|
||||
*Last Updated: 2026-05-25*
|
||||
|
||||
49
README.md
49
README.md
@@ -1,4 +1,3 @@
|
||||
|
||||
# MoviePilot
|
||||
|
||||
简体中文 | [English](README_EN.md)
|
||||
@@ -12,51 +11,56 @@
|
||||

|
||||

|
||||
|
||||
|
||||
基于 [NAStool](https://github.com/NAStool/nas-tools) 部分代码重新设计,聚焦自动化核心需求,减少问题同时更易于扩展和维护。
|
||||
|
||||
# 仅用于学习交流使用,请勿在任何国内平台宣传该项目!
|
||||
|
||||
发布频道:https://t.me/moviepilot_channel
|
||||
|
||||
|
||||
## 主要特性
|
||||
|
||||
- 前后端分离,基于FastApi + Vue3。
|
||||
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
|
||||
- 重新设计了用户界面,更加美观易用。
|
||||
|
||||
- 聚焦影视自动化的核心流程:订阅、搜索、下载、整理、刮削、媒体库刷新与消息通知。
|
||||
- 前后端分离,后端基于 FastAPI,前端基于 Vue 3,部署和扩展边界更清晰。
|
||||
- 支持下载器、媒体服务器、元数据源、消息渠道、插件、工作流和 AI Agent 等能力组合。
|
||||
- 更完整的功能介绍、截图和使用入口见官网:https://movie-pilot.org
|
||||
|
||||
## 安装使用
|
||||
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
推荐优先使用 Docker 部署,常用镜像包括 `jxxghp/moviepilot-v2` 和 `jxxghp/moviepilot`。Compose 示例、环境变量、目录映射和升级方式以官方 Wiki 为准:
|
||||
|
||||
- 官方 Wiki:https://wiki.movie-pilot.org
|
||||
- PostgreSQL 部署说明:[docs/postgresql-setup.md](docs/postgresql-setup.md)
|
||||
|
||||
## 本地 CLI
|
||||
|
||||
一键安装运行脚本:
|
||||
也可以使用本地 CLI 以源码模式安装和管理 MoviePilot:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
|
||||
```
|
||||
|
||||
使用 `moviepilot` 命令管理MoviePilot,完整 CLI 文档:[`docs/cli.md`](docs/cli.md)
|
||||
安装完成后使用 `moviepilot` 命令完成初始化、启动、停止、更新和配置查看。完整命令见 [docs/cli.md](docs/cli.md)。
|
||||
|
||||
## Agent
|
||||
|
||||
1. MoviePilot 自带智能体能力,可在完成模型配置后,通过自然语言调用系统工具,辅助完成搜索、订阅、下载、整理、排障等管理任务。
|
||||
2. 其它智能体可以导入本仓库的 `skills/` 目录以获得 MoviePilot 操作能力;支持 `skills` CLI 的环境可使用:
|
||||
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
|
||||
内置 Skills 列表见 [skills/](skills/),自定义 Skill 可参考 [skills/create-moviepilot-skill/SKILL.md](skills/create-moviepilot-skill/SKILL.md)。
|
||||
3. 其它 MCP 客户端可以通过 MoviePilot 的 MCP 端点 `/api/v1/mcp` 调用工具,认证方式、客户端配置和工具 API 见 [docs/mcp-api.md](docs/mcp-api.md)。
|
||||
|
||||
## 为 AI Agent 添加 Skills
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
|
||||
## 参与开发
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
开发前请先阅读仓库规则和本地环境说明,保持变更聚焦,通过测试后再提交 PR。常用入口:
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
开发环境准备与本地源码运行说明:[`docs/development-setup.md`](docs/development-setup.md)
|
||||
|
||||
插件开发说明:<https://wiki.movie-pilot.org/zh/plugindev>
|
||||
- 文档规则入口:[docs/rules/README.md](docs/rules/README.md)
|
||||
- 开发环境与本地源码运行:[docs/development-setup.md](docs/development-setup.md)
|
||||
- 测试说明:[docs/testing.md](docs/testing.md)
|
||||
- REST API 文档:https://api.movie-pilot.org
|
||||
- 插件开发说明:https://wiki.movie-pilot.org/zh/plugindev
|
||||
|
||||
## 相关项目
|
||||
|
||||
@@ -64,6 +68,7 @@ MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
|
||||
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
|
||||
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
|
||||
- [MoviePilot-Rust](https://github.com/jxxghp/MoviePilot-Rust)
|
||||
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
|
||||
|
||||
## 免责申明
|
||||
|
||||
48
README_EN.md
48
README_EN.md
@@ -17,44 +17,49 @@ Redesigned from parts of [NAStool](https://github.com/NAStool/nas-tools), with a
|
||||
|
||||
Release channel: https://t.me/moviepilot_channel
|
||||
|
||||
|
||||
## Key Features
|
||||
|
||||
- Frontend/backend separation based on FastApi + Vue3.
|
||||
- Focuses on core needs, simplifies features and settings, and allows some options to work well with sensible defaults.
|
||||
- Reworked user interface for a cleaner and more practical experience.
|
||||
- Focuses on the core media automation flow: subscriptions, search, downloads, file organization, scraping, media server refresh, and notifications.
|
||||
- Uses a separated backend/frontend architecture: FastAPI for the backend and Vue 3 for the frontend.
|
||||
- Connects download clients, media servers, metadata providers, message channels, plugins, workflows, and AI Agent capabilities.
|
||||
- For feature details, screenshots, and product entry points, see https://movie-pilot.org
|
||||
|
||||
## Installation and Usage
|
||||
|
||||
## Installation
|
||||
Docker is the recommended deployment model. Common images include `jxxghp/moviepilot-v2` and `jxxghp/moviepilot`. Compose examples, environment variables, volume mappings, and upgrade notes are maintained in the official wiki:
|
||||
|
||||
Official wiki: https://wiki.movie-pilot.org
|
||||
- Official wiki: https://wiki.movie-pilot.org
|
||||
- PostgreSQL setup: [docs/postgresql-setup.md](docs/postgresql-setup.md)
|
||||
|
||||
|
||||
## Local CLI
|
||||
|
||||
One-command bootstrap script:
|
||||
MoviePilot can also be installed and managed from source with the local CLI:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
|
||||
```
|
||||
|
||||
Manage MoviePilot with the `moviepilot` command. Full CLI documentation: [`docs/cli.md`](docs/cli.md)
|
||||
After installation, use the `moviepilot` command for initialization, service management, updates, and configuration. See [docs/cli.md](docs/cli.md) for the full command reference.
|
||||
|
||||
## Agent
|
||||
|
||||
## Add Skills for AI Agents
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
1. MoviePilot includes a built-in AI Agent. After model configuration, it can call system tools through natural language to help with search, subscriptions, downloads, organization, diagnostics, and other management tasks.
|
||||
2. Other agents can import the repository `skills/` directory to gain MoviePilot operation capabilities. Environments that support the `skills` CLI can use:
|
||||
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
|
||||
Built-in skills live in [skills/](skills/). For custom skill authoring, see [skills/create-moviepilot-skill/SKILL.md](skills/create-moviepilot-skill/SKILL.md).
|
||||
3. Other MCP clients can call MoviePilot tools through `/api/v1/mcp`. Authentication, client configuration, and tool APIs are documented in [docs/mcp-api.md](docs/mcp-api.md).
|
||||
|
||||
## Development
|
||||
|
||||
API documentation: https://api.movie-pilot.org
|
||||
Before contributing, read the repository rules and local environment guide, keep changes focused, and validate them before opening a PR. Useful entry points:
|
||||
|
||||
MCP tool API documentation: see [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
Development environment setup and local source-run guide: [`docs/development-setup.md`](docs/development-setup.md)
|
||||
|
||||
Plugin development guide: <https://wiki.movie-pilot.org/zh/plugindev>
|
||||
- Rule index: [docs/rules/README.md](docs/rules/README.md)
|
||||
- Development setup and local source run: [docs/development-setup.md](docs/development-setup.md)
|
||||
- Testing guide: [docs/testing.md](docs/testing.md)
|
||||
- REST API documentation: https://api.movie-pilot.org
|
||||
- Plugin development guide: https://wiki.movie-pilot.org/zh/plugindev
|
||||
|
||||
## Related Projects
|
||||
|
||||
@@ -62,6 +67,7 @@ Plugin development guide: <https://wiki.movie-pilot.org/zh/plugindev>
|
||||
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
|
||||
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
|
||||
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
|
||||
- [MoviePilot-Rust](https://github.com/jxxghp/MoviePilot-Rust)
|
||||
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
@@ -4,10 +4,11 @@ import re
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
@@ -35,6 +36,12 @@ from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.subagents import (
|
||||
SUBAGENT_CONTROL_TOOL_NAME,
|
||||
SUBAGENT_TASK_TOOL_NAME,
|
||||
create_subagent_middlewares,
|
||||
is_subagent_stream_metadata,
|
||||
)
|
||||
from app.agent.middleware.tool_selection import ToolSelectorMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
@@ -42,10 +49,11 @@ from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import ChainEventType, EventType, MessageChannel
|
||||
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
|
||||
|
||||
|
||||
@@ -53,6 +61,54 @@ class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
def _finish_processing_status(status: Optional[dict], user_id: Optional[str] = None) -> None:
|
||||
"""结束入站消息的渠道处理状态。"""
|
||||
if not status:
|
||||
return
|
||||
AgentChain().finish_message_processing_status(
|
||||
status=status,
|
||||
userid=user_id,
|
||||
)
|
||||
|
||||
|
||||
async def _async_start_processing_status(task: "_MessageTask") -> Optional[dict]:
|
||||
"""
|
||||
在 Agent worker 中启动渠道处理状态。
|
||||
渠道启动可能触发外部 API,同步实现需切到线程池避免阻塞事件循环。
|
||||
"""
|
||||
if not task.channel:
|
||||
return None
|
||||
|
||||
def _start() -> Optional[dict]:
|
||||
"""在线程池中通过统一 Chain 接口启动处理状态。"""
|
||||
try:
|
||||
return AgentChain().start_message_processing_status(
|
||||
channel=MessageChannel(task.channel),
|
||||
source=task.source,
|
||||
userid=task.user_id,
|
||||
message_id=task.original_message_id,
|
||||
chat_id=task.original_chat_id,
|
||||
text=task.message,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"启动Agent消息处理状态失败: {err}")
|
||||
return None
|
||||
|
||||
return await run_in_threadpool(_start)
|
||||
|
||||
|
||||
async def _async_finish_processing_status(
|
||||
status: Optional[dict], user_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
在 Agent worker 中结束渠道处理状态。
|
||||
渠道收口可能触发外部 API,同步实现需切到线程池避免阻塞事件循环。
|
||||
"""
|
||||
if not status:
|
||||
return
|
||||
await run_in_threadpool(_finish_processing_status, status, user_id)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionUsageSnapshot:
|
||||
model: Optional[str] = None
|
||||
@@ -169,6 +225,9 @@ class ReplyMode(str, Enum):
|
||||
|
||||
|
||||
HEARTBEAT_SESSION_PREFIX = "__agent_heartbeat_"
|
||||
UNSUPPORTED_IMAGE_INPUT_MESSAGE = "当前模型不支持图片输入,请更换支持图片输入的模型,或在系统设置中关闭图片输入支持后重试。"
|
||||
AGENT_EXECUTION_ERROR_PREFIX = "智能助手执行失败"
|
||||
AGENT_EXECUTION_ERROR_MESSAGE = "智能助手执行失败,请稍后重试。"
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
@@ -204,6 +263,9 @@ class MoviePilotAgent:
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
self._llm_runtime_config: Optional[Dict[str, Any]] = None
|
||||
self._llm_provider_selection: Dict[str, Any] = {}
|
||||
self._agent_started_at: Optional[datetime] = None
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -289,6 +351,40 @@ class MoviePilotAgent:
|
||||
)
|
||||
return self._session_usage.to_dict(self.session_id)
|
||||
|
||||
def _send_agent_tokens_usage_event(
|
||||
self,
|
||||
*,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
广播本次 Agent 执行的 token 聚合用量,供配额类插件异步记录。
|
||||
"""
|
||||
try:
|
||||
selection = self._llm_provider_selection or {}
|
||||
event_data = AgentTokensUsageEventData(
|
||||
session_id=self.session_id,
|
||||
selected_provider_id=selection.get("selected_provider_id"),
|
||||
selected_provider_name=selection.get("selected_provider_name"),
|
||||
provider=selection.get("provider") or settings.LLM_PROVIDER,
|
||||
base_url=selection.get("base_url") or settings.LLM_BASE_URL,
|
||||
model=self._session_usage.model or selection.get("model") or settings.LLM_MODEL,
|
||||
input_tokens=self._session_usage.total_input_tokens,
|
||||
output_tokens=self._session_usage.total_output_tokens,
|
||||
total_tokens=self._session_usage.total_tokens,
|
||||
model_call_count=self._session_usage.model_call_count,
|
||||
success=success,
|
||||
error=error,
|
||||
started_at=self._agent_started_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self._agent_started_at
|
||||
else None,
|
||||
finished_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
source=selection.get("source") or "agent",
|
||||
)
|
||||
eventmanager.send_event(EventType.AgentTokensUsage, event_data)
|
||||
except Exception as err:
|
||||
logger.debug(f"广播 Agent Tokens 用量事件失败: {err}")
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""
|
||||
@@ -336,12 +432,124 @@ class MoviePilotAgent:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _initialize_llm(streaming: bool = False):
|
||||
def _get_event_value(event_data: Any, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从链式事件数据中兼容读取 Pydantic 模型或普通字典字段。
|
||||
"""
|
||||
if isinstance(event_data, dict):
|
||||
return event_data.get(key, default)
|
||||
return getattr(event_data, key, default)
|
||||
|
||||
@staticmethod
|
||||
def _set_event_value(event_data: Any, key: str, value: Any) -> None:
|
||||
"""
|
||||
向链式事件数据中兼容写入 Pydantic 模型或普通字典字段。
|
||||
"""
|
||||
if isinstance(event_data, dict):
|
||||
event_data[key] = value
|
||||
else:
|
||||
setattr(event_data, key, value)
|
||||
|
||||
@classmethod
|
||||
def _clean_optional_text(cls, value: Any) -> Optional[str]:
|
||||
"""
|
||||
标准化事件返回的可选文本字段,空字符串按未返回处理。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
|
||||
async def _resolve_llm_runtime_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
通过链式事件解析本次 Agent 可用的 LLM 运行时配置。
|
||||
|
||||
若没有插件返回 selected_provider_id,则沿用系统配置,保持既有行为。
|
||||
"""
|
||||
if self._llm_runtime_config is not None:
|
||||
return self._llm_runtime_config
|
||||
|
||||
event_data = AgentLLMProviderEventData(
|
||||
provider=settings.LLM_PROVIDER,
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=settings.LLM_API_KEY,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
base_url_preset=settings.LLM_BASE_URL_PRESET,
|
||||
user_agent=settings.LLM_USER_AGENT,
|
||||
use_proxy=settings.LLM_USE_PROXY,
|
||||
thinking_level=None,
|
||||
)
|
||||
selected_event = await eventmanager.async_send_event(
|
||||
ChainEventType.AgentLLMProvider,
|
||||
event_data,
|
||||
)
|
||||
resolved_data = selected_event.event_data if selected_event else event_data
|
||||
|
||||
provider = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "provider"))
|
||||
or settings.LLM_PROVIDER
|
||||
)
|
||||
model = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "model"))
|
||||
or settings.LLM_MODEL
|
||||
)
|
||||
api_key = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "api_key"))
|
||||
or settings.LLM_API_KEY
|
||||
)
|
||||
base_url = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "base_url"))
|
||||
or settings.LLM_BASE_URL
|
||||
)
|
||||
base_url_preset = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "base_url_preset"))
|
||||
or settings.LLM_BASE_URL_PRESET
|
||||
)
|
||||
user_agent = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "user_agent"))
|
||||
or settings.LLM_USER_AGENT
|
||||
)
|
||||
use_proxy = self._get_event_value(resolved_data, "use_proxy")
|
||||
if use_proxy is None:
|
||||
use_proxy = settings.LLM_USE_PROXY
|
||||
thinking_level = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "thinking_level")
|
||||
)
|
||||
selected_provider_id = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "selected_provider_id")
|
||||
)
|
||||
selected_provider_name = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "selected_provider_name")
|
||||
)
|
||||
source = self._clean_optional_text(self._get_event_value(resolved_data, "source"))
|
||||
|
||||
self._llm_provider_selection = {
|
||||
"selected_provider_id": selected_provider_id,
|
||||
"selected_provider_name": selected_provider_name,
|
||||
"provider": provider,
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"source": source,
|
||||
}
|
||||
self._llm_runtime_config = {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
"base_url_preset": base_url_preset,
|
||||
"user_agent": user_agent,
|
||||
"use_proxy": bool(use_proxy),
|
||||
"thinking_level": thinking_level,
|
||||
}
|
||||
return self._llm_runtime_config
|
||||
|
||||
async def _initialize_llm(self, streaming: bool = False):
|
||||
"""
|
||||
初始化 LLM
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
return await LLMHelper.get_llm(streaming=streaming)
|
||||
runtime_config = await self._resolve_llm_runtime_config()
|
||||
return await LLMHelper.get_llm(streaming=streaming, **runtime_config)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
@@ -376,6 +584,165 @@ class MoviePilotAgent:
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
@classmethod
|
||||
def _has_image_input_content(cls, content: Any) -> bool:
|
||||
"""
|
||||
检查消息内容里是否包含真正会发给模型的图片块。
|
||||
结构化 JSON 文本里的 images 字段只是给 Agent 阅读的说明,不能作为图片输入判断。
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
return any(cls._has_image_input_content(item) for item in content)
|
||||
if not isinstance(content, dict):
|
||||
return False
|
||||
|
||||
block_type = str(content.get("type") or "").lower()
|
||||
if block_type in {"image", "image_url", "input_image"}:
|
||||
return True
|
||||
if content.get("image_url") or content.get("image"):
|
||||
return True
|
||||
return any(cls._has_image_input_content(value) for value in content.values())
|
||||
|
||||
@classmethod
|
||||
def _messages_have_image_input(cls, messages: List[BaseMessage]) -> bool:
|
||||
"""检查本轮提交给模型的消息列表中是否包含图片输入。"""
|
||||
return any(
|
||||
cls._has_image_input_content(getattr(message, "content", None))
|
||||
for message in messages or []
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _exception_detail_text(error: Exception) -> str:
|
||||
"""
|
||||
提取异常对象里可用于匹配的文本。
|
||||
OpenAI 兼容端点的错误详情可能藏在 body/code/status_code 等属性中。
|
||||
"""
|
||||
parts = [str(error)]
|
||||
for attr in ("message", "code", "status_code"):
|
||||
value = getattr(error, attr, None)
|
||||
if value is not None:
|
||||
parts.append(str(value))
|
||||
body = getattr(error, "body", None)
|
||||
if body is not None:
|
||||
try:
|
||||
parts.append(json.dumps(body, ensure_ascii=False))
|
||||
except (TypeError, ValueError):
|
||||
parts.append(str(body))
|
||||
return " ".join(part for part in parts if part)
|
||||
|
||||
@classmethod
|
||||
def _is_unsupported_image_input_error(cls, error: Exception) -> bool:
|
||||
"""
|
||||
判断模型服务是否在拒绝图片输入。
|
||||
兼容 OpenAI 及 OpenAI-compatible 服务常见的错误文案,避免把普通 404 当作图片能力问题。
|
||||
"""
|
||||
detail = cls._exception_detail_text(error).lower()
|
||||
if "no endpoints found that support image input" in detail:
|
||||
return True
|
||||
if "unknown variant" in detail and "image_url" in detail:
|
||||
return True
|
||||
if "image input" not in detail and "images" not in detail:
|
||||
return False
|
||||
return any(
|
||||
marker in detail
|
||||
for marker in (
|
||||
"does not support",
|
||||
"do not support",
|
||||
"not support",
|
||||
"not supported",
|
||||
"unsupported",
|
||||
"no endpoint",
|
||||
"no endpoints",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _payload_error_message(payload: Any) -> str:
|
||||
"""
|
||||
从 SDK 返回的结构化错误体里提取 message 字段。
|
||||
许多 OpenAI-compatible 服务会把真正原因放在 body.error.message 中。
|
||||
"""
|
||||
if isinstance(payload, dict):
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict) and error.get("message"):
|
||||
return str(error["message"])
|
||||
for key in ("message", "detail", "error_description"):
|
||||
if payload.get(key):
|
||||
return str(payload[key])
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_execution_error_message(message: str) -> str:
|
||||
"""
|
||||
清理执行错误中的密钥和尾部长说明,避免把敏感字段或 SDK 调参文档直接发给用户。
|
||||
"""
|
||||
sanitized = re.sub(r"\s+", " ", str(message or "")).strip()
|
||||
if settings.LLM_API_KEY:
|
||||
sanitized = sanitized.replace(settings.LLM_API_KEY, "***")
|
||||
sanitized = re.sub(
|
||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||
r"\1***",
|
||||
sanitized,
|
||||
)
|
||||
sanitized = re.sub(
|
||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||
"Authorization: ***",
|
||||
sanitized,
|
||||
)
|
||||
for marker in (
|
||||
" Tune or disable via ",
|
||||
" See also ",
|
||||
" Traceback ",
|
||||
" - Traceback ",
|
||||
):
|
||||
if marker in sanitized:
|
||||
sanitized = sanitized.split(marker, 1)[0].strip()
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def _primary_exception_message(cls, error: Exception) -> str:
|
||||
"""
|
||||
从异常对象中抽取最主要的错误消息。
|
||||
优先使用结构化 message,其次回退到异常字符串,保持用户回复直接反映真实失败原因。
|
||||
"""
|
||||
candidates = [
|
||||
getattr(error, "message", None),
|
||||
cls._payload_error_message(getattr(error, "body", None)),
|
||||
str(error),
|
||||
]
|
||||
for candidate in candidates:
|
||||
message = cls._sanitize_execution_error_message(candidate)
|
||||
if message:
|
||||
return message
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _friendly_execution_error_message(cls, error: Exception) -> str:
|
||||
"""
|
||||
将 Agent 执行异常转换为用户可读消息。
|
||||
回复只携带主错误信息,完整 traceback 保留在日志中排查。
|
||||
"""
|
||||
message = cls._primary_exception_message(error)
|
||||
if not message:
|
||||
return AGENT_EXECUTION_ERROR_MESSAGE
|
||||
return f"{AGENT_EXECUTION_ERROR_PREFIX}: {message}"
|
||||
|
||||
async def _dispatch_execution_notice(self, message: str) -> None:
|
||||
"""
|
||||
将执行层可预期的失败转成用户可读提示。
|
||||
按当前回复模式处理,避免后台捕获任务绕过 CAPTURE_ONLY 约束。
|
||||
"""
|
||||
if not message:
|
||||
return
|
||||
self._emit_output(message)
|
||||
if self._tool_context.get("user_reply_sent"):
|
||||
return
|
||||
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
if self.should_dispatch_reply:
|
||||
await self.send_agent_message(message, title=title)
|
||||
elif self.persist_output_message:
|
||||
await self._save_agent_message_to_db(message, title=title)
|
||||
|
||||
def _emit_output(self, text: str):
|
||||
"""
|
||||
输出当前流式文本到外部回调。
|
||||
@@ -413,6 +780,25 @@ class MoviePilotAgent:
|
||||
allow_message_tools=self.allow_message_tools,
|
||||
)
|
||||
|
||||
def _initialize_subagent_tools(self) -> List:
|
||||
"""
|
||||
初始化子代理专用静默工具列表。
|
||||
"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
stream_handler=None,
|
||||
agent_context={
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"should_dispatch_reply": False,
|
||||
},
|
||||
allow_message_tools=False,
|
||||
)
|
||||
|
||||
async def _create_agent(self, streaming: bool = False):
|
||||
"""
|
||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||
@@ -435,10 +821,22 @@ class MoviePilotAgent:
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
subagent_middlewares, subagent_task_tools = create_subagent_middlewares(
|
||||
model=non_streaming_model,
|
||||
tools=self._initialize_subagent_tools(),
|
||||
stream_handler=self.stream_handler,
|
||||
)
|
||||
max_tools = settings.LLM_MAX_TOOLS
|
||||
always_include_tools = (
|
||||
MoviePilotToolFactory.get_tool_selector_always_include_names(tools)
|
||||
)
|
||||
if subagent_task_tools:
|
||||
always_include_tools.extend(
|
||||
tool.name
|
||||
for tool in subagent_task_tools
|
||||
if getattr(tool, "name", None)
|
||||
in {SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME}
|
||||
)
|
||||
|
||||
# 中间件
|
||||
middlewares = [
|
||||
@@ -461,6 +859,8 @@ class MoviePilotAgent:
|
||||
),
|
||||
# 错误工具调用修复
|
||||
PatchToolCallsMiddleware(),
|
||||
# 子代理委派
|
||||
*subagent_middlewares,
|
||||
# 用量统计
|
||||
UsageMiddleware(on_usage=self._record_usage),
|
||||
]
|
||||
@@ -478,7 +878,7 @@ class MoviePilotAgent:
|
||||
middlewares.append(
|
||||
ToolSelectorMiddleware(
|
||||
model=non_streaming_model,
|
||||
selection_tools=tools,
|
||||
selection_tools=[*tools, *subagent_task_tools],
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
)
|
||||
@@ -500,6 +900,7 @@ class MoviePilotAgent:
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
has_audio_input: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
@@ -507,7 +908,8 @@ class MoviePilotAgent:
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}, "
|
||||
f"audio_input={has_audio_input}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"user_reply_sent": False,
|
||||
@@ -524,6 +926,10 @@ class MoviePilotAgent:
|
||||
# 构建结构化用户消息内容
|
||||
request_payload = {
|
||||
"message": message or "",
|
||||
"input": {
|
||||
"mode": "voice" if has_audio_input else "text",
|
||||
"transcribed": bool(has_audio_input),
|
||||
},
|
||||
"images": [
|
||||
{"index": index + 1, "type": "image"}
|
||||
for index, _ in enumerate(images or [])
|
||||
@@ -541,7 +947,10 @@ class MoviePilotAgent:
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
result = await self._execute_agent(messages)
|
||||
if isinstance(result, tuple) and result:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
@@ -572,6 +981,8 @@ class MoviePilotAgent:
|
||||
):
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if is_subagent_stream_metadata(metadata):
|
||||
continue
|
||||
if not token or not hasattr(token, "tool_call_chunks"):
|
||||
continue
|
||||
|
||||
@@ -603,6 +1014,11 @@ class MoviePilotAgent:
|
||||
- 渠道不支持消息编辑:非流式 LLM + ainvoke,完成后发送最终回复
|
||||
- 渠道支持消息编辑:流式 LLM + astream,实时推送 token
|
||||
"""
|
||||
execution_success = False
|
||||
execution_error: Optional[str] = None
|
||||
self._agent_started_at = datetime.now()
|
||||
self._llm_runtime_config = None
|
||||
self._llm_provider_selection = {}
|
||||
try:
|
||||
# Agent运行配置
|
||||
agent_config = {
|
||||
@@ -736,14 +1152,29 @@ class MoviePilotAgent:
|
||||
user_id=self.user_id,
|
||||
messages=agent.get_state(agent_config).values.get("messages", []),
|
||||
)
|
||||
execution_success = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
execution_error = "任务已取消"
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
execution_error = str(e)
|
||||
if self._messages_have_image_input(messages) and self._is_unsupported_image_input_error(e):
|
||||
logger.warning(
|
||||
f"当前模型不支持图片输入,已向用户发送友好提示: {e}"
|
||||
)
|
||||
await self._dispatch_execution_notice(UNSUPPORTED_IMAGE_INPUT_MESSAGE)
|
||||
return UNSUPPORTED_IMAGE_INPUT_MESSAGE, {}
|
||||
logger.error(f"Agent执行失败: {e} - {traceback.format_exc()}")
|
||||
return str(e), {}
|
||||
friendly_message = self._friendly_execution_error_message(e)
|
||||
await self._dispatch_execution_notice(friendly_message)
|
||||
return friendly_message, {}
|
||||
finally:
|
||||
self._send_agent_tokens_usage_event(
|
||||
success=execution_success,
|
||||
error=execution_error,
|
||||
)
|
||||
# 确保停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
|
||||
@@ -803,12 +1234,16 @@ class _MessageTask:
|
||||
message: str
|
||||
images: Optional[List[str]] = None
|
||||
files: Optional[List[dict]] = None
|
||||
has_audio_input: bool = False
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
original_message_id: Optional[str] = None
|
||||
original_chat_id: Optional[str] = None
|
||||
processing_status: Optional[dict] = None
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH
|
||||
persist_output_message: bool = True
|
||||
allow_message_tools: bool = True
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -823,6 +1258,11 @@ class AgentManager:
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# 每个会话最后活动时间,用于回收空闲 Agent 实例
|
||||
self._session_last_used: Dict[str, tuple[str, datetime]] = {}
|
||||
self._idle_cleanup_task: Optional[asyncio.Task] = None
|
||||
self._idle_session_ttl = timedelta(hours=24)
|
||||
self._idle_cleanup_interval = 60 * 60
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
@@ -855,33 +1295,85 @@ class AgentManager:
|
||||
)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
memory_manager.initialize()
|
||||
if self._idle_cleanup_task and not self._idle_cleanup_task.done():
|
||||
return
|
||||
self._idle_cleanup_task = asyncio.create_task(self._cleanup_idle_sessions())
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
if self._idle_cleanup_task:
|
||||
self._idle_cleanup_task.cancel()
|
||||
try:
|
||||
await self._idle_cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._idle_cleanup_task = None
|
||||
await memory_manager.close()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
for task in list(self._session_workers.values()):
|
||||
task.cancel()
|
||||
# 等待所有worker结束
|
||||
for session_id, task in self._session_workers.items():
|
||||
for session_id, task in list(self._session_workers.items()):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.clear()
|
||||
self._session_queues.clear()
|
||||
for agent in self.active_agents.values():
|
||||
self._session_last_used.clear()
|
||||
for agent in list(self.active_agents.values()):
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
def _record_session_activity(self, session_id: str, user_id: str) -> None:
|
||||
"""
|
||||
记录会话最近活动时间,供空闲会话清理任务判断是否可释放资源。
|
||||
"""
|
||||
self._session_last_used[session_id] = (user_id, datetime.now())
|
||||
|
||||
def _is_session_busy(self, session_id: str) -> bool:
|
||||
"""
|
||||
判断会话是否仍有正在执行的 worker 或待处理消息,避免误清理活跃会话。
|
||||
"""
|
||||
worker = self._session_workers.get(session_id)
|
||||
if worker and not worker.done():
|
||||
return True
|
||||
queue = self._session_queues.get(session_id)
|
||||
return bool(queue and not queue.empty())
|
||||
|
||||
def _expired_idle_sessions(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
收集已经超过空闲时间且当前不忙的会话。
|
||||
"""
|
||||
expire_before = datetime.now() - self._idle_session_ttl
|
||||
expired = []
|
||||
for session_id, (user_id, last_used) in list(self._session_last_used.items()):
|
||||
if last_used < expire_before and not self._is_session_busy(session_id):
|
||||
expired.append((session_id, user_id))
|
||||
return expired
|
||||
|
||||
async def _cleanup_idle_sessions(self) -> None:
|
||||
"""
|
||||
周期性清理长时间没有新消息的 Agent 会话,避免长期运行后实例持续累积。
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._idle_cleanup_interval)
|
||||
for session_id, user_id in self._expired_idle_sessions():
|
||||
await self.clear_session(session_id=session_id, user_id=user_id)
|
||||
logger.info(f"已清理空闲Agent会话: session_id={session_id}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理空闲Agent会话失败: {e}")
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -889,12 +1381,15 @@ class AgentManager:
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
has_audio_input: bool = False,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -906,13 +1401,17 @@ class AgentManager:
|
||||
message=message,
|
||||
images=images,
|
||||
files=files,
|
||||
has_audio_input=has_audio_input,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
reply_mode=reply_mode,
|
||||
persist_output_message=persist_output_message,
|
||||
allow_message_tools=allow_message_tools,
|
||||
)
|
||||
self._record_session_activity(session_id, user_id)
|
||||
|
||||
# 获取或创建会话队列
|
||||
if session_id not in self._session_queues:
|
||||
@@ -965,10 +1464,12 @@ class AgentManager:
|
||||
break
|
||||
|
||||
try:
|
||||
await self._start_task_processing_status(task)
|
||||
await self._process_message_internal(task)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
|
||||
finally:
|
||||
await self._finish_task_processing_status(task)
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -983,6 +1484,23 @@ class AgentManager:
|
||||
):
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
@staticmethod
|
||||
async def _start_task_processing_status(task: _MessageTask) -> None:
|
||||
"""
|
||||
在 Agent worker 真正开始处理消息时启动渠道处理状态。
|
||||
"""
|
||||
if task.processing_status:
|
||||
return
|
||||
task.processing_status = await _async_start_processing_status(task)
|
||||
|
||||
@staticmethod
|
||||
async def _finish_task_processing_status(task: _MessageTask) -> None:
|
||||
"""
|
||||
在 Agent worker 完成或异常后结束本条消息的渠道处理状态。
|
||||
"""
|
||||
await _async_finish_processing_status(task.processing_status, task.user_id)
|
||||
task.processing_status = None
|
||||
|
||||
async def _process_message_internal(self, task: _MessageTask):
|
||||
"""
|
||||
实际处理单条消息
|
||||
@@ -1001,6 +1519,8 @@ class AgentManager:
|
||||
original_message_id=task.original_message_id,
|
||||
original_chat_id=task.original_chat_id,
|
||||
replay_mode=task.reply_mode,
|
||||
persist_output_message=task.persist_output_message,
|
||||
allow_message_tools=task.allow_message_tools,
|
||||
)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
@@ -1015,8 +1535,16 @@ class AgentManager:
|
||||
agent.original_message_id = task.original_message_id
|
||||
agent.original_chat_id = task.original_chat_id
|
||||
agent.reply_mode = task.reply_mode
|
||||
agent.persist_output_message = task.persist_output_message
|
||||
agent.allow_message_tools = task.allow_message_tools
|
||||
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
process_kwargs = {
|
||||
"images": task.images,
|
||||
"files": task.files,
|
||||
}
|
||||
if task.has_audio_input:
|
||||
process_kwargs["has_audio_input"] = True
|
||||
return await agent.process(task.message, **process_kwargs)
|
||||
|
||||
async def stop_current_task(self, session_id: str):
|
||||
"""
|
||||
@@ -1059,6 +1587,7 @@ class AgentManager:
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
self._session_last_used.pop(session_id, None)
|
||||
# 取消该会话的worker
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
@@ -1066,7 +1595,7 @@ class AgentManager:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self._session_workers.pop(session_id, None)
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
|
||||
# 清理队列
|
||||
self._session_queues.pop(session_id, None)
|
||||
@@ -1151,7 +1680,9 @@ class AgentManager:
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
reply_mode=ReplyMode.DISPATCH,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=True,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
|
||||
@@ -293,6 +293,8 @@ class StreamingHandler:
|
||||
tool_message = (tool_message or "").strip()
|
||||
tool_message_lower = tool_message.lower()
|
||||
|
||||
if tool_name == "task":
|
||||
return "subagent", tool_kwargs.get("subagent_type")
|
||||
if tool_name == "read_file":
|
||||
return "file_read", tool_kwargs.get("file_path")
|
||||
if tool_name in {"write_file", "edit_file"}:
|
||||
@@ -307,7 +309,10 @@ class StreamingHandler:
|
||||
or tool_kwargs.get("path"),
|
||||
)
|
||||
if tool_name == "execute_command":
|
||||
return "command", tool_kwargs.get("command")
|
||||
return (
|
||||
"command",
|
||||
tool_kwargs.get("command") or tool_kwargs.get("session_id"),
|
||||
)
|
||||
if tool_name == "ask_user_choice":
|
||||
return "interaction", tool_kwargs.get("message")
|
||||
if tool_name.startswith("search_") or tool_name in {"get_search_results"}:
|
||||
@@ -405,6 +410,8 @@ class StreamingHandler:
|
||||
return f"执行了 {count} 次操作"
|
||||
if category == "interaction":
|
||||
return f"发起了 {count} 次交互"
|
||||
if category == "subagent":
|
||||
return f"已调用 {count} 个子代理"
|
||||
return f"调用了 {count} 次工具"
|
||||
|
||||
def _can_stream(self) -> bool:
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
"""Agent 内部使用的 LLM 适配层。"""
|
||||
|
||||
from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.agent.llm.capability import (
|
||||
AgentCapabilityManager,
|
||||
AgentCapabilityProvider,
|
||||
AudioCapabilityProvider,
|
||||
MiMoAudioProvider,
|
||||
OpenAIChatAudioProvider,
|
||||
OpenAIAudioProvider,
|
||||
)
|
||||
from app.agent.llm.provider import (
|
||||
LLMProviderAuthError,
|
||||
LLMProviderError,
|
||||
@@ -10,10 +18,16 @@ from app.agent.llm.provider import (
|
||||
|
||||
__all__ = [
|
||||
"LLMHelper",
|
||||
"AgentCapabilityManager",
|
||||
"AgentCapabilityProvider",
|
||||
"AudioCapabilityProvider",
|
||||
"LLMProviderAuthError",
|
||||
"LLMProviderError",
|
||||
"LLMProviderManager",
|
||||
"LLMTestError",
|
||||
"LLMTestTimeout",
|
||||
"MiMoAudioProvider",
|
||||
"OpenAIChatAudioProvider",
|
||||
"OpenAIAudioProvider",
|
||||
"render_auth_result_html",
|
||||
]
|
||||
|
||||
827
app/agent/llm/capability.py
Normal file
827
app/agent/llm/capability.py
Normal file
@@ -0,0 +1,827 @@
|
||||
"""Agent 多模态能力 provider 与调度入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import shutil
|
||||
import subprocess
|
||||
from abc import ABC
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class AgentCapabilityProvider(ABC):
|
||||
"""Agent 能力 provider 基类,后续图片等能力可继续扩展到这里。"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class AudioCapabilityProvider(AgentCapabilityProvider):
|
||||
"""音频输入/输出能力 provider。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
"""是否可用于音频输入转写。"""
|
||||
return False
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
"""是否可用于语音合成输出。"""
|
||||
return False
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
raise NotImplementedError
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字合成为可发送的音频文件。"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIAudioProvider(AudioCapabilityProvider):
|
||||
"""OpenAI / OpenAI-compatible 音频 provider。"""
|
||||
|
||||
name = "openai"
|
||||
|
||||
@staticmethod
|
||||
def _build_client(api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AUDIO_INPUT_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
voice=settings.AUDIO_OUTPUT_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIChatAudioProvider(AudioCapabilityProvider):
|
||||
"""通过 OpenAI Chat Completions 兼容接口传入/返回音频的 provider。"""
|
||||
|
||||
name = "openai_chat_audio"
|
||||
DISPLAY_NAME = "OpenAI Chat Audio"
|
||||
DEFAULT_BASE_URL: Optional[str] = None
|
||||
DEFAULT_STT_MODEL: Optional[str] = None
|
||||
DEFAULT_TTS_MODEL: Optional[str] = None
|
||||
DEFAULT_VOICE = "alloy"
|
||||
AUDIO_RESPONSE_FORMAT = "wav"
|
||||
AUDIO_INPUT_DATA_URL = False
|
||||
INCLUDE_AUDIO_MODALITIES = True
|
||||
TTS_MESSAGE_ROLE = "user"
|
||||
SUPPORTED_STT_MODELS: Optional[frozenset[str]] = None
|
||||
SUPPORTED_TTS_MODELS: Optional[frozenset[str]] = None
|
||||
UNSUPPORTED_TTS_MODELS = frozenset()
|
||||
SUPPORTED_AUDIO_MIME_TYPES = {
|
||||
".flac": "audio/flac",
|
||||
".m4a": "audio/mp4",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/ogg",
|
||||
".wav": "audio/wav",
|
||||
}
|
||||
TRANSCODED_STT_SUFFIX = ".wav"
|
||||
TRANSCODED_STT_SAMPLE_RATE = "16000"
|
||||
|
||||
def _build_client(self, api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def _normalize_stt_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_STT_MODELS,
|
||||
default_model=self.DEFAULT_STT_MODEL,
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_TTS_MODELS,
|
||||
default_model=self.DEFAULT_TTS_MODEL,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model(
|
||||
model: Optional[str],
|
||||
supported_models: Optional[frozenset[str]],
|
||||
default_model: Optional[str],
|
||||
) -> str:
|
||||
model = (model or "").strip()
|
||||
if not model:
|
||||
return default_model or ""
|
||||
if supported_models is None:
|
||||
return model
|
||||
model_key = model.lower()
|
||||
if model_key in supported_models:
|
||||
return model_key
|
||||
return default_model or model
|
||||
|
||||
def _is_supported_tts_model(self) -> bool:
|
||||
model = self._normalize_tts_model()
|
||||
if not model:
|
||||
return False
|
||||
model_key = model.lower()
|
||||
if model_key in self.UNSUPPORTED_TTS_MODELS:
|
||||
return False
|
||||
return self.SUPPORTED_TTS_MODELS is None or model_key in self.SUPPORTED_TTS_MODELS
|
||||
|
||||
@classmethod
|
||||
def _guess_audio_mime_type(cls, filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower()
|
||||
if suffix in cls.SUPPORTED_AUDIO_MIME_TYPES:
|
||||
return cls.SUPPORTED_AUDIO_MIME_TYPES[suffix]
|
||||
mime_type, _ = mimetypes.guess_type(filename or "")
|
||||
return mime_type or "audio/ogg"
|
||||
|
||||
@staticmethod
|
||||
def _guess_audio_format(filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower().lstrip(".")
|
||||
if suffix == "opus":
|
||||
return "ogg"
|
||||
return suffix or "ogg"
|
||||
|
||||
def _build_audio_input_payload(self, content: bytes, filename: str) -> dict:
|
||||
"""按不同 Chat Audio 兼容形态构造 input_audio 内容。"""
|
||||
audio_data = base64.b64encode(content).decode("utf-8")
|
||||
if self.AUDIO_INPUT_DATA_URL:
|
||||
mime_type = self._guess_audio_mime_type(filename)
|
||||
return {"data": f"data:{mime_type};base64,{audio_data}"}
|
||||
return {
|
||||
"data": audio_data,
|
||||
"format": self._guess_audio_format(filename),
|
||||
}
|
||||
|
||||
def _normalize_audio_for_transcription(
|
||||
self, content: bytes, filename: str
|
||||
) -> Optional[tuple[bytes, str]]:
|
||||
"""
|
||||
将转写输入归一化为 Chat Audio provider 明确支持的格式。
|
||||
|
||||
:param content: 原始音频字节
|
||||
:param filename: 原始音频文件名
|
||||
:return: 成功时返回可提交的音频字节和文件名,失败时返回 None
|
||||
"""
|
||||
suffix = Path(filename or "").suffix.lower()
|
||||
if suffix in self.SUPPORTED_AUDIO_MIME_TYPES:
|
||||
return content, filename
|
||||
return self._convert_audio_for_transcription(content=content, filename=filename)
|
||||
|
||||
def _convert_audio_for_transcription(
|
||||
self, content: bytes, filename: str
|
||||
) -> Optional[tuple[bytes, str]]:
|
||||
"""
|
||||
将 AMR 等第三方 STT 不支持的输入转为 WAV。
|
||||
|
||||
:param content: 原始音频字节
|
||||
:param filename: 原始音频文件名
|
||||
:return: 成功时返回 WAV 字节和文件名,失败时返回 None
|
||||
"""
|
||||
if not shutil.which("ffmpeg"):
|
||||
logger.warning(
|
||||
"%s STT 不支持当前音频格式且 ffmpeg 不可用,无法转码: filename=%s",
|
||||
self.DISPLAY_NAME,
|
||||
filename,
|
||||
)
|
||||
return None
|
||||
|
||||
suffix = Path(filename or "").suffix.lower() or ".audio"
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
input_path = voice_dir / f"{uuid4().hex}{suffix}"
|
||||
output_path = input_path.with_suffix(self.TRANSCODED_STT_SUFFIX)
|
||||
try:
|
||||
input_path.write_bytes(content)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(input_path),
|
||||
"-ar",
|
||||
self.TRANSCODED_STT_SAMPLE_RATE,
|
||||
"-ac",
|
||||
"1",
|
||||
"-f",
|
||||
"wav",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0 or not output_path.exists():
|
||||
logger.warning(
|
||||
"%s STT 音频转 WAV 失败: returncode=%s, stderr=%s",
|
||||
self.DISPLAY_NAME,
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
return output_path.read_bytes(), f"{input_path.stem}{self.TRANSCODED_STT_SUFFIX}"
|
||||
finally:
|
||||
for temp_path in (input_path, output_path):
|
||||
try:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
except OSError as err:
|
||||
logger.debug(f"清理 STT 临时音频失败: path={temp_path}, error={err}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_text(message) -> Optional[str]:
|
||||
"""兼容音频理解响应可能放在 content 或 reasoning_content 的情况。"""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content.strip()
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
if isinstance(reasoning_content, str) and reasoning_content.strip():
|
||||
return reasoning_content.strip()
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict):
|
||||
for key in ("content", "reasoning_content"):
|
||||
value = extra.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_data(message) -> Optional[str]:
|
||||
audio = getattr(message, "audio", None)
|
||||
if isinstance(audio, dict):
|
||||
return audio.get("data")
|
||||
if audio is not None:
|
||||
return getattr(audio, "data", None)
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict) and isinstance(extra.get("audio"), dict):
|
||||
return extra["audio"].get("data")
|
||||
return None
|
||||
|
||||
def _convert_wav_to_opus(self, wav_path: Path) -> Optional[Path]:
|
||||
"""将 Chat Audio 返回的 WAV 转成 OGG/Opus,便于各通知渠道发送语音。"""
|
||||
if not shutil.which("ffmpeg"):
|
||||
return None
|
||||
|
||||
output_path = wav_path.with_suffix(".opus")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(wav_path),
|
||||
"-ar",
|
||||
"48000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-c:a",
|
||||
"libopus",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0 or not output_path.exists():
|
||||
logger.warning(
|
||||
"%s TTS 音频转 Opus 失败,将使用 WAV 原文件: returncode=%s, stderr=%s",
|
||||
self.DISPLAY_NAME,
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
return output_path
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key) and self._is_supported_tts_model()
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
normalized_audio = self._normalize_audio_for_transcription(
|
||||
content=content, filename=filename
|
||||
)
|
||||
if not normalized_audio:
|
||||
return None
|
||||
content, filename = normalized_audio
|
||||
language = (settings.AUDIO_INPUT_LANGUAGE or "").strip()
|
||||
prompt = "请将这段音频完整转写为文字,只输出转写结果,不要添加解释。"
|
||||
if language:
|
||||
prompt += f"音频主要语言是 {language}。"
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=self._normalize_stt_model(),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": self._build_audio_input_payload(
|
||||
content=content, filename=filename
|
||||
),
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_completion_tokens=2048,
|
||||
)
|
||||
return self._extract_message_text(completion.choices[0].message)
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
if not self._is_supported_tts_model():
|
||||
logger.error(
|
||||
"%s TTS 当前不支持该模型或模型未配置: %s",
|
||||
self.DISPLAY_NAME,
|
||||
settings.AUDIO_OUTPUT_MODEL,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = voice_dir / f"{uuid4().hex}.wav"
|
||||
request = {
|
||||
"model": self._normalize_tts_model(),
|
||||
"messages": [
|
||||
{
|
||||
"role": self.TTS_MESSAGE_ROLE,
|
||||
"content": text,
|
||||
}
|
||||
],
|
||||
"audio": {
|
||||
"format": self.AUDIO_RESPONSE_FORMAT,
|
||||
"voice": settings.AUDIO_OUTPUT_VOICE or self.DEFAULT_VOICE,
|
||||
},
|
||||
}
|
||||
if self.INCLUDE_AUDIO_MODALITIES:
|
||||
request["modalities"] = ["text", "audio"]
|
||||
completion = client.chat.completions.create(**request)
|
||||
audio_data = self._extract_audio_data(completion.choices[0].message)
|
||||
if not audio_data:
|
||||
raise ValueError(f"{self.DISPLAY_NAME} TTS 响应中没有音频数据")
|
||||
|
||||
wav_path.write_bytes(base64.b64decode(audio_data))
|
||||
return self._convert_wav_to_opus(wav_path) or wav_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class MiMoAudioProvider(OpenAIChatAudioProvider):
|
||||
"""Xiaomi MiMo Chat Audio 预设,仅接入普通 STT/TTS 能力。"""
|
||||
|
||||
name = "mimo"
|
||||
DISPLAY_NAME = "Xiaomi MiMo"
|
||||
DEFAULT_BASE_URL = "https://api.xiaomimimo.com/v1"
|
||||
DEFAULT_STT_MODEL = "mimo-v2.5"
|
||||
DEFAULT_TTS_MODEL = "mimo-v2.5-tts"
|
||||
DEFAULT_VOICE = "mimo_default"
|
||||
AUDIO_INPUT_DATA_URL = True
|
||||
INCLUDE_AUDIO_MODALITIES = False
|
||||
TTS_MESSAGE_ROLE = "assistant"
|
||||
SUPPORTED_STT_MODELS = frozenset({"mimo-v2.5", "mimo-v2-omni"})
|
||||
SUPPORTED_TTS_MODELS = frozenset({DEFAULT_TTS_MODEL})
|
||||
UNSUPPORTED_TTS_MODELS = frozenset(
|
||||
{
|
||||
"mimo-v2.5-tts-voiceclone",
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
}
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
model = (settings.AUDIO_OUTPUT_MODEL or "").strip().lower()
|
||||
if not model or not model.startswith("mimo-"):
|
||||
return self.DEFAULT_TTS_MODEL
|
||||
return model
|
||||
|
||||
|
||||
class MiniMaxAudioProvider(OpenAIChatAudioProvider):
|
||||
"""MiniMax 音频 provider,语音合成使用官方 T2A HTTP 接口。"""
|
||||
|
||||
name = "minimax"
|
||||
DISPLAY_NAME = "MiniMax"
|
||||
DEFAULT_BASE_URL = "https://api.minimaxi.com/v1"
|
||||
DEFAULT_STT_MODEL = "MiniMax-M2.7"
|
||||
DEFAULT_TTS_MODEL = "speech-2.8-turbo"
|
||||
DEFAULT_VOICE = "Chinese (Mandarin)_Lyrical_Voice"
|
||||
AUDIO_INPUT_DATA_URL = True
|
||||
SUPPORTED_TTS_MODELS = frozenset(
|
||||
{
|
||||
"speech-2.8-hd",
|
||||
"speech-2.8-turbo",
|
||||
"speech-2.6-hd",
|
||||
"speech-2.6-turbo",
|
||||
"speech-02-hd",
|
||||
"speech-02-turbo",
|
||||
"speech-01-hd",
|
||||
"speech-01-turbo",
|
||||
}
|
||||
)
|
||||
|
||||
def _build_client(self, api_key: str, base_url: Optional[str]):
|
||||
"""构建 MiniMax OpenAI 兼容客户端,兼容用户误填 Anthropic 端点的情况。"""
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self._normalize_api_base_url(base_url),
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_api_base_url(cls, base_url: Optional[str]) -> str:
|
||||
"""归一化 MiniMax API 基础 URL,确保后续可以拼接 OpenAI/T2A 路径。"""
|
||||
normalized = (base_url or cls.DEFAULT_BASE_URL).strip().rstrip("/")
|
||||
if normalized.endswith("/t2a_v2"):
|
||||
normalized = normalized[: -len("/t2a_v2")]
|
||||
for suffix in ("/anthropic/v1", "/openai/v1"):
|
||||
if normalized.endswith(suffix):
|
||||
return normalized[: -len(suffix)] + "/v1"
|
||||
if not normalized.endswith("/v1"):
|
||||
normalized = f"{normalized}/v1"
|
||||
return normalized
|
||||
|
||||
@classmethod
|
||||
def _build_t2a_url(cls, base_url: Optional[str]) -> str:
|
||||
"""生成 MiniMax 同步 T2A 接口地址。"""
|
||||
return f"{cls._normalize_api_base_url(base_url)}/t2a_v2"
|
||||
|
||||
def _normalize_stt_model(self) -> str:
|
||||
"""将非 MiniMax 的默认转写模型名兜底为 MiniMax 对话模型。"""
|
||||
model = (settings.AUDIO_INPUT_MODEL or "").strip()
|
||||
if not model or model.lower().startswith(("gpt-", "mimo-")):
|
||||
return self.DEFAULT_STT_MODEL
|
||||
return model
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
"""将非 MiniMax 语音模型兜底为官方 T2A 模型。"""
|
||||
model = (settings.AUDIO_OUTPUT_MODEL or "").strip().lower()
|
||||
if model in self.SUPPORTED_TTS_MODELS:
|
||||
return model
|
||||
return self.DEFAULT_TTS_MODEL
|
||||
|
||||
def _normalize_voice_id(self) -> str:
|
||||
"""将其他 provider 的默认音色兜底为 MiniMax 中文系统音色。"""
|
||||
voice_id = (settings.AUDIO_OUTPUT_VOICE or "").strip()
|
||||
if not voice_id or voice_id in {"alloy", "mimo_default"}:
|
||||
return self.DEFAULT_VOICE
|
||||
return voice_id
|
||||
|
||||
@staticmethod
|
||||
def _decode_audio_payload(audio_data: str) -> bytes:
|
||||
"""解析 MiniMax T2A 返回的音频数据,优先按官方 hex 格式处理。"""
|
||||
normalized = "".join((audio_data or "").split())
|
||||
try:
|
||||
return bytes.fromhex(normalized)
|
||||
except ValueError:
|
||||
return base64.b64decode(audio_data)
|
||||
|
||||
@staticmethod
|
||||
def _extract_minimax_error(data: dict[str, Any]) -> Optional[str]:
|
||||
"""提取 MiniMax base_resp 错误信息,成功响应返回 None。"""
|
||||
base_resp = data.get("base_resp") or {}
|
||||
status_code = base_resp.get("status_code")
|
||||
if status_code in (None, 0, "0"):
|
||||
return None
|
||||
status_msg = base_resp.get("status_msg") or "unknown error"
|
||||
return f"{status_code}: {status_msg}"
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""调用 MiniMax T2A HTTP 接口合成语音文件。"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
response = RequestUtils(
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
proxies=settings.PROXY or {},
|
||||
timeout=60,
|
||||
).post_res(
|
||||
url=self._build_t2a_url(base_url),
|
||||
json={
|
||||
"model": self._normalize_tts_model(),
|
||||
"text": text,
|
||||
"stream": False,
|
||||
"language_boost": "auto",
|
||||
"output_format": "hex",
|
||||
"voice_setting": {
|
||||
"voice_id": self._normalize_voice_id(),
|
||||
"speed": 1,
|
||||
"vol": 1,
|
||||
"pitch": 0,
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "opus",
|
||||
"channel": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
if not response:
|
||||
raise ValueError("MiniMax T2A 请求无响应")
|
||||
if response.status_code >= 400:
|
||||
raise ValueError(f"MiniMax T2A HTTP {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
minimax_error = self._extract_minimax_error(result)
|
||||
if minimax_error:
|
||||
raise ValueError(f"MiniMax T2A 返回错误: {minimax_error}")
|
||||
|
||||
audio_data = ((result.get("data") or {}).get("audio") or "").strip()
|
||||
if not audio_data:
|
||||
raise ValueError("MiniMax T2A 响应中没有音频数据")
|
||||
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
output_path.write_bytes(self._decode_audio_payload(audio_data))
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class AgentCapabilityManager:
|
||||
"""Agent 能力统一入口。"""
|
||||
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
_audio_providers: Dict[str, AudioCapabilityProvider] = {
|
||||
OpenAIAudioProvider.name: OpenAIAudioProvider(),
|
||||
OpenAIChatAudioProvider.name: OpenAIChatAudioProvider(),
|
||||
MiMoAudioProvider.name: MiMoAudioProvider(),
|
||||
MiniMaxAudioProvider.name: MiniMaxAudioProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_audio_provider(cls, provider: AudioCapabilityProvider) -> None:
|
||||
"""注册新的音频 provider。"""
|
||||
cls._audio_providers[provider.name.lower()] = provider
|
||||
|
||||
@classmethod
|
||||
def get_registered_audio_providers(cls) -> list[str]:
|
||||
"""返回已注册的音频 provider 名称。"""
|
||||
return sorted(cls._audio_providers.keys())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider_name(provider: Optional[str]) -> str:
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_log_name(provider: AudioCapabilityProvider) -> str:
|
||||
provider_name = getattr(provider, "name", None)
|
||||
return provider_name if isinstance(provider_name, str) else provider.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def get_audio_provider(cls, mode: str) -> Optional[AudioCapabilityProvider]:
|
||||
provider_name = cls._normalize_provider_name(
|
||||
settings.AUDIO_INPUT_PROVIDER
|
||||
if (mode or "").lower() == "input"
|
||||
else settings.AUDIO_OUTPUT_PROVIDER
|
||||
)
|
||||
provider = cls._audio_providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning("未注册音频 provider: mode=%s, provider=%s", mode, provider_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""当前 Agent 是否启用图片输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_input() -> bool:
|
||||
"""当前 Agent 是否启用音频输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_output() -> bool:
|
||||
"""当前 Agent 是否启用音频输出能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_OUTPUT)
|
||||
|
||||
@classmethod
|
||||
def is_audio_input_available(cls) -> bool:
|
||||
if not cls.supports_audio_input():
|
||||
return False
|
||||
provider = cls.get_audio_provider("input")
|
||||
return bool(provider and provider.is_available_for_audio_input())
|
||||
|
||||
@classmethod
|
||||
def is_audio_output_available(cls) -> bool:
|
||||
if not cls.supports_audio_output():
|
||||
return False
|
||||
provider = cls.get_audio_provider("output")
|
||||
return bool(provider and provider.is_available_for_audio_output())
|
||||
|
||||
@classmethod
|
||||
def transcribe_audio(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将语音文件内容转写为文字,并记录能力调用日志。"""
|
||||
provider = cls.get_audio_provider("input")
|
||||
if not provider or not cls.is_audio_input_available():
|
||||
logger.info("语音转文字跳过:音频输入能力未启用或 provider 不可用")
|
||||
return None
|
||||
provider_name = cls._get_provider_log_name(provider)
|
||||
logger.info(
|
||||
f"语音转文字开始:provider={provider_name}, filename={filename}, "
|
||||
f"bytes={len(content) if content else 0}"
|
||||
)
|
||||
transcript = provider.transcribe_audio(content=content, filename=filename)
|
||||
if transcript:
|
||||
logger.info(
|
||||
f"语音转文字完成:provider={provider_name}, filename={filename}, "
|
||||
f"text_len={len(transcript)}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"语音转文字无结果:provider={provider_name}, filename={filename}"
|
||||
)
|
||||
return transcript
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
"""将文字合成为语音文件,并记录能力调用日志。"""
|
||||
provider = cls.get_audio_provider("output")
|
||||
if not provider or not cls.is_audio_output_available():
|
||||
logger.info("文字转语音跳过:音频输出能力未启用或 provider 不可用")
|
||||
return None
|
||||
provider_name = cls._get_provider_log_name(provider)
|
||||
logger.info(
|
||||
f"文字转语音开始:provider={provider_name}, text_len={len(text) if text else 0}"
|
||||
)
|
||||
output_path = provider.synthesize_speech(text=text)
|
||||
if output_path:
|
||||
logger.info(f"文字转语音完成:provider={provider_name}, path={output_path}")
|
||||
else:
|
||||
logger.info(f"文字转语音无结果:provider={provider_name}")
|
||||
return output_path
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""仅在支持原生语音回复的渠道上发送音频,其余渠道回退文字。"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def _parse_message_channel(cls, channel: Optional[Any]):
|
||||
"""将渠道入参归一化为消息渠道枚举。"""
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
if isinstance(channel, MessageChannel):
|
||||
return channel
|
||||
|
||||
channel_text = str(channel).strip()
|
||||
if not channel_text:
|
||||
return None
|
||||
lowered_channel = channel_text.lower()
|
||||
for channel_item in MessageChannel:
|
||||
aliases = {
|
||||
channel_item.value.lower(),
|
||||
channel_item.name.lower(),
|
||||
f"{MessageChannel.__name__}.{channel_item.name}".lower(),
|
||||
}
|
||||
if lowered_channel in aliases:
|
||||
return channel_item
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_wechat_app_mode(source: Optional[str]) -> bool:
|
||||
"""判断企业微信来源是否为自建应用模式。"""
|
||||
if not source:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""判断当前渠道是否支持原生语音消息发送。"""
|
||||
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
channel_enum = cls._parse_message_channel(channel)
|
||||
if not channel_enum:
|
||||
return False
|
||||
|
||||
if not ChannelCapabilityManager.supports_capability(
|
||||
channel_enum, ChannelCapability.AUDIO_OUTPUT
|
||||
):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Wechat:
|
||||
return cls._is_wechat_app_mode(source)
|
||||
return True
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -32,29 +32,87 @@ class LLMTestTimeout(TimeoutError):
|
||||
def _patch_gemini_thought_signature():
|
||||
"""
|
||||
修复 langchain-google-genai 中 Gemini 2.5 思考模型的 thought_signature 兼容问题。
|
||||
langchain-google-genai 的 _is_gemini_3_or_later() 仅检查 "gemini-3",
|
||||
导致 Gemini 2.5 思考模型(如 gemini-2.5-flash、gemini-2.5-pro)在工具调用时
|
||||
缺少 thought_signature 而报错 400。
|
||||
此补丁将检查范围扩展到 Gemini 2.5 模型。
|
||||
|
||||
问题 1:_is_gemini_3_or_later() 仅检查 "gemini-3",不包含 Gemini 2.5 模型,
|
||||
导致 _parse_chat_history 的 thought_signature 强制注入逻辑被跳过。
|
||||
|
||||
问题 2:强制注入逻辑使用 first_fc_seen 标志,只给每个 model 消息中
|
||||
第一个缺少 thought_signature 的 function_call 补 dummy,后续并行
|
||||
function_call 仍缺失签名,导致 Gemini API 返回 400。
|
||||
|
||||
此补丁同时修复以上两个问题。
|
||||
"""
|
||||
try:
|
||||
import langchain_google_genai.chat_models as _cm
|
||||
|
||||
# 检查版本:需要 >= 4.0 才支持 _is_gemini_3_or_later
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
_version = version("langchain-google-genai") or ""
|
||||
except Exception:
|
||||
_version = ""
|
||||
try:
|
||||
_major = int(_version.split(".")[0]) if _version else 0
|
||||
except (ValueError, TypeError):
|
||||
_major = 0
|
||||
if _major < 4:
|
||||
logger.error(
|
||||
f"langchain-google-genai 版本 {_version or '未知'} 过旧,"
|
||||
f"不支持 Gemini 2.5+ 模型的 thought_signature 处理,"
|
||||
f"请升级到 4.2.3+:pip install langchain-google-genai~=4.2.3"
|
||||
)
|
||||
return
|
||||
|
||||
# 仅在未修补时执行
|
||||
if getattr(_cm, "_thought_signature_patched", False):
|
||||
return
|
||||
|
||||
if not hasattr(_cm, "_is_gemini_3_or_later"):
|
||||
logger.error(
|
||||
"langchain-google-genai 缺少 _is_gemini_3_or_later,"
|
||||
"无法修补 thought_signature 兼容性,请检查包版本"
|
||||
)
|
||||
return
|
||||
|
||||
# 补丁 1:扩展 _is_gemini_3_or_later,使 Gemini 2.5 模型也能触发
|
||||
# _parse_chat_history 中的 thought_signature 强制注入逻辑
|
||||
def _patched_is_gemini_3_or_later(model_name: str) -> bool:
|
||||
if not model_name:
|
||||
return False
|
||||
name = model_name.lower().replace("models/", "")
|
||||
# Gemini 2.5 思考模型也需要 thought_signature 支持
|
||||
return "gemini-3" in name or "gemini-2.5" in name
|
||||
|
||||
_cm._is_gemini_3_or_later = _patched_is_gemini_3_or_later
|
||||
|
||||
# 补丁 2:修复 _parse_chat_history 中 first_fc_seen 只修复第一个
|
||||
# function_call 的问题。用 wrapper 在原函数返回后,确保所有 model
|
||||
# 消息中所有 function_call 都带有 thought_signature。
|
||||
_original_parse_chat_history = _cm._parse_chat_history # noqa
|
||||
|
||||
def _patched_parse_chat_history(*args, **kwargs):
|
||||
result = _original_parse_chat_history(*args, **kwargs)
|
||||
system_instruction, formatted_messages = result
|
||||
|
||||
# 从参数中提取 model 名称
|
||||
model = kwargs.get("model")
|
||||
if model is None and len(args) >= 4:
|
||||
model = args[3]
|
||||
|
||||
if model and _patched_is_gemini_3_or_later(model):
|
||||
dummy = _cm.DUMMY_THOUGHT_SIGNATURE
|
||||
for content_msg in formatted_messages:
|
||||
if content_msg.role == "model":
|
||||
for part in content_msg.parts or []:
|
||||
if part.function_call and not part.thought_signature:
|
||||
part.thought_signature = dummy
|
||||
|
||||
return result
|
||||
|
||||
_cm._parse_chat_history = _patched_parse_chat_history
|
||||
_cm._thought_signature_patched = True
|
||||
logger.debug(
|
||||
"已修补 langchain-google-genai thought_signature 兼容性(覆盖 Gemini 2.5 模型)"
|
||||
"已修补 langchain-google-genai thought_signature 兼容性"
|
||||
"(覆盖 Gemini 2.5 模型 + 修复并行 function_call 签名缺失)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"修补 langchain-google-genai thought_signature 失败: {e}")
|
||||
@@ -79,6 +137,57 @@ def _get_httpx_proxy_key() -> str:
|
||||
return "proxies"
|
||||
|
||||
|
||||
def _resolve_llm_proxy(use_proxy: bool | None = None) -> str | None:
|
||||
"""
|
||||
解析本次 LLM 调用应使用的系统代理地址。
|
||||
"""
|
||||
should_use_proxy = settings.LLM_USE_PROXY if use_proxy is None else use_proxy
|
||||
return settings.PROXY_HOST if should_use_proxy and settings.PROXY_HOST else None
|
||||
|
||||
|
||||
def _build_httpx_proxy_kwargs(proxy_url: str | None) -> dict[str, str]:
|
||||
"""
|
||||
构造兼容当前 httpx 版本的代理参数。
|
||||
"""
|
||||
if not proxy_url:
|
||||
return {}
|
||||
return {_get_httpx_proxy_key(): proxy_url}
|
||||
|
||||
|
||||
def _build_google_client_args(proxy_url: str | None) -> dict[str, Any]:
|
||||
"""
|
||||
构造 Google SDK 透传给 httpx 的客户端参数。
|
||||
"""
|
||||
return {
|
||||
"trust_env": False,
|
||||
**_build_httpx_proxy_kwargs(proxy_url),
|
||||
}
|
||||
|
||||
|
||||
def _build_httpx_client(
|
||||
proxy_url: str | None,
|
||||
*,
|
||||
async_client: bool = False,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
"""
|
||||
构造显式代理策略的 httpx 客户端。
|
||||
|
||||
当关闭 LLM 代理时也返回 trust_env=False 的客户端,避免 httpx 自动读取
|
||||
进程环境变量中的代理配置。
|
||||
"""
|
||||
import httpx
|
||||
|
||||
client_cls = httpx.AsyncClient if async_client else httpx.Client
|
||||
kwargs: dict[str, Any] = {
|
||||
"trust_env": False,
|
||||
**_build_httpx_proxy_kwargs(proxy_url),
|
||||
}
|
||||
if timeout is not None:
|
||||
kwargs["timeout"] = timeout
|
||||
return client_cls(**kwargs)
|
||||
|
||||
|
||||
def _deepseek_thinking_toggle(extra_body: Any) -> bool | None:
|
||||
"""
|
||||
解析 DeepSeek extra_body 中显式传入的 thinking 开关。
|
||||
@@ -142,9 +251,15 @@ def _patch_deepseek_reasoning_content_support():
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
# Resolve original messages so we can extract reasoning_content from
|
||||
# additional_kwargs. The parent's payload builder does not propagate
|
||||
# this DeepSeek-specific field.
|
||||
extra_body = (getattr(self, "model_kwargs", None) or {}).get("extra_body")
|
||||
if not _is_deepseek_thinking_enabled(
|
||||
getattr(self, "model_name", None) or getattr(self, "model", None),
|
||||
extra_body,
|
||||
):
|
||||
return payload
|
||||
|
||||
# 从原始 LangChain 消息中取回 reasoning_content。上游 payload 构造器
|
||||
# 不会自动透传这个 DeepSeek 扩展字段。
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
|
||||
for i, message in enumerate(payload["messages"]):
|
||||
@@ -152,9 +267,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
message["content"] = json.dumps(message["content"])
|
||||
elif message["role"] == "assistant":
|
||||
if isinstance(message["content"], list):
|
||||
# DeepSeek API expects assistant content to be a string,
|
||||
# not a list. Extract text blocks and join them, or use
|
||||
# empty string if none exist.
|
||||
# DeepSeek API 要求 assistant content 为字符串;工具场景下
|
||||
# LangChain 可能保留为内容块列表,这里只拼回可见文本块。
|
||||
text_parts = [
|
||||
block.get("text", "")
|
||||
for block in message["content"]
|
||||
@@ -162,10 +276,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
]
|
||||
message["content"] = "".join(text_parts) if text_parts else ""
|
||||
|
||||
# DeepSeek reasoning models require every assistant message to
|
||||
# carry a reasoning_content field (even when empty). The value
|
||||
# is stored in AIMessage.additional_kwargs by
|
||||
# _create_chat_result(); re-inject it into the API payload.
|
||||
# DeepSeek thinking mode 要求历史 assistant 消息携带
|
||||
# reasoning_content,即便本地只保存到了 additional_kwargs。
|
||||
if (
|
||||
"reasoning_content" not in message
|
||||
and i < len(messages)
|
||||
@@ -182,6 +294,103 @@ def _patch_deepseek_reasoning_content_support():
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_interleaved_reasoning_content_support():
|
||||
"""
|
||||
修补 OpenAI-compatible 模型的 interleaved reasoning 内容回传。
|
||||
|
||||
小米 MiMo、部分 Kimi/GLM 等兼容端点会把思考内容放在响应顶层
|
||||
`reasoning_content` 字段;如果下一轮请求没有把它随历史 assistant
|
||||
消息带回,工具调用后续请求会被服务端以 400 拒绝。
|
||||
|
||||
这里不按 provider 白名单判断,而是只在历史 AIMessage 真实保存过
|
||||
`reasoning_content` 时回传,避免以后每接入一个同类模型都要单独适配。
|
||||
"""
|
||||
try:
|
||||
import langchain_openai.chat_models.base as _openai_base
|
||||
from langchain_openai import ChatOpenAI
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai reasoning_content 修补:{err}")
|
||||
return
|
||||
|
||||
if not getattr(_openai_base, "_moviepilot_reasoning_response_patched", False):
|
||||
original_convert_dict = getattr(_openai_base, "_convert_dict_to_message", None)
|
||||
original_convert_delta = getattr(
|
||||
_openai_base, "_convert_delta_to_message_chunk", None
|
||||
)
|
||||
|
||||
if callable(original_convert_dict):
|
||||
@wraps(original_convert_dict)
|
||||
def _patched_convert_dict_to_message(message_dict):
|
||||
message = original_convert_dict(message_dict)
|
||||
if (
|
||||
isinstance(message, AIMessage)
|
||||
and "reasoning_content" in message_dict
|
||||
):
|
||||
message.additional_kwargs["reasoning_content"] = (
|
||||
message_dict.get("reasoning_content") or ""
|
||||
)
|
||||
return message
|
||||
|
||||
_openai_base._convert_dict_to_message = _patched_convert_dict_to_message
|
||||
|
||||
if callable(original_convert_delta):
|
||||
@wraps(original_convert_delta)
|
||||
def _patched_convert_delta_to_message_chunk(delta, default_class):
|
||||
chunk = original_convert_delta(delta, default_class)
|
||||
if (
|
||||
isinstance(chunk, AIMessageChunk)
|
||||
and "reasoning_content" in delta
|
||||
):
|
||||
chunk.additional_kwargs["reasoning_content"] = (
|
||||
delta.get("reasoning_content") or ""
|
||||
)
|
||||
return chunk
|
||||
|
||||
_openai_base._convert_delta_to_message_chunk = (
|
||||
_patched_convert_delta_to_message_chunk
|
||||
)
|
||||
|
||||
_openai_base._moviepilot_reasoning_response_patched = True
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_interleaved_reasoning_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-openai 缺少 _get_request_payload,无法修补 reasoning_content")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
if "messages" not in payload:
|
||||
return payload
|
||||
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
for index, payload_message in enumerate(payload["messages"]):
|
||||
if (
|
||||
payload_message.get("role") != "assistant"
|
||||
or index >= len(messages)
|
||||
or not isinstance(messages[index], AIMessage)
|
||||
or "reasoning_content" in payload_message
|
||||
):
|
||||
continue
|
||||
|
||||
reasoning_content = messages[index].additional_kwargs.get(
|
||||
"reasoning_content"
|
||||
)
|
||||
if reasoning_content is not None:
|
||||
# 只回传模型真实返回过的思考字段。普通模型没有该字段时,
|
||||
# payload 保持原样,不额外塞未知参数。
|
||||
payload_message["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
|
||||
ChatOpenAI._get_request_payload = _patched_get_request_payload
|
||||
ChatOpenAI._moviepilot_interleaved_reasoning_patched = True
|
||||
logger.debug("已修补 langchain-openai interleaved reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_instructions_support():
|
||||
"""
|
||||
修补 langchain-openai 在使用 use_responses_api=True 时,
|
||||
@@ -195,6 +404,9 @@ def _patch_openai_responses_instructions_support():
|
||||
logger.debug(f"跳过 langchain-openai instructions 修补:{err}")
|
||||
return
|
||||
|
||||
_patch_openai_interleaved_reasoning_content_support()
|
||||
_patch_openai_responses_empty_output_support()
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
|
||||
return
|
||||
|
||||
@@ -253,6 +465,64 @@ def _patch_openai_responses_instructions_support():
|
||||
logger.debug("已修补 langchain-openai responses API 的 instructions 兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_empty_output_support():
|
||||
"""
|
||||
修补 langchain-openai Responses API 流式完成事件 output 为空的兼容性。
|
||||
|
||||
ChatGPT Codex 后端有时会在 `response.completed` chunk 里返回
|
||||
`response.output = None`,但前面的 delta chunk 已经包含实际文本。
|
||||
langchain-openai 在收尾阶段遍历 output 会抛出 TypeError,这里将缺失
|
||||
output 规整为空列表,让收尾 chunk 只承载 usage/metadata。
|
||||
"""
|
||||
try:
|
||||
import langchain_openai.chat_models.base as _openai_base
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai responses output 修补:{err}")
|
||||
return
|
||||
|
||||
if getattr(_openai_base, "_moviepilot_responses_empty_output_patched", False):
|
||||
return
|
||||
|
||||
original_construct = getattr(
|
||||
_openai_base, "_construct_lc_result_from_responses_api", None
|
||||
)
|
||||
if not callable(original_construct):
|
||||
logger.warning("langchain-openai 缺少 Responses API 结果构造函数,无法修补 output")
|
||||
return
|
||||
|
||||
def _clone_response_with_empty_output(response):
|
||||
"""
|
||||
复制 Responses 对象,把缺失 output 规整为空列表。
|
||||
"""
|
||||
model_copy = getattr(response, "model_copy", None)
|
||||
if callable(model_copy):
|
||||
try:
|
||||
return model_copy(update={"output": []})
|
||||
except Exception as e:
|
||||
logger.debug(f"复制 Responses 对象失败,回退原地修补 output:{e}")
|
||||
|
||||
try:
|
||||
setattr(response, "output", [])
|
||||
except Exception as e:
|
||||
logger.debug(f"原地修补 Responses output 失败:{e}")
|
||||
return response
|
||||
|
||||
@wraps(original_construct)
|
||||
def _patched_construct_lc_result_from_responses_api(response, *args, **kwargs):
|
||||
"""
|
||||
在 Responses API 收尾 chunk 缺少 output 时跳过空内容遍历。
|
||||
"""
|
||||
if hasattr(response, "output") and getattr(response, "output", None) is None:
|
||||
response = _clone_response_with_empty_output(response)
|
||||
return original_construct(response, *args, **kwargs)
|
||||
|
||||
_openai_base._construct_lc_result_from_responses_api = (
|
||||
_patched_construct_lc_result_from_responses_api
|
||||
)
|
||||
_openai_base._moviepilot_responses_empty_output_patched = True
|
||||
logger.debug("已修补 langchain-openai responses API 空 output 兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@@ -442,6 +712,7 @@ class LLMHelper:
|
||||
model_name: str | None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
在 provider 目录不可用时回退到旧的直接构造逻辑。
|
||||
@@ -465,12 +736,68 @@ class LLMHelper:
|
||||
"model_id": model_name,
|
||||
"api_key": api_key_value,
|
||||
"base_url": base_url_value,
|
||||
"default_headers": None,
|
||||
"default_headers": LLMHelper._build_openai_default_headers(
|
||||
None,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
"use_responses_api": None,
|
||||
"model_record": None,
|
||||
"model_metadata": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_openai_default_headers(
|
||||
default_headers: dict[str, str] | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
合并 OpenAI 兼容接口默认请求头。
|
||||
|
||||
:param default_headers: provider 运行时已解析的默认请求头
|
||||
:param user_agent: 用户配置的 User-Agent,非空时写入标准请求头
|
||||
:return: 可传给 OpenAI SDK 的请求头字典
|
||||
"""
|
||||
headers = dict(default_headers or {})
|
||||
normalized_user_agent = str(user_agent or "").strip()
|
||||
if normalized_user_agent:
|
||||
for key in list(headers.keys()):
|
||||
if key.lower() == "user-agent":
|
||||
headers.pop(key)
|
||||
headers["User-Agent"] = normalized_user_agent
|
||||
return headers or None
|
||||
|
||||
@classmethod
|
||||
def _should_use_openai_responses_api(
|
||||
cls,
|
||||
provider: str,
|
||||
model: str | None,
|
||||
runtime: dict[str, Any],
|
||||
) -> bool | None:
|
||||
"""
|
||||
判断官方 ChatGPT API Key 模式是否应使用 Responses API。
|
||||
|
||||
GPT-5/o 系推理模型在 Chat Completions 中组合 function tools 与
|
||||
reasoning_effort 时会被官方端点拒绝,因此 ChatGPT 官方 API Key
|
||||
模式需要显式切到 Responses API;通用 OpenAI-compatible 入口保持
|
||||
provider 目录解析出的默认行为,避免误伤第三方兼容服务。
|
||||
"""
|
||||
runtime_use_responses_api = runtime.get("use_responses_api")
|
||||
if runtime_use_responses_api is not None:
|
||||
return bool(runtime_use_responses_api)
|
||||
|
||||
provider_name = (provider or "").strip().lower()
|
||||
if provider_name != "chatgpt":
|
||||
return None
|
||||
|
||||
base_url = str(runtime.get("base_url") or "").strip().lower()
|
||||
if "api.openai.com" not in base_url:
|
||||
return None
|
||||
|
||||
model_name = cls._normalize_model_name(model)
|
||||
if model_name.startswith(("gpt-5", "o1", "o3", "o4")):
|
||||
return True
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_thinking_level(
|
||||
cls,
|
||||
@@ -515,6 +842,8 @@ class LLMHelper:
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
use_proxy: bool | None = None,
|
||||
):
|
||||
"""
|
||||
获取LLM实例
|
||||
@@ -528,6 +857,8 @@ class LLMHelper:
|
||||
:param api_key: API Key。未显式传入时使用当前配置项 LLM_API_KEY。对于某些提供商(如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param base_url: API Base URL。未显式传入时使用当前配置项 LLM_BASE_URL。
|
||||
:param base_url_preset: Base URL 预设。未显式传入时使用当前配置项 LLM_BASE_URL_PRESET。
|
||||
:param user_agent: OpenAI兼容接口请求 User-Agent。未显式传入时使用配置项 LLM_USER_AGENT。
|
||||
:param use_proxy: 是否为本次 LLM 调用使用系统代理。未显式传入时使用配置项 LLM_USE_PROXY。
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).lower()
|
||||
@@ -537,6 +868,7 @@ class LLMHelper:
|
||||
base_url_preset_value = (
|
||||
base_url_preset if base_url_preset is not None else settings.LLM_BASE_URL_PRESET
|
||||
)
|
||||
user_agent_value = user_agent if user_agent is not None else settings.LLM_USER_AGENT
|
||||
normalized_thinking_level = cls._resolve_thinking_level(
|
||||
thinking_level=thinking_level,
|
||||
)
|
||||
@@ -551,6 +883,8 @@ class LLMHelper:
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
base_url_preset_id=base_url_preset_value,
|
||||
user_agent=user_agent_value,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM provider 目录不可用,回退到旧运行时逻辑: {err}")
|
||||
@@ -559,13 +893,24 @@ class LLMHelper:
|
||||
model_name=model_name,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
user_agent=user_agent_value,
|
||||
)
|
||||
model_name = runtime.get("model_id") or model_name
|
||||
default_headers = cls._build_openai_default_headers(
|
||||
runtime.get("default_headers"),
|
||||
user_agent=user_agent_value,
|
||||
)
|
||||
thinking_kwargs = cls._build_thinking_kwargs(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
thinking_level=normalized_thinking_level,
|
||||
)
|
||||
use_responses_api = cls._should_use_openai_responses_api(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
runtime=runtime,
|
||||
)
|
||||
llm_proxy = _resolve_llm_proxy(use_proxy)
|
||||
|
||||
if runtime["runtime"] == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
@@ -576,18 +921,13 @@ class LLMHelper:
|
||||
# 会导致工具调用时报错 400
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
client_args = None
|
||||
if settings.PROXY_HOST:
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
client_args = {proxy_key: settings.PROXY_HOST}
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=model_name,
|
||||
api_key=runtime["api_key"],
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
client_args=_build_google_client_args(llm_proxy),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif runtime["runtime"] == "deepseek":
|
||||
@@ -602,6 +942,8 @@ class LLMHelper:
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
http_client=_build_httpx_client(llm_proxy),
|
||||
http_async_client=_build_httpx_client(llm_proxy, async_client=True),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif runtime["runtime"] in {"anthropic_compatible", "copilot_anthropic"}:
|
||||
@@ -615,8 +957,8 @@ class LLMHelper:
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
anthropic_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
anthropic_proxy=llm_proxy,
|
||||
default_headers=default_headers,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -636,9 +978,17 @@ class LLMHelper:
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
use_responses_api=runtime.get("use_responses_api"),
|
||||
openai_proxy=llm_proxy,
|
||||
**(
|
||||
{}
|
||||
if llm_proxy
|
||||
else {
|
||||
"http_client": _build_httpx_client(llm_proxy),
|
||||
"http_async_client": _build_httpx_client(llm_proxy, async_client=True),
|
||||
}
|
||||
),
|
||||
default_headers=default_headers,
|
||||
use_responses_api=use_responses_api,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
|
||||
@@ -713,6 +1063,8 @@ class LLMHelper:
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
use_proxy: bool | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
@@ -728,6 +1080,8 @@ class LLMHelper:
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset=base_url_preset,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
try:
|
||||
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
||||
@@ -758,6 +1112,8 @@ class LLMHelper:
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
use_proxy: bool | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""
|
||||
@@ -775,6 +1131,8 @@ class LLMHelper:
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -782,9 +1140,11 @@ class LLMHelper:
|
||||
if provider == "google":
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_google_models(api_key or "")
|
||||
for model_id in await self._get_google_models(
|
||||
api_key or "",
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
]
|
||||
model_list_base_url = base_url
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
@@ -804,24 +1164,24 @@ class LLMHelper:
|
||||
provider,
|
||||
api_key or "",
|
||||
model_list_base_url,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def _get_google_models(api_key: str) -> List[str]:
|
||||
async def _get_google_models(api_key: str, use_proxy: bool | None = None) -> List[str]:
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import HttpOptions
|
||||
|
||||
http_options = None
|
||||
if settings.PROXY_HOST:
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
proxy_args = {proxy_key: settings.PROXY_HOST}
|
||||
http_options = HttpOptions(
|
||||
client_args=proxy_args,
|
||||
async_client_args=proxy_args,
|
||||
)
|
||||
llm_proxy = _resolve_llm_proxy(use_proxy)
|
||||
google_client_args = _build_google_client_args(llm_proxy)
|
||||
http_options = HttpOptions(
|
||||
client_args=google_client_args,
|
||||
async_client_args=google_client_args,
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
models = await client.aio.models.list()
|
||||
@@ -838,7 +1198,11 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
async def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: str = None,
|
||||
user_agent: str | None = None,
|
||||
use_proxy: bool | None = None,
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
@@ -847,7 +1211,19 @@ class LLMHelper:
|
||||
if provider == "deepseek":
|
||||
base_url = base_url or "https://api.deepseek.com"
|
||||
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_headers=LLMHelper._build_openai_default_headers(
|
||||
None,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
http_client=_build_httpx_client(
|
||||
_resolve_llm_proxy(use_proxy),
|
||||
async_client=True,
|
||||
timeout=15.0,
|
||||
),
|
||||
)
|
||||
models = await client.models.list()
|
||||
await client.close()
|
||||
return [model.id for model in models.data]
|
||||
|
||||
@@ -105,6 +105,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
_MODELS_DEV_URL = "https://models.dev/api.json"
|
||||
_MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json")
|
||||
_MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60
|
||||
_AUTH_SESSION_DONE_RETENTION = 300
|
||||
_CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
_CHATGPT_ISSUER = "https://auth.openai.com"
|
||||
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
@@ -183,6 +184,33 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json"
|
||||
)
|
||||
|
||||
def _cleanup_auth_sessions_locked(self, now: Optional[float] = None) -> None:
|
||||
"""
|
||||
清理过期或已完成一段时间的临时授权会话。
|
||||
|
||||
调用方必须已经持有 `_lock`,这样 `_pending_sessions` 与
|
||||
`_oauth_state_index` 能保持一致,避免 state 残留。
|
||||
"""
|
||||
now = time.time() if now is None else now
|
||||
expired_session_ids = []
|
||||
for session_id, session in self._pending_sessions.items():
|
||||
expires_at = session.expires_at or session.created_at + 600
|
||||
if session.status == "pending":
|
||||
if expires_at <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
elif expires_at + self._AUTH_SESSION_DONE_RETENTION <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
|
||||
if not expired_session_ids:
|
||||
return
|
||||
|
||||
expired_session_ids_set = set(expired_session_ids)
|
||||
for session_id in expired_session_ids:
|
||||
self._pending_sessions.pop(session_id, None)
|
||||
for state, session_id in list(self._oauth_state_index.items()):
|
||||
if session_id in expired_session_ids_set:
|
||||
self._oauth_state_index.pop(state, None)
|
||||
|
||||
@staticmethod
|
||||
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
|
||||
"""
|
||||
@@ -672,6 +700,88 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="腾讯兼容端点。",
|
||||
sort_order=170,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="china-unicom",
|
||||
name="中国联通",
|
||||
runtime="openai_compatible",
|
||||
default_base_url="https://aigw-gzgy2.cucloud.cn:8443/v1",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
id="china-unicom-coding-openai",
|
||||
label="Coding Plan / OpenAI",
|
||||
value="https://aigw-gzgy2.cucloud.cn:8443/v1",
|
||||
model_list_strategy="manual",
|
||||
),
|
||||
url_preset(
|
||||
id="china-unicom-coding-anthropic",
|
||||
label="Coding Plan / Anthropic",
|
||||
value="https://aigw-gzgy2.cucloud.cn:8443",
|
||||
runtime="anthropic_compatible",
|
||||
model_list_strategy="manual",
|
||||
),
|
||||
),
|
||||
base_url_editable=True,
|
||||
api_key_hint="填写联通云 AISP / Coding Plan 专属 API Key;模型名称请按控制台可用模型 ID 手动填写。",
|
||||
supports_model_refresh=False,
|
||||
model_list_strategy="manual",
|
||||
description="联通云 AISP Coding Plan 兼容端点,支持 OpenAI 与 Anthropic 协议地址预设。",
|
||||
sort_order=172,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="china-mobile",
|
||||
name="中国移动",
|
||||
runtime="openai_compatible",
|
||||
default_base_url="https://ecloud.10086.cn/api",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
id="china-mobile-moma",
|
||||
label="MoMA / 移动云",
|
||||
value="https://ecloud.10086.cn/api",
|
||||
),
|
||||
url_preset(
|
||||
id="china-mobile-coding",
|
||||
label="Coding Plan / 移动智算包",
|
||||
value="https://zhenze-huhehaote.cmecloud.cn/api/coding/v1",
|
||||
),
|
||||
),
|
||||
base_url_editable=True,
|
||||
api_key_hint="填写中国移动 MoMA / 移动云 Token 服务 API Key;如控制台下发专属域名,请覆盖 Base URL。",
|
||||
supports_model_refresh=False,
|
||||
model_list_strategy="manual",
|
||||
description="中国移动 MoMA / 移动云 OpenAI-compatible Token 服务,支持专属域名覆盖。",
|
||||
sort_order=174,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="china-telecom",
|
||||
name="中国电信",
|
||||
runtime="openai_compatible",
|
||||
default_base_url="https://wishub-x6.ctyun.cn/v1",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
id="china-telecom-token-service",
|
||||
label="Token 服务 / 息壤",
|
||||
value="https://wishub-x6.ctyun.cn/v1",
|
||||
),
|
||||
url_preset(
|
||||
id="china-telecom-coding-openai",
|
||||
label="编码套餐 / OpenAI",
|
||||
value="https://wishub-x6.ctyun.cn/coding/v1",
|
||||
model_list_strategy="manual",
|
||||
),
|
||||
url_preset(
|
||||
id="china-telecom-coding-anthropic",
|
||||
label="编码套餐 / Anthropic",
|
||||
value="https://wishub-x6.ctyun.cn/coding/v1",
|
||||
runtime="anthropic_compatible",
|
||||
model_list_strategy="manual",
|
||||
),
|
||||
),
|
||||
base_url_editable=True,
|
||||
api_key_label="App Key",
|
||||
api_key_hint="填写天翼云 Token 服务 / 息壤 App Key;编码套餐模型请按控制台展示的模型 ID 手动填写。",
|
||||
description="天翼云 Token 服务(原模型推理服务)OpenAI-compatible 端点,支持通用与编码套餐地址预设。",
|
||||
sort_order=176,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="ollama-cloud",
|
||||
name="Ollama Cloud",
|
||||
@@ -975,14 +1085,20 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return builtin_specs + self._dynamic_provider_specs(builtin_specs)
|
||||
|
||||
async def _get_provider_async(
|
||||
self, provider_id: str, force_refresh: bool = False
|
||||
self,
|
||||
provider_id: str,
|
||||
force_refresh: bool = False,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> ProviderSpec:
|
||||
"""异步获取指定 provider 的 ProviderSpec 实例。"""
|
||||
normalized_provider_id = self._normalize_provider_id(provider_id)
|
||||
try:
|
||||
return self.get_provider(normalized_provider_id)
|
||||
except LLMProviderError:
|
||||
await self.get_models_dev_data(force_refresh=force_refresh)
|
||||
await self.get_models_dev_data(
|
||||
force_refresh=force_refresh,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
return self.get_provider(normalized_provider_id)
|
||||
|
||||
def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]:
|
||||
@@ -1022,11 +1138,16 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
|
||||
async def list_providers_async(
|
||||
self, force_refresh: bool = False
|
||||
self,
|
||||
force_refresh: bool = False,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""返回前端可渲染的 provider 目录,并优先补齐 models.dev 动态平台。"""
|
||||
try:
|
||||
await self.get_models_dev_data(force_refresh=force_refresh)
|
||||
await self.get_models_dev_data(
|
||||
force_refresh=force_refresh,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"加载 models.dev provider 目录失败,回退内置列表: {err}")
|
||||
return self.list_providers()
|
||||
@@ -1056,6 +1177,23 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return None
|
||||
return value.rstrip("/")
|
||||
|
||||
@staticmethod
|
||||
def _merge_user_agent_header(
|
||||
default_headers: Optional[dict[str, str]],
|
||||
user_agent: Optional[str],
|
||||
) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
合并用户配置的 OpenAI 兼容接口 User-Agent 请求头。
|
||||
"""
|
||||
headers = dict(default_headers or {})
|
||||
normalized_user_agent = str(user_agent or "").strip()
|
||||
if normalized_user_agent:
|
||||
for key in list(headers.keys()):
|
||||
if key.lower() == "user-agent":
|
||||
headers.pop(key)
|
||||
headers["User-Agent"] = normalized_user_agent
|
||||
return headers or None
|
||||
|
||||
@classmethod
|
||||
def _default_base_url_for_provider(cls, spec: ProviderSpec) -> Optional[str]:
|
||||
"""获取 provider 的默认 Base URL。"""
|
||||
@@ -1200,10 +1338,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
params = httpx.Client.__init__.__code__.co_varnames
|
||||
return "proxy" if "proxy" in params else "proxies"
|
||||
|
||||
def _build_httpx_kwargs(self) -> dict[str, Any]:
|
||||
def _build_httpx_kwargs(self, use_proxy: Optional[bool] = None) -> dict[str, Any]:
|
||||
"""构造用于 httpx 客户端的参数,如代理等。"""
|
||||
kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT}
|
||||
if settings.PROXY_HOST:
|
||||
should_use_proxy = settings.LLM_USE_PROXY if use_proxy is None else use_proxy
|
||||
kwargs: dict[str, Any] = {
|
||||
"timeout": self._DEFAULT_TIMEOUT,
|
||||
"trust_env": False,
|
||||
}
|
||||
if should_use_proxy and settings.PROXY_HOST:
|
||||
kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST
|
||||
return kwargs
|
||||
|
||||
@@ -1314,15 +1456,19 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
except Exception as err:
|
||||
logger.warning(f"写入 models.dev 缓存失败: {err}")
|
||||
|
||||
async def _fetch_models_dev(self) -> dict[str, Any]:
|
||||
async def _fetch_models_dev(self, use_proxy: Optional[bool] = None) -> dict[str, Any]:
|
||||
"""通过网络请求获取最新 models.dev 数据。"""
|
||||
headers = {"User-Agent": "MoviePilot/1.0"}
|
||||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||||
headers = {"User-Agent": settings.USER_AGENT}
|
||||
async with httpx.AsyncClient(**self._build_httpx_kwargs(use_proxy)) as client:
|
||||
response = await client.get(self._MODELS_DEV_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_models_dev_data(self, force_refresh: bool = False) -> dict[str, Any]:
|
||||
async def get_models_dev_data(
|
||||
self,
|
||||
force_refresh: bool = False,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
返回 models.dev 原始数据。
|
||||
|
||||
@@ -1348,7 +1494,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return cached
|
||||
|
||||
try:
|
||||
payload = await self._fetch_models_dev()
|
||||
payload = await self._fetch_models_dev(use_proxy=use_proxy)
|
||||
self._models_dev_data = payload
|
||||
self._models_dev_loaded_at = now
|
||||
await self._write_models_dev_to_disk(payload)
|
||||
@@ -1372,9 +1518,13 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id: str,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""获取指定 provider 在 models.dev 中的完整负载。"""
|
||||
spec = await self._get_provider_async(provider_id)
|
||||
spec = await self._get_provider_async(
|
||||
provider_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
|
||||
spec,
|
||||
base_url,
|
||||
@@ -1382,7 +1532,9 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
if not models_dev_provider_id:
|
||||
return {}
|
||||
return (await self.get_models_dev_data()).get(models_dev_provider_id, {}) or {}
|
||||
return (
|
||||
await self.get_models_dev_data(use_proxy=use_proxy)
|
||||
).get(models_dev_provider_id, {}) or {}
|
||||
|
||||
async def _models_dev_model(
|
||||
self,
|
||||
@@ -1390,12 +1542,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
model_id: str,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""获取指定模型的 models.dev 元数据。"""
|
||||
payload = await self._models_dev_provider_payload(
|
||||
provider_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
@@ -1494,19 +1648,23 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return normalized[:-3]
|
||||
return normalized
|
||||
|
||||
async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]:
|
||||
async def _list_models_from_google(
|
||||
self,
|
||||
api_key: str,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从 Google AI Studio 获取模型列表。"""
|
||||
from google import genai
|
||||
from google.genai.types import HttpOptions
|
||||
|
||||
http_options = None
|
||||
if settings.PROXY_HOST:
|
||||
proxy_key = self._httpx_proxy_key()
|
||||
proxy_args = {proxy_key: settings.PROXY_HOST}
|
||||
http_options = HttpOptions(
|
||||
client_args=proxy_args,
|
||||
async_client_args=proxy_args,
|
||||
)
|
||||
should_use_proxy = settings.LLM_USE_PROXY if use_proxy is None else use_proxy
|
||||
client_args: dict[str, Any] = {"trust_env": False}
|
||||
if should_use_proxy and settings.PROXY_HOST:
|
||||
client_args[self._httpx_proxy_key()] = settings.PROXY_HOST
|
||||
http_options = HttpOptions(
|
||||
client_args=client_args,
|
||||
async_client_args=client_args,
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
response = await client.aio.models.list()
|
||||
@@ -1516,7 +1674,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
if "generateContent" not in supported:
|
||||
continue
|
||||
model_id = model.name
|
||||
metadata = await self._models_dev_model("google", model_id) or {}
|
||||
metadata = await self._models_dev_model(
|
||||
"google",
|
||||
model_id,
|
||||
use_proxy=use_proxy,
|
||||
) or {}
|
||||
results.append(
|
||||
self._normalize_model_record(
|
||||
model_id=model_id,
|
||||
@@ -1533,6 +1695,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
default_headers: Optional[dict[str, str]] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""通过 OpenAI 兼容接口获取模型列表。"""
|
||||
from openai import AsyncOpenAI
|
||||
@@ -1543,6 +1706,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
default_headers=default_headers,
|
||||
timeout=15.0,
|
||||
max_retries=2,
|
||||
http_client=httpx.AsyncClient(**self._build_httpx_kwargs(use_proxy)),
|
||||
)
|
||||
results = []
|
||||
response = await client.models.list()
|
||||
@@ -1551,6 +1715,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id,
|
||||
model.id,
|
||||
base_url=base_url,
|
||||
use_proxy=use_proxy,
|
||||
) or {}
|
||||
results.append(
|
||||
self._normalize_model_record(
|
||||
@@ -1568,6 +1733,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
transport: str = "openai",
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
某些 provider 没有统一稳定的 models.list 行为,
|
||||
@@ -1578,6 +1744,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
@@ -1606,7 +1773,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
仅补充 Copilot 必需的意图头,避免重复覆盖。
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"x-initiator": "user",
|
||||
}
|
||||
@@ -1614,9 +1781,13 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]:
|
||||
async def _list_models_from_copilot(
|
||||
self,
|
||||
token: str,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从 GitHub Copilot 端点获取模型列表。"""
|
||||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||||
async with httpx.AsyncClient(**self._build_httpx_kwargs(use_proxy)) as client:
|
||||
response = await client.get(
|
||||
"https://api.githubcopilot.com/models",
|
||||
headers=self._copilot_headers(token),
|
||||
@@ -1653,7 +1824,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
|
||||
limits = ((item.get("capabilities") or {}).get("limits") or {})
|
||||
supports = ((item.get("capabilities") or {}).get("supports") or {})
|
||||
metadata = await self._models_dev_model("github-copilot", model_id) or {}
|
||||
metadata = await self._models_dev_model(
|
||||
"github-copilot",
|
||||
model_id,
|
||||
use_proxy=use_proxy,
|
||||
) or {}
|
||||
results.append(
|
||||
self._normalize_model_record(
|
||||
model_id=model_id,
|
||||
@@ -1684,6 +1859,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id: str,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取开启 OAuth 的 ChatGPT 模型列表。"""
|
||||
# ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护
|
||||
@@ -1692,6 +1868,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
@@ -1715,10 +1892,16 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
force_refresh: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""返回标准化后的模型目录。"""
|
||||
spec = await self._get_provider_async(provider_id, force_refresh=force_refresh)
|
||||
spec = await self._get_provider_async(
|
||||
provider_id,
|
||||
force_refresh=force_refresh,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
resolved_model_list_strategy = self._resolve_provider_model_list_strategy(
|
||||
spec,
|
||||
base_url,
|
||||
@@ -1732,7 +1915,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
# 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表”
|
||||
# 在使用目录型 provider 时也能拿到最新参数。
|
||||
if force_refresh:
|
||||
await self.get_models_dev_data(force_refresh=True)
|
||||
await self.get_models_dev_data(
|
||||
force_refresh=True,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "manual":
|
||||
# 万擎等推理点型平台没有稳定的全局模型目录,模型 ID 需要用户从控制台复制。
|
||||
@@ -1744,13 +1930,21 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "google":
|
||||
return await self._list_models_from_google(runtime["api_key"])
|
||||
return await self._list_models_from_google(
|
||||
runtime["api_key"],
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "github_copilot":
|
||||
return await self._list_models_from_copilot(runtime["api_key"])
|
||||
return await self._list_models_from_copilot(
|
||||
runtime["api_key"],
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "chatgpt":
|
||||
if runtime.get("auth_mode") == "oauth":
|
||||
@@ -1758,6 +1952,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
provider_id=provider_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
return await self._list_models_from_openai_compatible(
|
||||
provider_id="chatgpt",
|
||||
@@ -1767,7 +1962,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
runtime["base_url"],
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
),
|
||||
default_headers=runtime.get("default_headers"),
|
||||
default_headers=self._merge_user_agent_header(
|
||||
runtime.get("default_headers"),
|
||||
user_agent,
|
||||
),
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "anthropic_compatible":
|
||||
@@ -1776,6 +1975,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
transport="anthropic",
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if resolved_model_list_strategy == "models_dev_only":
|
||||
@@ -1784,6 +1984,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
transport="openai",
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
# openai-compatible / deepseek 默认走官方 models 端点。
|
||||
@@ -1795,7 +1996,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
runtime["base_url"],
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
),
|
||||
default_headers=runtime.get("default_headers"),
|
||||
default_headers=self._merge_user_agent_header(
|
||||
runtime.get("default_headers"),
|
||||
user_agent,
|
||||
),
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
async def resolve_model_metadata(
|
||||
@@ -1804,6 +2009,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
model_id: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""解析并返回指定模型在 models.dev 中的元数据。"""
|
||||
if not model_id:
|
||||
@@ -1813,13 +2019,18 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
model_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
if metadata:
|
||||
return metadata
|
||||
if provider_id == "chatgpt":
|
||||
return await self._models_dev_model("openai", model_id)
|
||||
return await self._models_dev_model(
|
||||
"openai",
|
||||
model_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
if provider_id == "openai":
|
||||
models_dev = await self.get_models_dev_data()
|
||||
models_dev = await self.get_models_dev_data(use_proxy=use_proxy)
|
||||
return models_dev.get("openai", {}).get("models", {}).get(model_id)
|
||||
return None
|
||||
|
||||
@@ -1919,6 +2130,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
self._oauth_state_index[state] = session.session_id
|
||||
return {
|
||||
@@ -1935,7 +2147,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/usercode",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={"client_id": self._CHATGPT_CLIENT_ID},
|
||||
)
|
||||
@@ -1953,6 +2165,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -1971,7 +2184,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"client_id": self._COPILOT_CLIENT_ID,
|
||||
@@ -1991,6 +2204,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -2007,6 +2221,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""读取临时授权会话状态。"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
@@ -2053,6 +2268,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
if error:
|
||||
message = error_description or error
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state or "", None)
|
||||
if session_id and session_id in self._pending_sessions:
|
||||
self._mark_session_error(self._pending_sessions[session_id], message)
|
||||
@@ -2062,6 +2278,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return False, "缺少授权码或 state 参数"
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state, None)
|
||||
session = self._pending_sessions.get(session_id or "")
|
||||
|
||||
@@ -2104,6 +2321,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
@@ -2162,7 +2380,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"device_auth_id": session.context["device_auth_id"],
|
||||
@@ -2207,7 +2425,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MoviePilot/1.0",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
},
|
||||
json={
|
||||
"client_id": self._COPILOT_CLIENT_ID,
|
||||
@@ -2281,6 +2499,8 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
解析 provider 运行时参数。
|
||||
@@ -2292,7 +2512,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
normalized_provider_id,
|
||||
base_url_preset_id,
|
||||
)
|
||||
spec = await self._get_provider_async(normalized_provider_id)
|
||||
spec = await self._get_provider_async(
|
||||
normalized_provider_id,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
resolved_runtime = self._resolve_provider_runtime(
|
||||
spec,
|
||||
base_url,
|
||||
@@ -2311,6 +2534,8 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
if item["id"] == model
|
||||
),
|
||||
@@ -2330,6 +2555,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
model,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
),
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
@@ -2353,7 +2579,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"runtime": "chatgpt",
|
||||
"api_key": auth["access_token"],
|
||||
"base_url": self._CHATGPT_CODEX_BASE_URL,
|
||||
"default_headers": headers,
|
||||
"default_headers": self._merge_user_agent_header(
|
||||
headers,
|
||||
user_agent,
|
||||
),
|
||||
"use_responses_api": True,
|
||||
"auth_mode": "oauth",
|
||||
}
|
||||
@@ -2367,6 +2596,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"api_key": normalized_api_key,
|
||||
"base_url": normalized_base_url
|
||||
or self._default_base_url_for_provider(spec),
|
||||
"default_headers": self._merge_user_agent_header(
|
||||
None,
|
||||
user_agent,
|
||||
),
|
||||
"auth_mode": "api_key",
|
||||
}
|
||||
)
|
||||
@@ -2391,9 +2624,12 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
else "github_copilot",
|
||||
"api_key": token,
|
||||
"base_url": "https://api.githubcopilot.com",
|
||||
"default_headers": self._copilot_headers(
|
||||
token,
|
||||
include_auth=transport == "anthropic",
|
||||
"default_headers": self._merge_user_agent_header(
|
||||
self._copilot_headers(
|
||||
token,
|
||||
include_auth=transport == "anthropic",
|
||||
),
|
||||
user_agent,
|
||||
),
|
||||
"auth_mode": "oauth" if auth else "api_key",
|
||||
}
|
||||
@@ -2426,6 +2662,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"base_url": self._normalize_base_url_for_anthropic(
|
||||
effective_base_url
|
||||
),
|
||||
"default_headers": self._merge_user_agent_header(
|
||||
None,
|
||||
user_agent,
|
||||
),
|
||||
"auth_mode": "api_key",
|
||||
}
|
||||
)
|
||||
@@ -2440,6 +2680,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
{
|
||||
"api_key": normalized_api_key,
|
||||
"base_url": effective_base_url,
|
||||
"default_headers": self._merge_user_agent_header(None, user_agent),
|
||||
"auth_mode": "api_key",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -27,6 +27,8 @@ class MemoryManager:
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
return
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
@@ -46,6 +48,7 @@ class MemoryManager:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.cleanup_task = None
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
@@ -9,35 +9,65 @@ from langgraph.types import Overwrite
|
||||
class PatchToolCallsMiddleware(AgentMiddleware):
|
||||
"""修复消息历史中悬空工具调用的中间件。"""
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""在代理运行之前,处理任何 AIMessage 中悬空的工具调用。"""
|
||||
messages = state["messages"]
|
||||
@staticmethod
|
||||
def _build_cancelled_tool_message(tool_call: dict[str, Any]) -> ToolMessage:
|
||||
"""构造取消状态的工具响应消息。"""
|
||||
tool_name = tool_call.get("name") or "unknown_tool"
|
||||
tool_call_id = tool_call.get("id") or ""
|
||||
tool_msg = (
|
||||
f"Tool call {tool_name} with id {tool_call_id} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
return ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_messages(cls, messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""规范化工具调用消息顺序,满足 OpenAI tool_calls 协议要求。"""
|
||||
if not messages or len(messages) == 0:
|
||||
return messages
|
||||
|
||||
tool_messages = {
|
||||
msg.tool_call_id: msg
|
||||
for msg in messages
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id
|
||||
}
|
||||
patched_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
continue
|
||||
|
||||
patched_messages.append(msg)
|
||||
if not isinstance(msg, AIMessage) or not msg.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_call in msg.tool_calls:
|
||||
tool_call_id = tool_call.get("id")
|
||||
corresponding_tool_msg = tool_messages.get(tool_call_id)
|
||||
if corresponding_tool_msg:
|
||||
patched_messages.append(corresponding_tool_msg)
|
||||
else:
|
||||
patched_messages.append(cls._build_cancelled_tool_message(tool_call))
|
||||
|
||||
return patched_messages
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> Optional[dict[str, Any]]: # noqa: ARG002
|
||||
"""在代理运行之前,处理任何 AIMessage 中悬空或乱序的工具调用。"""
|
||||
messages = state["messages"]
|
||||
patched_messages = self._normalize_messages(messages)
|
||||
if patched_messages == messages:
|
||||
return None
|
||||
|
||||
patched_messages = []
|
||||
# 遍历消息并添加任何悬空的工具调用
|
||||
for i, msg in enumerate(messages):
|
||||
patched_messages.append(msg)
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
corresponding_tool_msg = next(
|
||||
(msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]),
|
||||
# ty: ignore[unresolved-attribute]
|
||||
None,
|
||||
)
|
||||
if corresponding_tool_msg is None:
|
||||
# 我们有一个悬空的工具调用,需要一个 ToolMessage
|
||||
tool_msg = (
|
||||
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
patched_messages.append(
|
||||
ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return {"messages": Overwrite(patched_messages)}
|
||||
|
||||
async def abefore_agent(self, state: AgentState, runtime: Runtime[Any]) -> Optional[dict[str, Any]]: # noqa: ARG002
|
||||
"""在代理异步运行之前,处理任何 AIMessage 中悬空或乱序的工具调用。"""
|
||||
messages = state["messages"]
|
||||
patched_messages = self._normalize_messages(messages)
|
||||
if patched_messages == messages:
|
||||
return None
|
||||
|
||||
return {"messages": Overwrite(patched_messages)}
|
||||
|
||||
@@ -157,7 +157,7 @@ def _parse_skill_metadata( # noqa: C901
|
||||
MAX_SKILL_COMPATIBILITY_LENGTH,
|
||||
skill_path,
|
||||
)
|
||||
compatibility_str = compatibility_str[:MAX_SKILL_COMPATIBILITY_LENGTH]
|
||||
compatibility_str = str(compatibility_str)[:MAX_SKILL_COMPATIBILITY_LENGTH]
|
||||
|
||||
# 版本号,默认为 0(表示未设置版本)
|
||||
raw_version = frontmatter_data.get("version")
|
||||
|
||||
1093
app/agent/middleware/subagents.py
Normal file
1093
app/agent/middleware/subagents.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,11 +2,9 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||||
from typing import Annotated, Any, NotRequired
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
@@ -16,78 +14,18 @@ from langchain.agents.middleware.types import (
|
||||
from langchain.agents.middleware.types import (
|
||||
PrivateStateAttr, # noqa
|
||||
)
|
||||
from langchain.agents.middleware.tool_selection import (
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.log import logger
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)]
|
||||
for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
@@ -102,9 +40,7 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
@@ -129,94 +65,19 @@ class ToolSelectorMiddleware(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
model: BaseChatModel | str | None = None,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [
|
||||
tool for tool in base_tools if tool.name not in self.always_include
|
||||
]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
super().__init__(
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
system_prompt=system_prompt,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include,
|
||||
)
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
@@ -225,46 +86,29 @@ class ToolSelectorMiddleware(
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
if selected_tool_names:
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
else:
|
||||
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
|
||||
"""
|
||||
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
|
||||
"""
|
||||
if response.get("tools") == []:
|
||||
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
|
||||
selected_tools = available_tools
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
return request.override(
|
||||
tools=[*available_tools, *always_included_tools, *provider_tools]
|
||||
)
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
return super()._process_selection_response(
|
||||
response,
|
||||
available_tools,
|
||||
valid_tool_names,
|
||||
request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
|
||||
@@ -5,59 +5,70 @@ All your responses must be in **Chinese (中文)**.
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
<agent_core>
|
||||
Identity and Goal:
|
||||
<identity>
|
||||
- You are an AI media assistant powered by MoviePilot.
|
||||
- Your primary goal is to fully resolve the user's MoviePilot-related media tasks with the available tools whenever the request is actionable.
|
||||
- Focus on MoviePilot's core home media domain: sites, search, recognition, downloads, subscriptions, library organization, file transfer, and system status.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
</identity>
|
||||
|
||||
<non_negotiable_boundaries>
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- Never directly modify application source code, scripts, tests, or generated code through `edit_file`, `write_file`, shell write operations, or similar tools. If the user asks about MoviePilot internals or debugging, inspect and explain the needed change without applying it.
|
||||
- If the user explicitly asks to change the speaking style or persona, use `query_personas` and `switch_persona` instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Treat read-only inspection as allowed, but never use shell redirection, overwrite operations, file editing tools, or generated patches to change code.
|
||||
</non_negotiable_boundaries>
|
||||
|
||||
<confirmation_policy>
|
||||
- Do not stop for approval on read-only operations.
|
||||
- If the user has not explicitly requested an operation that changes system behavior, ask for confirmation before proceeding. This includes modifying system settings, updating plugin configuration, reloading plugins, running restart/stop/start commands, or triggering slash commands such as `/restart`.
|
||||
- Always get explicit consent before destructive or high-impact actions such as starting downloads, deleting subscriptions, deleting download tasks or files, removing history, installing/uninstalling plugins, changing site authentication, changing scheduler or workflow execution state, restarting services, or stopping services.
|
||||
- If the user explicitly requested the exact write action, perform the smallest correct change and then validate the result.
|
||||
- If a requested action is ambiguous between read-only inspection and state change, inspect first and ask a short confirmation question before the state-changing step.
|
||||
</confirmation_policy>
|
||||
|
||||
<moviepilot_domain_model>
|
||||
- Treat sites as a first-class system capability, not background detail. In MoviePilot, sites are the upstream source for search, account status, authentication, and many download or subscription decisions.
|
||||
- Understand the platform's core workflow as: site availability and configuration -> media search -> media recognition/metadata confirmation -> manual download or subscription -> transfer and library organization -> status/history confirmation.
|
||||
- Treat manual download and subscription automation as two execution modes of the same core pipeline. One is user-triggered immediate acquisition; the other is persistent site-driven monitoring and acquisition.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
- Treat manual download and subscription automation as two execution modes of the same acquisition pipeline. Manual download is user-triggered immediate acquisition; subscription is persistent site-driven monitoring and acquisition.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, transfer, or status/history.
|
||||
- Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
- User messages may arrive as structured JSON. Treat the `message` field as the user's text. Input metadata appears in `input`; when `input.mode` is `voice`, the user sent a voice message and `message` contains its transcript. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
</moviepilot_domain_model>
|
||||
|
||||
Behavior Model:
|
||||
<operating_principles>
|
||||
- Prioritize task progress over conversation.
|
||||
- Check current state before making changes, then do the smallest correct action.
|
||||
- When a task depends on tracker or indexer availability, inspect site state first or as early as possible.
|
||||
- Do not stop for approval on read-only operations. Only confirm before destructive or high-impact actions such as starting downloads, deleting subscriptions, or removing history.
|
||||
- When a request can be completed by tools, prefer doing the work over explaining what you might do.
|
||||
- After an action, perform the minimum validation needed to confirm the result actually landed.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, or transfer.
|
||||
- If the user explicitly asks to change the speaking style or persona, use the dedicated persona tools instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
</operating_principles>
|
||||
|
||||
Core Capabilities:
|
||||
1. Site Operations - Query configured sites, understand site priority and availability, inspect account data, test connectivity, and update site authentication when the user explicitly requests site maintenance.
|
||||
2. Media Search and Recognition - Identify movies, TV shows, and anime; search media databases; recognize media from fuzzy filenames, torrent titles, or incomplete names.
|
||||
3. Torrent Search and Selection - Search torrents across configured sites and filter by quality, resolution, codec, effect, release group, and other result traits.
|
||||
4. Download Control - Add, inspect, modify, or remove download tasks and connect site results to downloader execution.
|
||||
5. Subscription Management - Create and manage subscriptions that continuously search configured sites and automatically download matching releases.
|
||||
6. Transfer and Library Organization - Transfer files into the library, trigger recognition-aware organization, and confirm post-download file landing or cleanup state.
|
||||
7. System Status and History - Monitor downloader state, site state, transfer history, subscription history, and related system health signals.
|
||||
8. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
9. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
10. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
|
||||
Core Workflow:
|
||||
<core_workflow>
|
||||
1. Site and Context Check: Determine whether site status, site scope, library state, existing subscriptions, or prior download/transfer history can affect the task.
|
||||
2. Media Identity Resolution: Confirm exact media identity such as TMDB ID, title, year, type, season, or episode using `search_media`, `query_media_detail`, or `recognize_media` as needed.
|
||||
3. Resource Discovery: Use the appropriate search path for the task. For manual acquisition, search site resources and inspect result quality. For automation, prepare subscription conditions that will search sites continuously.
|
||||
4. Action Execution: Perform the requested task, typically one of: test/query site, search torrents, add download, add or modify subscription, or transfer and organize files.
|
||||
5. Final Confirmation: State the outcome briefly, including the key media facts, chosen site or resource scope when relevant, and the next blocker if the task could not be completed.
|
||||
</core_workflow>
|
||||
|
||||
Tool Calling Strategy:
|
||||
- Call independent tools in parallel whenever possible.
|
||||
<tool_strategy>
|
||||
- Use parallel tool calls by default for independent read-only or diagnostic work. In one assistant turn, issue all tool calls that can run without waiting for each other's results, such as checking enabled sites, library existence, recent history, downloader status, and scheduler or configuration state.
|
||||
- Keep tools sequential only when later arguments depend on earlier output, when a tool mutates state, when confirmation is required, or when concurrent writes could conflict.
|
||||
- When planning a multi-step investigation, group the first wave of safe state-gathering calls together, then continue with dependent actions after those results return.
|
||||
- Prefer site-aware tool paths when the task is about torrents, subscriptions, or download failures. `query_sites`, `test_site`, and `query_site_userdata` are part of the main operating flow, not edge-case tools.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- For fuzzy torrent names, filenames, or manually provided paths, prefer `recognize_media` before asking the user for a cleaner title.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when automated paths are exhausted.
|
||||
- If torrent search yields no useful result, check site scope, site health, and recognition quality before concluding that the resource is unavailable.
|
||||
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
- Use `execute_command` only for diagnostics, read-only inspection, or commands the user explicitly asked to run. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process.
|
||||
</tool_strategy>
|
||||
|
||||
Media Management Rules:
|
||||
<media_rules>
|
||||
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.
|
||||
2. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
3. Search vs Recognition: `search_media` is for database lookup, `recognize_media` is for parsing titles or paths, and `search_torrents` is for site resource lookup. Do not confuse these roles.
|
||||
@@ -66,6 +77,7 @@ Media Management Rules:
|
||||
6. Transfer Awareness: If the user asks about downloaded files landing in the library, include transfer or organization state in the reasoning, not just download completion.
|
||||
7. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative or the next best operational step.
|
||||
8. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</media_rules>
|
||||
</agent_core>
|
||||
|
||||
<communication_runtime>
|
||||
|
||||
@@ -14,7 +14,11 @@ task_types:
|
||||
- "For 'recurring' jobs, check 'last_run' to determine if it's time to run again."
|
||||
- "For 'once' jobs with status 'pending', execute them now."
|
||||
- "After executing each job, update its status, 'last_run' time, and execution log in the JOB.md file."
|
||||
- "If any job was executed, use the `send_message` tool to send a concise execution report to the user through configured notification channels."
|
||||
empty_result: "If no jobs were executed, output nothing."
|
||||
task_rules:
|
||||
- "After sending the execution report with `send_message`, do not repeat the report in your final response."
|
||||
- "Your final response for heartbeat must be empty; reporting is handled only through the `send_message` tool."
|
||||
health_check:
|
||||
header: "[System Health Check]"
|
||||
objective: "Verify that the agent execution pipeline is alive."
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import shutil
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -9,6 +10,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import (
|
||||
@@ -21,6 +23,31 @@ from app.utils.system import SystemUtils
|
||||
|
||||
SYSTEM_TASKS_FILE = "System Tasks.yaml"
|
||||
SYSTEM_TASKS_SCHEMA_VERSION = 2
|
||||
COMMON_SHELL_COMMANDS = (
|
||||
# 只探测会明显改变 Agent 执行策略的可选能力。基础命令、语言运行时、
|
||||
# 包管理器、服务管理器和数据库客户端默认不做启动探测,减少 which 扫描量。
|
||||
"ssh",
|
||||
"scp",
|
||||
"sftp",
|
||||
"git",
|
||||
"gh",
|
||||
"rg",
|
||||
"fd",
|
||||
"jq",
|
||||
"yq",
|
||||
"curl",
|
||||
"wget",
|
||||
"docker",
|
||||
"docker-compose",
|
||||
"python",
|
||||
"python3",
|
||||
"ffmpeg",
|
||||
"ffprobe",
|
||||
"mediainfo",
|
||||
"rclone",
|
||||
"aria2c",
|
||||
"yt-dlp",
|
||||
)
|
||||
|
||||
|
||||
class PromptConfigError(ValueError):
|
||||
@@ -64,6 +91,7 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
self._system_tasks_cache: Optional[SystemTasksDefinition] = None
|
||||
self._system_tasks_signature: Optional[tuple[int, int]] = None
|
||||
self._available_shell_commands_cache: Optional[list[tuple[str, str]]] = None
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""
|
||||
@@ -251,8 +279,7 @@ class PromptManager:
|
||||
sections.append(self._format_numbered_rules("IMPORTANT", rules))
|
||||
return "\n\n".join(section for section in sections if section).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_moviepilot_info() -> str:
|
||||
def _get_moviepilot_info(self) -> str:
|
||||
"""
|
||||
获取MoviePilot系统信息,用于注入到系统提示词中
|
||||
"""
|
||||
@@ -302,10 +329,47 @@ class PromptManager:
|
||||
f"- 配置文件目录: {config_path}",
|
||||
f"- 日志文件目录: {log_path}",
|
||||
f"- 系统安装目录: {settings.ROOT_PATH}",
|
||||
f"- 插件安装目录: {settings.ROOT_PATH / 'app' / 'plugins'}",
|
||||
]
|
||||
|
||||
available_commands = self._get_available_shell_commands()
|
||||
if available_commands:
|
||||
info_lines.append("- 可用系统命令(可通过 `execute_command` 调用):")
|
||||
info_lines.extend(
|
||||
f" - {command}: {path}" for command, path in available_commands
|
||||
)
|
||||
# `rg` 同时覆盖文件枚举和文本检索,且比通用 shell 查找更适合
|
||||
# Agent 的代码阅读与定位场景;只有在它不可用或不适合时才退回其他工具。
|
||||
if any(command == "rg" for command, _ in available_commands):
|
||||
info_lines.append(
|
||||
"- When searching files or text, prefer `rg` / `rg --files`. Only fall back to other search tools when `rg` is unavailable or unsuitable."
|
||||
)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
def _get_available_shell_commands(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
探测 PATH 中已经安装的常用命令。
|
||||
|
||||
这里只使用 shutil.which 做无副作用查找,不实际执行命令;执行权限、
|
||||
高风险操作确认和输出限制仍由 execute_command 工具负责。探测结果
|
||||
在进程内缓存,避免每次组装提示词都重复扫描 PATH。
|
||||
"""
|
||||
if self._available_shell_commands_cache is not None:
|
||||
return self._available_shell_commands_cache
|
||||
|
||||
available_commands: list[tuple[str, str]] = []
|
||||
for command in COMMON_SHELL_COMMANDS:
|
||||
command_path = shutil.which(command)
|
||||
if command_path:
|
||||
available_commands.append((command, command_path))
|
||||
self._available_shell_commands_cache = available_commands
|
||||
return available_commands
|
||||
|
||||
def clear_available_shell_commands_cache(self) -> None:
|
||||
"""清理可用系统命令缓存,供测试或运行时手动刷新使用。"""
|
||||
self._available_shell_commands_cache = None
|
||||
|
||||
@staticmethod
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
@@ -327,10 +391,17 @@ class PromptManager:
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions() -> str:
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
return "Audio output is disabled; do not call `send_voice_message`."
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when the user explicitly asks for a voice reply "
|
||||
"or spoken playback is clearly better than plain text."
|
||||
"Use normal text replies by default. Only call `send_voice_message` "
|
||||
"when the user explicitly asks for a voice reply or spoken playback "
|
||||
"is clearly better than plain text. `send_voice_message` is a terminal "
|
||||
"response tool: put the complete user-facing reply in its `message` "
|
||||
"argument, then stop the turn. Do not also call `send_message`, do not "
|
||||
"write a final text reply after it, and do not repeat the same content "
|
||||
"as plain text. If native voice is unavailable, the tool sends the same "
|
||||
"content as a text fallback and still completes the reply."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -344,9 +415,11 @@ class PromptManager:
|
||||
):
|
||||
return (
|
||||
"- User questions: If you need the user to choose from a few clear options, "
|
||||
"call `ask_user_choice` to send button options. After the user clicks a button, "
|
||||
"the selected value will come back as the user's next message. After calling this tool, "
|
||||
"wait for the user's selection instead of repeating the question in plain text."
|
||||
"call `ask_user_choice` to send button options. `ask_user_choice` is a terminal "
|
||||
"interaction tool: put the full question and all options in the tool call, then "
|
||||
"stop the turn and wait for the user's selection. The selected value will come back "
|
||||
"as the user's next message. Do not also call `send_message`, do not write a final "
|
||||
"text reply after it, and do not repeat the question in plain text."
|
||||
)
|
||||
return "- User questions: When you truly need user input, ask briefly in plain text."
|
||||
|
||||
|
||||
@@ -10,13 +10,14 @@ from langchain_core.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingHandler
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
@@ -131,7 +132,31 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._require_admin = getattr(self.__class__, "require_admin", False)
|
||||
# require_admin 在各工具子类以 pydantic 字段声明,pydantic v2 不在类对象上暴露字段值
|
||||
# (getattr(cls, ...) 取不到),必须经实例读取——super().__init__() 已按字段默认填充实例;
|
||||
# getattr 兜底兼容未声明该字段的工具,缺省按非管理员(False)处理。
|
||||
self._require_admin = getattr(self, "require_admin", False)
|
||||
self.tags = self._build_tool_tags()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tag_values(tags: Optional[Any]) -> set[str]:
|
||||
"""规范化 LangChain 工具标签。"""
|
||||
if not tags:
|
||||
return set()
|
||||
if isinstance(tags, (str, ToolTag)):
|
||||
tags = [tags]
|
||||
normalized_tags = set()
|
||||
for tag in tags:
|
||||
if isinstance(tag, ToolTag):
|
||||
normalized_tags.add(tag.value)
|
||||
elif tag:
|
||||
normalized_tags.add(str(tag))
|
||||
return normalized_tags
|
||||
|
||||
def _build_tool_tags(self) -> list[str]:
|
||||
"""规范化工具实现中显式声明的标签。"""
|
||||
explicit_tags = self._normalize_tag_values(getattr(self, "tags", None))
|
||||
return sorted(explicit_tags | {ToolTag.AgentTool.value})
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
@@ -157,8 +182,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 发送工具执行过程消息
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
# 发送工具执行过程消息(流式传输且非最后终结工具时)
|
||||
if self._stream_handler and self._stream_handler.is_streaming and not self.return_direct:
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
if self._stream_handler.is_auto_flushing:
|
||||
# 渠道支持编辑:工具消息追加到 buffer,由定时刷新推送
|
||||
@@ -212,8 +237,15 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
result_len = len(str(result)) if result is not None else 0
|
||||
logger.debug(f"Tool {self.name} executed, raw result length: {result_len}")
|
||||
|
||||
# 记录工具执行结果摘要日志
|
||||
str_result = serialize_tool_result_for_agent(result)
|
||||
if len(str_result) > 500:
|
||||
summary = str_result[:500] + f"...(已截断,总长度: {len(str_result)})"
|
||||
else:
|
||||
summary = str_result
|
||||
logger.info(f"Agent工具 {self.name} 执行完成,结果摘要: {summary}")
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
@@ -236,7 +268,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
return None
|
||||
explanation = kwargs.get("explanation")
|
||||
return str(explanation) if explanation else None
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
@@ -278,7 +311,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
设置与当前 Agent 共享的上下文。
|
||||
"""
|
||||
self._agent_context = agent_context or {}
|
||||
# 空 dict 也是合法共享上下文;不能用 ``or {}``,否则每个工具会拿到
|
||||
# 独立的新 dict,跨工具状态(例如质量门槛拒绝标记)无法传播。
|
||||
self._agent_context = {} if agent_context is None else agent_context
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -397,7 +432,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
@@ -405,6 +440,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
|
||||
@@ -77,6 +77,7 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.agent.tools.impl.query_system_settings import QuerySystemSettingsTool
|
||||
from app.agent.tools.impl.update_system_settings import UpdateSystemSettingsTool
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -90,7 +91,7 @@ class MoviePilotToolFactory:
|
||||
"""
|
||||
|
||||
# 这些通用工具需要始终保留,避免大工具集裁剪后让 Agent 丢失基础的
|
||||
# 文件系统、命令执行或交互确认能力。AskUserChoiceTool 仅在支持按钮
|
||||
# 文件系统、命令执行、主动消息发送或交互确认能力。AskUserChoiceTool 仅在支持按钮
|
||||
# 的渠道中才会实际注入,因此后续会再按已加载工具做一次求交集。
|
||||
TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES = (
|
||||
"list_directory",
|
||||
@@ -98,6 +99,7 @@ class MoviePilotToolFactory:
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
"send_message",
|
||||
"ask_user_choice",
|
||||
)
|
||||
|
||||
@@ -225,12 +227,9 @@ class MoviePilotToolFactory:
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
tool_definitions.append(SendLocalFileTool)
|
||||
if AgentCapabilityManager.supports_audio_output():
|
||||
tool_definitions.append(SendVoiceMessageTool)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""插件 Agent 工具共享辅助方法"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
@@ -7,6 +8,7 @@ from typing import Any, Optional
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
@@ -93,6 +95,9 @@ def summarize_plugin(plugin: Any) -> dict[str, Any]:
|
||||
"plugin_author": getattr(plugin, "plugin_author", None),
|
||||
"installed": bool(getattr(plugin, "installed", False)),
|
||||
"has_update": bool(getattr(plugin, "has_update", False)),
|
||||
"system_version_compatible": getattr(plugin, "system_version_compatible", True) is not False,
|
||||
"system_version": getattr(plugin, "system_version", None),
|
||||
"system_version_message": getattr(plugin, "system_version_message", None),
|
||||
"state": bool(getattr(plugin, "state", False)),
|
||||
"repo_url": repo_url,
|
||||
"source": "local_repo" if PluginHelper.is_local_repo_url(repo_url) else "market",
|
||||
@@ -226,7 +231,7 @@ async def install_plugin_runtime(
|
||||
refreshed_only = False
|
||||
if not force and plugin_id in plugin_manager.get_plugin_ids():
|
||||
refreshed_only = True
|
||||
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
|
||||
await MoviePilotServerHelper.async_install_plugin_reg(plugin_id=plugin_id, repo_url=repo_url)
|
||||
message = "插件已存在,已刷新加载"
|
||||
else:
|
||||
if not repo_url:
|
||||
@@ -238,6 +243,7 @@ async def install_plugin_runtime(
|
||||
)
|
||||
if not state:
|
||||
return False, message, False
|
||||
await MoviePilotServerHelper.async_install_plugin_reg(plugin_id=plugin_id, repo_url=repo_url)
|
||||
|
||||
if plugin_id not in install_plugins:
|
||||
install_plugins.append(plugin_id)
|
||||
@@ -245,7 +251,7 @@ async def install_plugin_runtime(
|
||||
SystemConfigKey.UserInstalledPlugins, install_plugins
|
||||
)
|
||||
|
||||
reload_plugin_runtime(plugin_id)
|
||||
await asyncio.to_thread(reload_plugin_runtime, plugin_id)
|
||||
return True, message or "插件安装成功", refreshed_only
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,10 @@ SYSTEMCONFIG_SETTING_METADATA = {
|
||||
"group": "custom_identifiers",
|
||||
"label": "自定义识别词",
|
||||
},
|
||||
SystemConfigKey.EpisodeFormatRuleTable.value: {
|
||||
"group": "transfer",
|
||||
"label": "集数定位规则词表",
|
||||
},
|
||||
SystemConfigKey.CustomReleaseGroups.value: {
|
||||
"group": "customization",
|
||||
"label": "自定义制作组/字幕组",
|
||||
|
||||
630
app/agent/tools/impl/_terminal_session.py
Normal file
630
app/agent/tools/impl/_terminal_session.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""Agent 终端会话管理器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
if os.name == "posix":
|
||||
import fcntl as _fcntl
|
||||
import pty as _pty
|
||||
else:
|
||||
_fcntl = None
|
||||
_pty = None
|
||||
|
||||
|
||||
TERMINAL_CONCURRENCY_LIMIT = 4
|
||||
TERMINAL_RETENTION_SECONDS = 30 * 60
|
||||
TERMINAL_MAX_RETAINED_BYTES = 1024 * 1024
|
||||
TERMINAL_DEFAULT_READ_BYTES = 10 * 1024
|
||||
TERMINAL_MAX_READ_BYTES = 64 * 1024
|
||||
TERMINAL_READ_CHUNK_SIZE = 4096
|
||||
TERMINAL_PTY_POLL_INTERVAL = 0.05
|
||||
TERMINAL_WAIT_DEFAULT_MS = 1000
|
||||
TERMINAL_WAIT_MAX_MS = 60 * 1000
|
||||
TERMINAL_KILL_GRACE_SECONDS = 3
|
||||
TERMINAL_FORBIDDEN_KEYWORDS = (
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TerminalChunk:
|
||||
"""记录终端输出分片,供增量读取时按 seq 过滤。"""
|
||||
|
||||
seq: int
|
||||
stream: str
|
||||
text: str
|
||||
byte_size: int
|
||||
created_at: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TerminalSession:
|
||||
"""保存一个后台命令会话的进程、输出和状态。"""
|
||||
|
||||
session_id: str
|
||||
command: str
|
||||
cwd: str
|
||||
pid: int
|
||||
use_pty: bool
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
status: str = "running"
|
||||
exit_code: Optional[int] = None
|
||||
process: Optional[asyncio.subprocess.Process] = None
|
||||
master_fd: Optional[int] = None
|
||||
chunks: list[_TerminalChunk] = field(default_factory=list)
|
||||
next_seq: int = 1
|
||||
retained_from_seq: int = 1
|
||||
retained_bytes: int = 0
|
||||
kill_requested: bool = False
|
||||
error: Optional[str] = None
|
||||
reader_tasks: list[asyncio.Task] = field(default_factory=list)
|
||||
wait_task: Optional[asyncio.Task] = None
|
||||
|
||||
def append_output(self, stream: str, data: bytes) -> None:
|
||||
"""追加输出并按容量上限丢弃最旧分片,避免长任务撑爆内存。"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
text = data.decode("utf-8", errors="replace")
|
||||
chunk = _TerminalChunk(
|
||||
seq=self.next_seq,
|
||||
stream=stream,
|
||||
text=text,
|
||||
byte_size=len(data),
|
||||
created_at=time.time(),
|
||||
)
|
||||
self.next_seq += 1
|
||||
self.chunks.append(chunk)
|
||||
self.retained_bytes += chunk.byte_size
|
||||
self.updated_at = chunk.created_at
|
||||
self._trim_output()
|
||||
|
||||
def _trim_output(self) -> None:
|
||||
"""移除超出保留上限的旧输出分片。"""
|
||||
while self.retained_bytes > TERMINAL_MAX_RETAINED_BYTES and self.chunks:
|
||||
removed = self.chunks.pop(0)
|
||||
self.retained_bytes -= removed.byte_size
|
||||
self.retained_from_seq = removed.seq + 1
|
||||
|
||||
def mark_finished(self, exit_code: Optional[int]) -> None:
|
||||
"""标记进程已经结束,并记录退出码。"""
|
||||
self.exit_code = exit_code
|
||||
self.status = "killed" if self.kill_requested else "exited"
|
||||
self.updated_at = time.time()
|
||||
|
||||
def mark_error(self, message: str) -> None:
|
||||
"""标记会话异常,保留错误信息供后续读取。"""
|
||||
self.error = message
|
||||
self.status = "error"
|
||||
self.updated_at = time.time()
|
||||
|
||||
def close_pty(self) -> None:
|
||||
"""关闭父进程持有的 PTY master fd。"""
|
||||
if self.master_fd is None:
|
||||
return
|
||||
try:
|
||||
os.close(self.master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self.master_fd = None
|
||||
|
||||
|
||||
class _TerminalSessionManager:
|
||||
"""管理 Agent 后台终端会话的生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化会话表和并发保护锁。"""
|
||||
self._sessions: dict[str, _TerminalSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_bool(value: Any, default: bool = True) -> bool:
|
||||
"""兼容 LLM 或 HTTP 传入的 bool/string/int 布尔值。"""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() not in {"false", "0", "no", "off"}
|
||||
return bool(value)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_cwd(cwd: Optional[str]) -> str:
|
||||
"""解析工作目录,未传入时默认使用 MoviePilot 项目根目录。"""
|
||||
if not cwd:
|
||||
return str(settings.ROOT_PATH)
|
||||
path = Path(cwd).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (settings.ROOT_PATH / path).resolve()
|
||||
else:
|
||||
path = path.resolve()
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"工作目录不存在: {path}")
|
||||
if not path.is_dir():
|
||||
raise NotADirectoryError(f"工作目录不是目录: {path}")
|
||||
return str(path)
|
||||
|
||||
@staticmethod
|
||||
def _build_env(env: Optional[dict[str, Any]]) -> dict[str, str]:
|
||||
"""合并环境变量,并把值稳定转换为字符串。"""
|
||||
merged_env = os.environ.copy()
|
||||
if not env:
|
||||
return merged_env
|
||||
for key, value in env.items():
|
||||
if value is None:
|
||||
continue
|
||||
merged_env[str(key)] = str(value)
|
||||
return merged_env
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
"""拒绝明显危险或空白命令。"""
|
||||
if not command or not command.strip():
|
||||
raise ValueError("命令不能为空")
|
||||
for keyword in TERMINAL_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
|
||||
@staticmethod
|
||||
def _set_nonblocking(fd: int) -> None:
|
||||
"""将 PTY master fd 设置为非阻塞,避免后台读取任务卡住事件循环。"""
|
||||
if _fcntl is None:
|
||||
raise RuntimeError("当前平台不支持 PTY 非阻塞设置")
|
||||
flags = _fcntl.fcntl(fd, _fcntl.F_GETFL)
|
||||
_fcntl.fcntl(fd, _fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
@staticmethod
|
||||
def _pipe_subprocess_kwargs() -> dict[str, Any]:
|
||||
"""生成普通管道模式的子进程参数。"""
|
||||
kwargs: dict[str, Any] = {
|
||||
"stdin": asyncio.subprocess.PIPE,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
"stderr": asyncio.subprocess.PIPE,
|
||||
}
|
||||
if os.name == "posix":
|
||||
kwargs["start_new_session"] = True
|
||||
elif os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
return kwargs
|
||||
|
||||
async def start(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict[str, Any]] = None,
|
||||
use_pty: Any = True,
|
||||
) -> dict[str, Any]:
|
||||
"""启动后台命令并立即返回会话 ID。"""
|
||||
self._validate_command(command)
|
||||
normalized_cwd = self._normalize_cwd(cwd)
|
||||
normalized_env = self._build_env(env)
|
||||
should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix"
|
||||
|
||||
async with self._lock:
|
||||
self._cleanup_finished_sessions_locked()
|
||||
if self._active_session_count_locked() >= TERMINAL_CONCURRENCY_LIMIT:
|
||||
raise RuntimeError(
|
||||
f"后台终端会话数已达到上限 {TERMINAL_CONCURRENCY_LIMIT}"
|
||||
)
|
||||
|
||||
session = (
|
||||
await self._start_pty_session(command, normalized_cwd, normalized_env)
|
||||
if should_use_pty
|
||||
else await self._start_pipe_session(command, normalized_cwd, normalized_env)
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._sessions[session.session_id] = session
|
||||
|
||||
logger.info(
|
||||
"启动后台终端会话: session_id=%s, pid=%s, use_pty=%s, command=%s",
|
||||
session.session_id,
|
||||
session.pid,
|
||||
session.use_pty,
|
||||
command,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
async def _start_pty_session(
|
||||
self, command: str, cwd: str, env: dict[str, str]
|
||||
) -> _TerminalSession:
|
||||
"""通过 PTY fork 启动交互式命令会话。"""
|
||||
if _pty is None:
|
||||
raise RuntimeError("当前平台不支持 PTY 会话")
|
||||
pid, master_fd = _pty.fork()
|
||||
if pid == 0:
|
||||
os.chdir(cwd)
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
shell = os.environ.get("SHELL") or "/bin/sh"
|
||||
os.execl(shell, shell, "-lc", command)
|
||||
|
||||
self._set_nonblocking(master_fd)
|
||||
session = _TerminalSession(
|
||||
session_id=f"term_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
pid=pid,
|
||||
use_pty=True,
|
||||
master_fd=master_fd,
|
||||
)
|
||||
session.reader_tasks.append(asyncio.create_task(self._read_pty(session)))
|
||||
session.wait_task = asyncio.create_task(self._wait_pty_process(session))
|
||||
return session
|
||||
|
||||
async def _start_pipe_session(
|
||||
self, command: str, cwd: str, env: dict[str, str]
|
||||
) -> _TerminalSession:
|
||||
"""通过普通 stdin/stdout/stderr 管道启动命令会话。"""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
**self._pipe_subprocess_kwargs(),
|
||||
)
|
||||
session = _TerminalSession(
|
||||
session_id=f"term_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
pid=process.pid or 0,
|
||||
use_pty=False,
|
||||
process=process,
|
||||
)
|
||||
if process.stdout:
|
||||
session.reader_tasks.append(
|
||||
asyncio.create_task(self._read_pipe(session, process.stdout, "stdout"))
|
||||
)
|
||||
if process.stderr:
|
||||
session.reader_tasks.append(
|
||||
asyncio.create_task(self._read_pipe(session, process.stderr, "stderr"))
|
||||
)
|
||||
session.wait_task = asyncio.create_task(self._wait_pipe_process(session))
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
async def _read_pty(session: _TerminalSession) -> None:
|
||||
"""持续从 PTY 读取增量输出。"""
|
||||
while session.master_fd is not None:
|
||||
try:
|
||||
data = os.read(session.master_fd, TERMINAL_READ_CHUNK_SIZE)
|
||||
except BlockingIOError:
|
||||
await asyncio.sleep(TERMINAL_PTY_POLL_INTERVAL)
|
||||
continue
|
||||
except OSError as err:
|
||||
if err.errno not in {errno.EIO, errno.EBADF}:
|
||||
logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err)
|
||||
break
|
||||
|
||||
if not data:
|
||||
break
|
||||
session.append_output("pty", data)
|
||||
|
||||
@staticmethod
|
||||
async def _read_pipe(
|
||||
session: _TerminalSession,
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
) -> None:
|
||||
"""持续从普通管道读取增量输出。"""
|
||||
while True:
|
||||
data = await stream.read(TERMINAL_READ_CHUNK_SIZE)
|
||||
if not data:
|
||||
break
|
||||
session.append_output(stream_name, data)
|
||||
|
||||
async def _wait_pty_process(self, session: _TerminalSession) -> None:
|
||||
"""等待 PTY 子进程结束并完成输出读取任务收尾。"""
|
||||
try:
|
||||
_, status = await asyncio.to_thread(os.waitpid, session.pid, 0)
|
||||
exit_code = os.waitstatus_to_exitcode(status)
|
||||
session.mark_finished(exit_code)
|
||||
except ChildProcessError:
|
||||
session.mark_finished(session.exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
session.close_pty()
|
||||
|
||||
async def _wait_pipe_process(self, session: _TerminalSession) -> None:
|
||||
"""等待普通管道子进程结束并完成输出读取任务收尾。"""
|
||||
try:
|
||||
if not session.process:
|
||||
session.mark_error("进程对象不存在")
|
||||
return
|
||||
exit_code = await session.process.wait()
|
||||
session.mark_finished(exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
|
||||
@staticmethod
|
||||
async def _finish_reader_tasks(session: _TerminalSession) -> None:
|
||||
"""等待输出读取任务退出,超时后取消残留任务。"""
|
||||
if not session.reader_tasks:
|
||||
return
|
||||
done, pending = await asyncio.wait(session.reader_tasks, timeout=1)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
await asyncio.gather(*done, *pending, return_exceptions=True)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
) -> dict[str, Any]:
|
||||
"""读取会话当前保留的增量输出。"""
|
||||
session = self.get_session(session_id)
|
||||
output, output_truncated, output_until_seq = self._collect_output(
|
||||
session,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._session_payload(
|
||||
session,
|
||||
output=output,
|
||||
output_truncated=output_truncated,
|
||||
output_until_seq=output_until_seq,
|
||||
)
|
||||
|
||||
async def wait(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
) -> dict[str, Any]:
|
||||
"""短暂等待会话结束,并返回等待期间可见的增量输出。"""
|
||||
session = self.get_session(session_id)
|
||||
normalized_timeout = self._normalize_wait_timeout(timeout_ms)
|
||||
if session.wait_task and not session.wait_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(session.wait_task),
|
||||
timeout=normalized_timeout / 1000,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
output, output_truncated, output_until_seq = self._collect_output(
|
||||
session,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
payload = self._session_payload(
|
||||
session,
|
||||
output=output,
|
||||
output_truncated=output_truncated,
|
||||
output_until_seq=output_until_seq,
|
||||
)
|
||||
payload["wait_timeout_ms"] = normalized_timeout
|
||||
return payload
|
||||
|
||||
async def write(self, *, session_id: str, input_text: str) -> dict[str, Any]:
|
||||
"""向会话 stdin 写入文本,PTY 模式下写入 master fd。"""
|
||||
session = self.get_session(session_id)
|
||||
if session.status != "running":
|
||||
raise RuntimeError(f"会话已结束,当前状态: {session.status}")
|
||||
|
||||
data = (input_text or "").encode("utf-8")
|
||||
if session.use_pty:
|
||||
if session.master_fd is None:
|
||||
raise RuntimeError("PTY 已关闭")
|
||||
await asyncio.to_thread(os.write, session.master_fd, data)
|
||||
else:
|
||||
if not session.process or not session.process.stdin:
|
||||
raise RuntimeError("进程 stdin 不可写")
|
||||
session.process.stdin.write(data)
|
||||
await session.process.stdin.drain()
|
||||
|
||||
session.updated_at = time.time()
|
||||
payload = self._session_payload(session, output="", output_truncated=False)
|
||||
payload["written_bytes"] = len(data)
|
||||
return payload
|
||||
|
||||
async def kill(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
sig: Optional[str | int] = "TERM",
|
||||
) -> dict[str, Any]:
|
||||
"""向会话进程组发送信号并等待短暂清理。"""
|
||||
session = self.get_session(session_id)
|
||||
if session.status != "running":
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
session.kill_requested = True
|
||||
signal_number = self._resolve_signal(sig)
|
||||
self._send_signal(session, signal_number)
|
||||
|
||||
if session.wait_task and not session.wait_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(session.wait_task),
|
||||
timeout=TERMINAL_KILL_GRACE_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
force_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
|
||||
self._send_signal(session, force_signal)
|
||||
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
def get_session(self, session_id: str) -> _TerminalSession:
|
||||
"""按 ID 获取会话,不存在时抛出清晰错误。"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
raise KeyError(f"终端会话不存在: {session_id}")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def _normalize_wait_timeout(timeout_ms: Optional[int]) -> int:
|
||||
"""限制 wait 单次等待时间,避免工具调用长时间占用模型回合。"""
|
||||
try:
|
||||
normalized = int(timeout_ms or TERMINAL_WAIT_DEFAULT_MS)
|
||||
except (TypeError, ValueError):
|
||||
normalized = TERMINAL_WAIT_DEFAULT_MS
|
||||
if normalized < 0:
|
||||
return 0
|
||||
return min(normalized, TERMINAL_WAIT_MAX_MS)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_read_limit(max_bytes: Optional[int]) -> int:
|
||||
"""限制单次读取返回的输出大小。"""
|
||||
try:
|
||||
normalized = int(max_bytes or TERMINAL_DEFAULT_READ_BYTES)
|
||||
except (TypeError, ValueError):
|
||||
normalized = TERMINAL_DEFAULT_READ_BYTES
|
||||
if normalized <= 0:
|
||||
return TERMINAL_DEFAULT_READ_BYTES
|
||||
return min(normalized, TERMINAL_MAX_READ_BYTES)
|
||||
|
||||
def _collect_output(
|
||||
self,
|
||||
session: _TerminalSession,
|
||||
*,
|
||||
since_seq: Optional[int],
|
||||
max_bytes: Optional[int],
|
||||
) -> tuple[str, bool, int]:
|
||||
"""按 seq 和大小限制收集输出文本。"""
|
||||
read_limit = self._normalize_read_limit(max_bytes)
|
||||
selected_chunks = [
|
||||
chunk
|
||||
for chunk in session.chunks
|
||||
if since_seq is None or chunk.seq > since_seq
|
||||
]
|
||||
output_parts: list[str] = []
|
||||
output_bytes = 0
|
||||
output_truncated = False
|
||||
last_stream: Optional[str] = None
|
||||
output_until_seq = since_seq or session.retained_from_seq - 1
|
||||
|
||||
for chunk in selected_chunks:
|
||||
prefix = self._stream_prefix(chunk.stream, last_stream, session.use_pty)
|
||||
text = f"{prefix}{chunk.text}" if prefix else chunk.text
|
||||
encoded = text.encode("utf-8")
|
||||
remaining = read_limit - output_bytes
|
||||
if len(encoded) > remaining:
|
||||
if remaining > 0:
|
||||
output_parts.append(
|
||||
encoded[:remaining].decode("utf-8", errors="ignore")
|
||||
)
|
||||
output_truncated = True
|
||||
break
|
||||
output_parts.append(text)
|
||||
output_bytes += len(encoded)
|
||||
last_stream = chunk.stream
|
||||
output_until_seq = chunk.seq
|
||||
|
||||
if since_seq is not None and since_seq < session.retained_from_seq - 1:
|
||||
output_truncated = True
|
||||
if not output_truncated:
|
||||
output_until_seq = session.next_seq - 1
|
||||
return "".join(output_parts), output_truncated, output_until_seq
|
||||
|
||||
@staticmethod
|
||||
def _stream_prefix(stream: str, last_stream: Optional[str], use_pty: bool) -> str:
|
||||
"""为普通管道输出增加 stdout/stderr 分段标识。"""
|
||||
if use_pty or stream == last_stream:
|
||||
return ""
|
||||
title = "标准输出" if stream == "stdout" else "错误输出"
|
||||
return f"\n[{title}]\n"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_signal(sig: Optional[str | int]) -> int:
|
||||
"""解析字符串或数字形式的信号名。"""
|
||||
if isinstance(sig, int):
|
||||
return sig
|
||||
signal_name = str(sig or "TERM").strip().upper()
|
||||
if signal_name.isdigit():
|
||||
return int(signal_name)
|
||||
if not signal_name.startswith("SIG"):
|
||||
signal_name = f"SIG{signal_name}"
|
||||
return int(getattr(signal, signal_name, signal.SIGTERM))
|
||||
|
||||
@staticmethod
|
||||
def _send_signal(session: _TerminalSession, sig: int) -> None:
|
||||
"""优先向进程组发信号,失败时回退到单进程。"""
|
||||
try:
|
||||
if os.name == "posix":
|
||||
os.killpg(session.pid, sig)
|
||||
elif session.process:
|
||||
if sig == getattr(signal, "SIGKILL", None):
|
||||
session.process.kill()
|
||||
else:
|
||||
session.process.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
def _active_session_count_locked(self) -> int:
|
||||
"""统计仍在运行的会话数量。"""
|
||||
return sum(1 for session in self._sessions.values() if session.status == "running")
|
||||
|
||||
def _cleanup_finished_sessions_locked(self) -> None:
|
||||
"""清理已经结束且超过保留时间的会话。"""
|
||||
now = time.time()
|
||||
expired_ids = [
|
||||
session_id
|
||||
for session_id, session in self._sessions.items()
|
||||
if session.status != "running"
|
||||
and now - session.updated_at > TERMINAL_RETENTION_SECONDS
|
||||
]
|
||||
for session_id in expired_ids:
|
||||
session = self._sessions.pop(session_id)
|
||||
session.close_pty()
|
||||
|
||||
@staticmethod
|
||||
def _session_payload(
|
||||
session: _TerminalSession,
|
||||
*,
|
||||
output: str,
|
||||
output_truncated: bool,
|
||||
output_until_seq: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""生成工具返回的结构化会话状态。"""
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"command": session.command,
|
||||
"cwd": session.cwd,
|
||||
"pid": session.pid,
|
||||
"status": session.status,
|
||||
"exit_code": session.exit_code,
|
||||
"use_pty": session.use_pty,
|
||||
"last_seq": session.next_seq - 1,
|
||||
"output_until_seq": (
|
||||
session.next_seq - 1 if output_until_seq is None else output_until_seq
|
||||
),
|
||||
"retained_from_seq": session.retained_from_seq,
|
||||
"output_truncated": output_truncated,
|
||||
"output": output,
|
||||
"error": session.error,
|
||||
}
|
||||
|
||||
|
||||
terminal_session_manager = _TerminalSessionManager()
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_custom_rules,
|
||||
normalize_custom_rule,
|
||||
@@ -19,10 +20,8 @@ from app.schemas.types import SystemConfigKey
|
||||
class AddCustomFilterRuleInput(BaseModel):
|
||||
"""新增自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
rule_id: str = Field(
|
||||
...,
|
||||
description="Unique custom rule ID. Only letters and numbers are allowed.",
|
||||
@@ -48,6 +47,11 @@ class AddCustomFilterRuleInput(BaseModel):
|
||||
|
||||
class AddCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "add_custom_filter_rule"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Add a custom filter rule to CustomFilterRules. "
|
||||
"The new rule can then be referenced by rule ID inside filter rule groups."
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.download import DownloadChain
|
||||
@@ -22,7 +23,7 @@ from app.utils.crypto import HashUtils
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
torrent_url: List[str] = Field(
|
||||
...,
|
||||
description="One or more torrent_url values. Supports refs from get_search_results (`hash:id`) and magnet links."
|
||||
@@ -37,6 +38,11 @@ class AddDownloadInput(BaseModel):
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Download,
|
||||
ToolTag.Resource,
|
||||
]
|
||||
description: str = "Add torrent download tasks using refs from get_search_results or magnet links."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
build_custom_rule_map,
|
||||
collect_rule_group_usages,
|
||||
@@ -23,10 +24,8 @@ from app.schemas.types import SystemConfigKey
|
||||
class AddRuleGroupInput(BaseModel):
|
||||
"""新增过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
name: str = Field(..., description="New rule group name.")
|
||||
rule_string: str = Field(
|
||||
...,
|
||||
@@ -48,6 +47,11 @@ class AddRuleGroupInput(BaseModel):
|
||||
|
||||
class AddRuleGroupTool(MoviePilotTool):
|
||||
name: str = "add_rule_group"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Add a new filter rule group to UserFilterRuleGroups. "
|
||||
"Rule groups are matched level by level from left to right and can be linked to search/subscription flows. "
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
@@ -14,10 +15,8 @@ from app.schemas.types import MediaType, MessageChannel
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
title: str = Field(
|
||||
...,
|
||||
description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')",
|
||||
@@ -74,6 +73,11 @@ class AddSubscribeInput(BaseModel):
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Subscription,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = (
|
||||
"Add media subscription to create automated download rules for movies and TV shows. "
|
||||
"The system will automatically search and download new episodes or releases based on the subscription criteria. "
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.helper.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
@@ -26,9 +27,11 @@ class UserChoiceOptionInput(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_option(self):
|
||||
if not self.label.strip():
|
||||
label = str(self.label)
|
||||
value = str(self.value)
|
||||
if not label.strip():
|
||||
raise ValueError("label 不能为空")
|
||||
if not self.value.strip():
|
||||
if not value.strip():
|
||||
raise ValueError("value 不能为空")
|
||||
return self
|
||||
|
||||
@@ -36,10 +39,8 @@ class UserChoiceOptionInput(BaseModel):
|
||||
class AskUserChoiceInput(BaseModel):
|
||||
"""按钮选择工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why the agent needs the user to choose from buttons",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why the agent needs the user to choose from buttons",)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Question or prompt shown to the user together with the buttons",
|
||||
@@ -55,7 +56,8 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message.strip():
|
||||
message = str(self.message)
|
||||
if not message.strip():
|
||||
raise ValueError("message 不能为空")
|
||||
if not self.options:
|
||||
raise ValueError("options 至少需要提供一个")
|
||||
@@ -63,11 +65,22 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
"""发送按钮选择并让当前 Agent 轮次等待用户回调消息。"""
|
||||
|
||||
name: str = "ask_user_choice"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Message,
|
||||
ToolTag.UserInteraction,
|
||||
ToolTag.TerminalResponse,
|
||||
]
|
||||
sends_message: bool = True
|
||||
return_direct: bool = True
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
"This is a terminal interaction tool: put the full question and all options in this call, "
|
||||
"then stop the current turn. After the user clicks a button, the selected value will come "
|
||||
"back as the user's next message. Do not also send the same question as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
require_admin: bool = False
|
||||
@@ -86,6 +99,15 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
return text[:max_length]
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
def _blocked_by_feedback_quality_gate(self) -> bool:
|
||||
"""反馈 Issue 质量门槛拒绝后,禁止继续发按钮引导改写。
|
||||
|
||||
这是对 ``feedback-issue`` skill 的历史兜底:如果同一轮上下文已经
|
||||
标记反馈内容被质量门槛拒绝,就不能再用按钮诱导用户把测试 / 占位
|
||||
内容改写成“真实问题”。
|
||||
"""
|
||||
return bool(self._agent_context.get("feedback_issue_rejected_quality"))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
@@ -93,6 +115,17 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
title: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if self._blocked_by_feedback_quality_gate():
|
||||
logger.warning(
|
||||
"ask_user_choice blocked after feedback issue rejected_quality: "
|
||||
"session_id=%s",
|
||||
self._session_id,
|
||||
)
|
||||
return (
|
||||
"反馈 Issue 已被质量门槛拒绝,不能继续发送按钮引导用户改写或重新提交。"
|
||||
"请直接结束本次反馈流程。"
|
||||
)
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发起按钮选择"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -38,10 +39,8 @@ class BrowserAction(str, Enum):
|
||||
class BrowseWebpageInput(BaseModel):
|
||||
"""浏览器操作工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this browser action is being performed",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this browser action is being performed",)
|
||||
action: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@@ -91,6 +90,10 @@ class BrowseWebpageInput(BaseModel):
|
||||
|
||||
class BrowseWebpageTool(MoviePilotTool):
|
||||
name: str = "browse_webpage"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Web,
|
||||
]
|
||||
description: str = (
|
||||
"Control a real browser (Playwright) to interact with web pages. "
|
||||
"Supports navigating to URLs, reading page content, taking screenshots, "
|
||||
@@ -198,68 +201,62 @@ class BrowseWebpageTool(MoviePilotTool):
|
||||
cookies: Optional[str],
|
||||
user_agent: Optional[str],
|
||||
) -> str:
|
||||
"""在同步上下文中执行 Playwright 浏览器操作"""
|
||||
from playwright.sync_api import sync_playwright
|
||||
"""在同步上下文中执行 CloakBrowser 浏览器操作"""
|
||||
from cloakbrowser import launch_context
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
browser = None
|
||||
context = None
|
||||
page = None
|
||||
try:
|
||||
# 启动浏览器
|
||||
browser_type = settings.PLAYWRIGHT_BROWSER_TYPE or "chromium"
|
||||
browser = playwright[browser_type].launch(headless=True)
|
||||
|
||||
# 创建上下文
|
||||
context_kwargs = {}
|
||||
if user_agent:
|
||||
context_kwargs["user_agent"] = user_agent
|
||||
# 设置视口大小
|
||||
context_kwargs["viewport"] = {
|
||||
context = None
|
||||
page = None
|
||||
try:
|
||||
context_kwargs = {
|
||||
"viewport": {
|
||||
"width": SCREENSHOT_MAX_WIDTH,
|
||||
"height": SCREENSHOT_MAX_HEIGHT,
|
||||
}
|
||||
}
|
||||
if user_agent:
|
||||
context_kwargs["user_agent"] = user_agent
|
||||
|
||||
context = browser.new_context(**context_kwargs)
|
||||
page = context.new_page()
|
||||
page.set_default_timeout(timeout * 1000)
|
||||
context = launch_context(
|
||||
headless=True,
|
||||
humanize=settings.CLOAKBROWSER_HUMANIZE,
|
||||
human_preset=settings.CLOAKBROWSER_HUMAN_PRESET,
|
||||
**context_kwargs,
|
||||
)
|
||||
page = context.new_page()
|
||||
page.set_default_timeout(timeout * 1000)
|
||||
|
||||
# 设置 cookies
|
||||
if cookies:
|
||||
page.set_extra_http_headers({"cookie": cookies})
|
||||
# 设置 cookies
|
||||
if cookies:
|
||||
page.set_extra_http_headers({"cookie": cookies})
|
||||
|
||||
# 对于非 goto 操作,如果提供了 url 先导航
|
||||
if url and browser_action != BrowserAction.GOTO:
|
||||
page.goto(
|
||||
url, wait_until="domcontentloaded", timeout=timeout * 1000
|
||||
)
|
||||
page.wait_for_load_state("networkidle", timeout=timeout * 1000)
|
||||
# 对于非 goto 操作,如果提供了 url 先导航
|
||||
if url and browser_action != BrowserAction.GOTO:
|
||||
page.goto(url, wait_until="domcontentloaded", timeout=timeout * 1000)
|
||||
page.wait_for_load_state("networkidle", timeout=timeout * 1000)
|
||||
|
||||
# 执行具体操作
|
||||
result = self._do_action(
|
||||
page,
|
||||
browser_action,
|
||||
url,
|
||||
selector,
|
||||
value,
|
||||
script,
|
||||
content_type,
|
||||
timeout,
|
||||
)
|
||||
return result
|
||||
# 执行具体操作
|
||||
result = self._do_action(
|
||||
page,
|
||||
browser_action,
|
||||
url,
|
||||
selector,
|
||||
value,
|
||||
script,
|
||||
content_type,
|
||||
timeout,
|
||||
)
|
||||
return result
|
||||
|
||||
finally:
|
||||
if page:
|
||||
page.close()
|
||||
if context:
|
||||
context.close()
|
||||
if browser:
|
||||
browser.close()
|
||||
finally:
|
||||
if page:
|
||||
page.close()
|
||||
if context:
|
||||
context.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Playwright 执行失败: {e}", exc_info=True)
|
||||
return f"Playwright 执行失败: {str(e)}"
|
||||
logger.error(f"CloakBrowser 执行失败: {e}", exc_info=True)
|
||||
return f"CloakBrowser 执行失败: {str(e)}"
|
||||
|
||||
def _do_action(
|
||||
self,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
@@ -19,15 +20,18 @@ from app.schemas.types import SystemConfigKey
|
||||
class DeleteCustomFilterRuleInput(BaseModel):
|
||||
"""删除自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
rule_id: str = Field(..., description="Custom rule ID to delete.")
|
||||
|
||||
|
||||
class DeleteCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "delete_custom_filter_rule"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Delete a custom filter rule from CustomFilterRules. "
|
||||
"If the rule is still referenced by rule groups, the deletion is blocked to avoid breaking rule_string expressions."
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
@@ -12,10 +13,8 @@ from app.log import logger
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
@@ -31,6 +30,11 @@ class DeleteDownloadInput(BaseModel):
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Download,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.log import logger
|
||||
@@ -13,10 +14,8 @@ from app.log import logger
|
||||
class DeleteDownloadHistoryInput(BaseModel):
|
||||
"""删除下载历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the download history record to delete"
|
||||
)
|
||||
@@ -24,6 +23,11 @@ class DeleteDownloadHistoryInput(BaseModel):
|
||||
|
||||
class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_download_history"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Download,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_rule_groups,
|
||||
remove_rule_group_references,
|
||||
@@ -18,15 +19,18 @@ from app.schemas.types import SystemConfigKey
|
||||
class DeleteRuleGroupInput(BaseModel):
|
||||
"""删除过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
name: str = Field(..., description="Rule group name to delete.")
|
||||
|
||||
|
||||
class DeleteRuleGroupTool(MoviePilotTool):
|
||||
name: str = "delete_rule_group"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Delete a filter rule group from UserFilterRuleGroups. "
|
||||
"The tool also removes dangling references from global settings and subscriptions."
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.event import eventmanager
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
@@ -15,10 +16,8 @@ from app.schemas.types import EventType
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
|
||||
@@ -27,6 +26,11 @@ class DeleteSubscribeInput(BaseModel):
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Subscription,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
require_admin: bool = True
|
||||
@@ -51,7 +55,7 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
|
||||
await subscribe_oper.async_delete(subscribe_id)
|
||||
# 分享订阅统计刷新本身已异步化,这里只需要在删除后触发即可。
|
||||
SubscribeHelper().sub_done_async(
|
||||
MoviePilotServerHelper.sub_done_async(
|
||||
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
@@ -12,10 +13,8 @@ from app.log import logger
|
||||
class DeleteTransferHistoryInput(BaseModel):
|
||||
"""删除整理历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the transfer history record to delete"
|
||||
)
|
||||
@@ -23,6 +22,11 @@ class DeleteTransferHistoryInput(BaseModel):
|
||||
|
||||
class DeleteTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_transfer_history"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Transfer,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Delete a specific transfer history record by its ID. This is useful when you need to remove a failed transfer record before retrying the transfer, as the system skips files that already have transfer history."
|
||||
args_schema: Type[BaseModel] = DeleteTransferHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -7,6 +7,7 @@ from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -20,6 +21,11 @@ class EditFileInput(BaseModel):
|
||||
|
||||
class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.File,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
"""执行Shell命令工具"""
|
||||
"""执行 Shell 命令工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, TextIO, Type
|
||||
from typing import Any, Literal, Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._terminal_session import (
|
||||
TERMINAL_DEFAULT_READ_BYTES,
|
||||
TERMINAL_MAX_READ_BYTES,
|
||||
TERMINAL_WAIT_DEFAULT_MS,
|
||||
terminal_session_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -20,6 +30,13 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
COMMAND_FORBIDDEN_KEYWORDS = (
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@@ -38,11 +55,13 @@ class _CommandOutput:
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
"""按 UTF-8 字节数截断文本,避免截断后出现非法字符。"""
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
"""把输出分片按 stdout/stderr 分段写入临时文件。"""
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
@@ -56,6 +75,7 @@ class _CommandOutput:
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
"""首次超出预览上限时创建临时文件并补写已缓存预览。"""
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
@@ -72,6 +92,7 @@ class _CommandOutput:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭临时文件句柄,确保输出落盘。"""
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
@@ -79,6 +100,7 @@ class _CommandOutput:
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
"""追加一段输出,超出预览上限后只保留完整日志文件。"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
@@ -104,47 +126,167 @@ class _CommandOutput:
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
"""返回当前保留的 stdout 预览。"""
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
"""返回当前保留的 stderr 预览。"""
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
"""执行 Shell 命令工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this command action is needed")
|
||||
action: Optional[Literal["start", "read", "wait", "write", "kill", "run"]] = Field(
|
||||
"start",
|
||||
description=(
|
||||
"Command action. start launches a managed background session and returns "
|
||||
"session_id. read/wait/write/kill operate on that session. run executes "
|
||||
"once and waits until completion or timeout."
|
||||
),
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None,
|
||||
description="Shell command. Required for action=start or action=run.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Command session id returned by action=start.",
|
||||
)
|
||||
input_text: Optional[str] = Field(
|
||||
None,
|
||||
description="Text to send to stdin for action=write. Use \\u0003 for Ctrl+C.",
|
||||
)
|
||||
signal_name: Optional[str] = Field(
|
||||
"TERM",
|
||||
description="Signal for action=kill, such as TERM, INT, KILL, or 15.",
|
||||
)
|
||||
cwd: Optional[str] = Field(
|
||||
None,
|
||||
description="Working directory for action=start or action=run.",
|
||||
)
|
||||
env: Optional[dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Additional environment variables for action=start.",
|
||||
)
|
||||
use_pty: Optional[bool] = Field(
|
||||
True,
|
||||
description="Use a pseudo terminal for action=start when supported.",
|
||||
)
|
||||
since_seq: Optional[int] = Field(
|
||||
None,
|
||||
description="For action=read/wait, return output chunks after this seq.",
|
||||
)
|
||||
max_bytes: Optional[int] = Field(
|
||||
TERMINAL_DEFAULT_READ_BYTES,
|
||||
description="For action=read/wait, maximum output bytes to return.",
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
TERMINAL_WAIT_DEFAULT_MS,
|
||||
description="For action=wait, maximum segmented wait time in milliseconds.",
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
60,
|
||||
description="For action=run, max execution time in seconds.",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
"""统一执行和管理 Shell 命令的 Agent 工具。"""
|
||||
|
||||
name: str = "execute_command"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Command,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and output preview limits."
|
||||
"Start and manage shell commands on the server. By default action=start "
|
||||
"launches a background session and immediately returns session_id/status/"
|
||||
"last_seq/output_until_seq. Call the same tool with action=read, wait, "
|
||||
"write, or kill to poll output, wait in short segments, send stdin, or "
|
||||
"terminate it. Use action=run only when a one-shot bounded command result "
|
||||
"is preferred."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
result_max_chars = TERMINAL_MAX_READ_BYTES + 4096
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"执行系统命令: {command}"
|
||||
"""根据命令动作生成友好的提示消息。"""
|
||||
action = kwargs.get("action") or "start"
|
||||
command = kwargs.get("command")
|
||||
session_id = kwargs.get("session_id")
|
||||
if action in {"start", "run"}:
|
||||
return f"执行系统命令: {command or ''}"
|
||||
if action == "read":
|
||||
return f"读取命令输出: {session_id or ''}"
|
||||
if action == "wait":
|
||||
return f"等待命令会话: {session_id or ''}"
|
||||
if action == "write":
|
||||
return f"写入命令输入: {session_id or ''}"
|
||||
if action == "kill":
|
||||
return f"终止命令会话: {session_id or ''}"
|
||||
return f"处理命令会话: {session_id or command or ''}"
|
||||
|
||||
@staticmethod
|
||||
def _dump(payload: dict[str, Any]) -> str:
|
||||
"""把结构化命令会话结果转换为 Agent 容易解析的 JSON 字符串。"""
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _require_session_id(session_id: Optional[str]) -> str:
|
||||
"""校验会话型 action 必须传入 session_id。"""
|
||||
if not session_id:
|
||||
raise ValueError("action 需要传入 session_id")
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def _require_command(command: Optional[str]) -> str:
|
||||
"""校验启动型 action 必须传入 command。"""
|
||||
if not command or not command.strip():
|
||||
raise ValueError("action 需要传入 command")
|
||||
return command
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
|
||||
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
|
||||
# 检查是否使用了 rm -r/R 删除根目录或一级目录,防止误杀多级目录
|
||||
import re
|
||||
import os.path
|
||||
tokens = re.split(r'\s+', command.strip())
|
||||
if any(t == "rm" or t.endswith("/rm") for t in tokens):
|
||||
has_r = False
|
||||
for token in tokens:
|
||||
if token.startswith("-") and ("r" in token or "R" in token):
|
||||
has_r = True
|
||||
break
|
||||
|
||||
if has_r:
|
||||
for token in tokens:
|
||||
# 提取可能包含目标路径的部分(去除重定向、管道、分号等末尾干扰)
|
||||
m = re.match(r'^([^;\|&><]+)', token)
|
||||
if m:
|
||||
clean_token = m.group(1).strip('"\'')
|
||||
# 仅对绝对路径进行一级目录限制
|
||||
if clean_token.startswith('/'):
|
||||
norm_path = os.path.normpath(clean_token)
|
||||
if re.match(r'^/[^/]*$', norm_path) or re.match(r'^/[^/]*/$', norm_path):
|
||||
raise ValueError(f"不允许使用 rm 命令删除根目录或一级目录: {clean_token}")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
|
||||
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout。"""
|
||||
"""限制一次性执行命令的最长运行时间。"""
|
||||
try:
|
||||
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
|
||||
except (TypeError, ValueError):
|
||||
@@ -161,7 +303,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
"""为一次性命令创建独立进程组,便于超时清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -179,17 +321,16 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
) -> None:
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
"""按块读取一次性命令输出,只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
"""向进程组发送终止信号;不支持进程组的平台回退为单进程终止。"""
|
||||
def _terminate_process(process: Any, sig: int) -> None:
|
||||
"""向进程组发送终止信号,不支持进程组的平台回退为单进程终止。"""
|
||||
try:
|
||||
if os.name == "posix":
|
||||
os.killpg(process.pid, sig)
|
||||
@@ -203,7 +344,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
@classmethod
|
||||
async def _cleanup_process(
|
||||
cls,
|
||||
process: asyncio.subprocess.Process,
|
||||
process: Any,
|
||||
wait_task: asyncio.Task,
|
||||
) -> None:
|
||||
"""先温和终止,失败后强杀,避免超时 shell 遗留子进程。"""
|
||||
@@ -230,7 +371,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
|
||||
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
|
||||
"""等待一次性命令输出读取任务退出,异常只记录不影响工具返回。"""
|
||||
if not reader_tasks:
|
||||
return
|
||||
done, pending = await asyncio.wait(reader_tasks, timeout=1)
|
||||
@@ -244,7 +385,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
logger.debug("命令输出读取任务异常: %s", result)
|
||||
|
||||
@staticmethod
|
||||
def _format_result(
|
||||
def _format_run_result(
|
||||
*,
|
||||
exit_code: Optional[int],
|
||||
output: _CommandOutput,
|
||||
@@ -252,6 +393,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
timed_out: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
"""格式化 action=run 的兼容文本结果。"""
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
else:
|
||||
@@ -260,11 +402,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
file_note = "截至命令终止前的完整输出" if timed_out else "完整输出"
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
@@ -281,65 +419,129 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
async def _run_once(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
timeout: Optional[int],
|
||||
cwd: Optional[str] = None,
|
||||
) -> str:
|
||||
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
|
||||
self._validate_command(command)
|
||||
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
|
||||
|
||||
async with _command_semaphore:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
cwd=cwd,
|
||||
**self._subprocess_kwargs(),
|
||||
)
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(self._read_stream(process.stdout, "stdout", output)),
|
||||
asyncio.create_task(self._read_stream(process.stderr, "stderr", output)),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_run_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
action: Optional[str] = "start",
|
||||
command: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
signal_name: Optional[str] = "TERM",
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict[str, Any]] = None,
|
||||
use_pty: Optional[bool] = True,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
|
||||
timeout: Optional[int] = 60,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
|
||||
normalized_action = (action or "start").strip().lower()
|
||||
logger.info(
|
||||
"执行工具: %s, action=%s, command=%s, session_id=%s",
|
||||
self.name,
|
||||
normalized_action,
|
||||
command,
|
||||
session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
if normalized_action == "start":
|
||||
start_command = self._require_command(command)
|
||||
self._validate_command(start_command)
|
||||
payload = await terminal_session_manager.start(
|
||||
command=start_command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
use_pty=use_pty,
|
||||
)
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
return self._dump(payload)
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
if normalized_action == "read":
|
||||
payload = await terminal_session_manager.read(
|
||||
session_id=self._require_session_id(session_id),
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
if normalized_action == "wait":
|
||||
payload = await terminal_session_manager.wait(
|
||||
session_id=self._require_session_id(session_id),
|
||||
timeout_ms=timeout_ms,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
timeout_note=timeout_note,
|
||||
if normalized_action == "write":
|
||||
payload = await terminal_session_manager.write(
|
||||
session_id=self._require_session_id(session_id),
|
||||
input_text=input_text or "",
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
if normalized_action == "kill":
|
||||
payload = await terminal_session_manager.kill(
|
||||
session_id=self._require_session_id(session_id),
|
||||
sig=signal_name,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
if normalized_action == "run":
|
||||
return await self._run_once(
|
||||
command=self._require_command(command),
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return f"执行命令时发生错误: {str(e)}"
|
||||
raise ValueError(f"不支持的 action: {action}")
|
||||
except Exception as err:
|
||||
logger.error("执行命令 action 失败: %s", err, exc_info=True)
|
||||
return self._dump({"error": str(err), "status": "error", "action": normalized_action})
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
@@ -14,10 +15,8 @@ from app.schemas.types import MediaType, media_type_to_agent
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
source: Optional[str] = Field(
|
||||
"tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
@@ -46,6 +45,11 @@ class GetRecommendationsInput(BaseModel):
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
ToolTag.Recommendation,
|
||||
]
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from ._torrent_search_utils import (
|
||||
@@ -20,10 +21,8 @@ from ._torrent_search_utils import (
|
||||
class GetSearchResultsInput(BaseModel):
|
||||
"""获取搜索结果工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
site: Optional[List[str]] = Field(None, description="Site name filters")
|
||||
season: Optional[List[str]] = Field(None, description="Season or episode filters")
|
||||
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
|
||||
@@ -49,6 +48,10 @@ class GetSearchResultsInput(BaseModel):
|
||||
|
||||
class GetSearchResultsTool(MoviePilotTool):
|
||||
name: str = "get_search_results"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Resource,
|
||||
]
|
||||
description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page."
|
||||
args_schema: Type[BaseModel] = GetSearchResultsInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
install_plugin_runtime,
|
||||
@@ -18,10 +19,8 @@ from app.log import logger
|
||||
class InstallPluginInput(BaseModel):
|
||||
"""安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to install. Use query_market_plugins first to find the correct plugin_id.",
|
||||
@@ -38,6 +37,11 @@ class InstallPluginInput(BaseModel):
|
||||
|
||||
class InstallPluginTool(MoviePilotTool):
|
||||
name: str = "install_plugin"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Install a plugin by exact plugin_id from the plugin market or local plugin repositories. "
|
||||
"Use query_market_plugins first when you need filtering or discovery."
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.storage import StorageChain
|
||||
from app.log import logger
|
||||
from app.schemas.file import FileItem
|
||||
@@ -16,7 +17,7 @@ from app.utils.string import StringUtils
|
||||
|
||||
class ListDirectoryInput(BaseModel):
|
||||
"""查询文件系统目录内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(..., description="Directory path to list contents (e.g., '/home/user/downloads' or 'C:/Downloads')")
|
||||
storage: Optional[str] = Field("local", description="Storage type (default: 'local' for local file system, can be 'smb', 'alist', etc.)")
|
||||
sort_by: Optional[str] = Field("name", description="Sort order: 'name' for alphabetical sorting, 'time' for modification time sorting (default: 'name')")
|
||||
@@ -24,6 +25,11 @@ class ListDirectoryInput(BaseModel):
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Directory,
|
||||
ToolTag.File,
|
||||
]
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directory_settings' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
|
||||
@@ -6,20 +6,24 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ListSlashCommandsInput(BaseModel):
|
||||
"""查询所有可用斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
|
||||
|
||||
class ListSlashCommandsTool(MoviePilotTool):
|
||||
name: str = "list_slash_commands"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.SlashCommand,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"List all available slash commands in the system, including system preset commands "
|
||||
"(e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
@@ -12,10 +13,8 @@ from app.log import logger
|
||||
class ModifyDownloadInput(BaseModel):
|
||||
"""修改下载任务工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
@@ -39,6 +38,11 @@ class ModifyDownloadTool(MoviePilotTool):
|
||||
"""修改下载任务工具"""
|
||||
|
||||
name: str = "modify_download"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Download,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Modify a download task in the downloader by task hash. "
|
||||
"Supports: 1) Setting tags on a download task, "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_builtin_rules,
|
||||
serialize_builtin_rule,
|
||||
@@ -17,10 +18,8 @@ from app.log import logger
|
||||
class QueryBuiltinFilterRulesInput(BaseModel):
|
||||
"""查询内置过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of built-in rule IDs to query. If omitted, return all built-in rules.",
|
||||
@@ -29,6 +28,10 @@ class QueryBuiltinFilterRulesInput(BaseModel):
|
||||
|
||||
class QueryBuiltinFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_builtin_filter_rules"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.FilterRule,
|
||||
]
|
||||
description: str = (
|
||||
"Query built-in filter rules defined by the backend filter module. "
|
||||
"These rule IDs can be used directly inside rule_string expressions for filter rule groups. "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
@@ -18,10 +19,8 @@ from app.log import logger
|
||||
class QueryCustomFilterRulesInput(BaseModel):
|
||||
"""查询自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of custom rule IDs to query. If omitted, return all custom rules.",
|
||||
@@ -34,6 +33,10 @@ class QueryCustomFilterRulesInput(BaseModel):
|
||||
|
||||
class QueryCustomFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_custom_filter_rules"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.FilterRule,
|
||||
]
|
||||
description: str = (
|
||||
"Query custom filter rules stored in CustomFilterRules. "
|
||||
"Custom rules can be referenced from rule_string expressions in filter rule groups. "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
@@ -14,14 +15,17 @@ from app.schemas.types import SystemConfigKey
|
||||
class QueryCustomIdentifiersInput(BaseModel):
|
||||
"""查询自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
|
||||
|
||||
class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "query_custom_identifiers"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query all currently configured custom identifiers (自定义识别词). "
|
||||
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
|
||||
|
||||
@@ -6,13 +6,14 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDirectorySettingsInput(BaseModel):
|
||||
"""查询系统目录设置工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
directory_type: Optional[str] = Field("all",
|
||||
description="Filter directories by type: 'download' for download directories, 'library' for media library directories, 'all' for all directories")
|
||||
storage_type: Optional[str] = Field("all",
|
||||
@@ -23,6 +24,12 @@ class QueryDirectorySettingsInput(BaseModel):
|
||||
|
||||
class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
name: str = "query_directory_settings"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Directory,
|
||||
ToolTag.Settings,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query system directory configuration settings (NOT file listings). Returns configured directory paths, storage types, transfer modes, and other directory-related settings. Use 'list_directory' to list actual files and folders in a directory."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryDirectorySettingsInput
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
@@ -15,7 +16,7 @@ from app.schemas.types import TorrentStatus, media_type_to_agent
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
@@ -27,6 +28,10 @@ class QueryDownloadTasksInput(BaseModel):
|
||||
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Download,
|
||||
]
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash, title, or tag. Shows download progress, completion status, tags, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
@@ -13,11 +14,16 @@ from app.schemas.types import SystemConfigKey
|
||||
|
||||
class QueryDownloadersInput(BaseModel):
|
||||
"""查询下载器工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Download,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
@@ -6,13 +6,14 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryEpisodeScheduleInput(BaseModel):
|
||||
"""查询剧集上映时间工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series (can be obtained from search_media tool)")
|
||||
season: int = Field(..., description="Season number to query")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
@@ -20,6 +21,10 @@ class QueryEpisodeScheduleInput(BaseModel):
|
||||
|
||||
class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
name: str = "query_episode_schedule"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Query TV series episode air dates and schedule. Returns non-duplicated schedule fields, including episode list, air-date statistics, and per-episode metadata. Filters out episodes without air dates."
|
||||
args_schema: Type[BaseModel] = QueryEpisodeScheduleInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
MAX_PLUGIN_CANDIDATE_LIMIT,
|
||||
@@ -20,10 +21,8 @@ from app.log import logger
|
||||
class QueryInstalledPluginsInput(BaseModel):
|
||||
"""查询已安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter installed plugins by plugin ID, name, description, or author.",
|
||||
@@ -36,6 +35,11 @@ class QueryInstalledPluginsInput(BaseModel):
|
||||
|
||||
class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
name: str = "query_installed_plugins"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query installed plugins in MoviePilot. Returns all installed plugins or filters them by keywords. "
|
||||
"Use this tool to find the exact plugin_id before uninstall_plugin or other plugin management tools are used."
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Type, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.log import logger
|
||||
@@ -76,7 +77,7 @@ def _build_tv_server_result(existing_seasons: OrderedDict, total_seasons: Ordere
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
@@ -84,6 +85,11 @@ class QueryLibraryExistsInput(BaseModel):
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Library,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Check whether media already exists in Plex, Emby, or Jellyfin by media ID. Results are grouped by media server; TV results include existing episodes, total episodes, and missing episodes/seasons. Requires tmdb_id or douban_id from search_media."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
@@ -17,10 +18,8 @@ PAGE_SIZE = 20
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
server: Optional[str] = Field(
|
||||
None,
|
||||
description="Media server name (optional, if not specified queries all enabled media servers)",
|
||||
@@ -32,6 +31,11 @@ class QueryLibraryLatestInput(BaseModel):
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Library,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
MAX_PLUGIN_CANDIDATE_LIMIT,
|
||||
@@ -20,10 +21,8 @@ from app.log import logger
|
||||
class QueryMarketPluginsInput(BaseModel):
|
||||
"""查询插件市场工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter plugin market results by plugin ID, name, description, or author.",
|
||||
@@ -40,6 +39,11 @@ class QueryMarketPluginsInput(BaseModel):
|
||||
|
||||
class QueryMarketPluginsTool(MoviePilotTool):
|
||||
name: str = "query_market_plugins"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query available plugins from the plugin market and local plugin repositories. "
|
||||
"Can return the full plugin list or filter by keywords before install_plugin is used."
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
@@ -17,7 +18,7 @@ SEASON_PREVIEW_LIMIT = 100
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID of the media (alternative to tmdb_id)")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
@@ -25,6 +26,10 @@ class QueryMediaDetailInput(BaseModel):
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Query supplementary media details from TMDB by ID and media_type. Accepts tmdb_id or douban_id (at least one required). media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
|
||||
@@ -7,16 +7,15 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPersonasInput(BaseModel):
|
||||
"""查询人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
@@ -28,6 +27,10 @@ class QueryPersonasInput(BaseModel):
|
||||
|
||||
class QueryPersonasTool(MoviePilotTool):
|
||||
name: str = "query_personas"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Persona,
|
||||
]
|
||||
description: str = (
|
||||
"List all available personas (人格) and show which one is currently active. "
|
||||
"Use this before switching persona when the user asks for a different speaking style but does not name "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
@@ -13,10 +14,8 @@ from app.log import logger
|
||||
class QueryPluginCapabilitiesInput(BaseModel):
|
||||
"""查询插件能力工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional plugin ID to query capabilities for a specific plugin. "
|
||||
@@ -27,6 +26,11 @@ class QueryPluginCapabilitiesInput(BaseModel):
|
||||
|
||||
class QueryPluginCapabilitiesTool(MoviePilotTool):
|
||||
name: str = "query_plugin_capabilities"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query the capabilities of installed plugins, including supported commands and scheduled services. "
|
||||
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_slash_command tool. "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
@@ -14,10 +15,8 @@ from app.log import logger
|
||||
class QueryPluginConfigInput(BaseModel):
|
||||
"""查询插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
@@ -26,6 +25,11 @@ class QueryPluginConfigInput(BaseModel):
|
||||
|
||||
class QueryPluginConfigTool(MoviePilotTool):
|
||||
name: str = "query_plugin_config"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query the saved configuration of an installed plugin. "
|
||||
"Returns the current saved config and, when available, the plugin's default config model. "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
PLUGIN_DATA_KEY_PREVIEW_LIMIT,
|
||||
build_preview_payload,
|
||||
@@ -18,10 +19,8 @@ from app.log import logger
|
||||
class QueryPluginDataInput(BaseModel):
|
||||
"""查询插件数据工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
@@ -38,6 +37,11 @@ class QueryPluginDataInput(BaseModel):
|
||||
|
||||
class QueryPluginDataTool(MoviePilotTool):
|
||||
name: str = "query_plugin_data"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query persisted data of an installed plugin. "
|
||||
"Optionally specify a key to read a single data item; otherwise all plugin data entries are returned. "
|
||||
|
||||
@@ -7,8 +7,9 @@ import cn2an
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
@@ -17,7 +18,7 @@ MAX_PAGE_SIZE = 50
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
|
||||
@@ -30,6 +31,11 @@ class QueryPopularSubscribesInput(BaseModel):
|
||||
|
||||
class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
name: str = "query_popular_subscribes"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Subscription,
|
||||
ToolTag.Recommendation,
|
||||
]
|
||||
description: str = "Query popular subscriptions based on user shared data. Shows media with the most subscribers, supports filtering by genre, rating, minimum subscribers, and pagination."
|
||||
args_schema: Type[BaseModel] = QueryPopularSubscribesInput
|
||||
|
||||
@@ -77,8 +83,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
subscribes = await subscribe_helper.async_get_statistic(
|
||||
subscribes = await MoviePilotServerHelper.async_get_subscribe_statistic(
|
||||
stype=media_type_enum.to_agent(),
|
||||
page=page,
|
||||
count=count,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_rule_group_usages,
|
||||
get_rule_groups,
|
||||
@@ -18,10 +19,8 @@ from app.log import logger
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
group_names: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of rule group names to query. If omitted, return all rule groups.",
|
||||
@@ -34,6 +33,10 @@ class QueryRuleGroupsInput(BaseModel):
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.FilterRule,
|
||||
]
|
||||
description: str = (
|
||||
"Query filter rule groups (过滤规则组 / 优先级规则组). "
|
||||
"Each rule group contains a rule_string made of built-in rules and/or custom rules. "
|
||||
|
||||
@@ -6,16 +6,21 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySchedulersInput(BaseModel):
|
||||
"""查询定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QuerySchedulersTool(MoviePilotTool):
|
||||
name: str = "query_schedulers"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Scheduler,
|
||||
]
|
||||
description: str = "Query scheduled tasks and list all available scheduler jobs. Shows job status, next run time, and provider information."
|
||||
args_schema: Type[BaseModel] = QuerySchedulersInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
@@ -23,10 +24,8 @@ def _preview_list(value, limit: int = SITE_USERDATA_DETAIL_PREVIEW_LIMIT) -> tup
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to query user data for (can be obtained from query_sites tool)",
|
||||
@@ -39,6 +38,11 @@ class QuerySiteUserdataInput(BaseModel):
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Site,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
@@ -13,10 +14,8 @@ from app.log import logger
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites",
|
||||
@@ -28,6 +27,11 @@ class QuerySitesInput(BaseModel):
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Site,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
@@ -17,10 +18,8 @@ PAGE_SIZE = 20
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
@@ -35,6 +34,10 @@ class QuerySubscribeHistoryInput(BaseModel):
|
||||
|
||||
class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_history"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Subscription,
|
||||
]
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.log import logger
|
||||
|
||||
MAX_PAGE_SIZE = 50
|
||||
@@ -14,7 +15,7 @@ MAX_PAGE_SIZE = 50
|
||||
|
||||
class QuerySubscribeSharesInput(BaseModel):
|
||||
"""查询订阅分享工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
|
||||
@@ -26,6 +27,10 @@ class QuerySubscribeSharesInput(BaseModel):
|
||||
|
||||
class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_shares"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Subscription,
|
||||
]
|
||||
description: str = "Query shared subscriptions from other users. Shows popular subscriptions shared by the community with filtering and pagination support."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeSharesInput
|
||||
|
||||
@@ -68,8 +73,7 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
# 订阅分享是外部列表型结果,限制单页大小能降低工具上下文占用。
|
||||
count = min(count, MAX_PAGE_SIZE)
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
shares = await subscribe_helper.async_get_shares(
|
||||
shares = await MoviePilotServerHelper.async_get_subscribe_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.subscribe import Subscribe as SubscribeSchema
|
||||
@@ -33,6 +34,7 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"sites",
|
||||
"downloader",
|
||||
"best_version",
|
||||
"best_version_full",
|
||||
"current_priority",
|
||||
"episode_priority",
|
||||
"save_path",
|
||||
@@ -46,10 +48,8 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions",
|
||||
@@ -72,6 +72,10 @@ class QuerySubscribesInput(BaseModel):
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Subscription,
|
||||
]
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription. Supports pagination with 100 items per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._system_setting_utils import (
|
||||
SettingSpec,
|
||||
list_setting_specs,
|
||||
@@ -19,10 +20,8 @@ from app.log import logger
|
||||
class QuerySystemSettingsInput(BaseModel):
|
||||
"""查询系统设置工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
setting_key: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
@@ -58,6 +57,12 @@ class QuerySystemSettingsInput(BaseModel):
|
||||
|
||||
class QuerySystemSettingsTool(MoviePilotTool):
|
||||
name: str = "query_system_settings"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.System,
|
||||
ToolTag.Settings,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Query system settings across both the basic Settings module and all SystemConfig-backed categories. "
|
||||
"Use this tool to inspect downloaders, media servers, notification channels, storages, directories, search-site ranges, "
|
||||
|
||||
@@ -3,19 +3,20 @@
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import jieba
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
from app.utils.jieba import cut as jieba_cut
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
"""查询整理历史记录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="Search by title (optional, supports partial match)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter by status: 'success' for successful transfers, 'failed' for failed transfers, 'all' for all records (default: 'all')")
|
||||
@@ -24,6 +25,10 @@ class QueryTransferHistoryInput(BaseModel):
|
||||
|
||||
class QueryTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "query_transfer_history"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Transfer,
|
||||
]
|
||||
description: str = "Query file transfer history records. Shows transfer status, source and destination paths, media information, and transfer details. Supports filtering by title and status."
|
||||
args_schema: Type[BaseModel] = QueryTransferHistoryInput
|
||||
|
||||
@@ -69,8 +74,8 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用 jieba 分词处理标题
|
||||
words = jieba.cut(title, HMM=False)
|
||||
# 使用统一分词封装处理标题,便于替换底层实现。
|
||||
words = jieba_cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await TransferHistory.async_list_by_title(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
@@ -13,7 +14,7 @@ from app.log import logger
|
||||
|
||||
class QueryWorkflowsInput(BaseModel):
|
||||
"""查询工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
state: Optional[str] = Field("all", description="Filter workflows by state: 'W' for waiting, 'R' for running, 'P' for paused, 'S' for success, 'F' for failed, 'all' for all workflows (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter workflows by name (partial match, optional)")
|
||||
trigger_type: Optional[str] = Field("all", description="Filter workflows by trigger type: 'timer' for scheduled, 'event' for event-triggered, 'manual' for manual, 'all' for all types (default: 'all')")
|
||||
@@ -21,6 +22,10 @@ class QueryWorkflowsInput(BaseModel):
|
||||
|
||||
class QueryWorkflowsTool(MoviePilotTool):
|
||||
name: str = "query_workflows"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Workflow,
|
||||
]
|
||||
description: str = "Query workflow list and status. Shows workflow name, description, trigger type, state, execution count, and other workflow details. Supports filtering by state, name, and trigger type."
|
||||
args_schema: Type[BaseModel] = QueryWorkflowsInput
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
# 最大读取大小 50KB
|
||||
@@ -22,6 +23,10 @@ class ReadFileInput(BaseModel):
|
||||
|
||||
class ReadFileTool(MoviePilotTool):
|
||||
name: str = "read_file"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.File,
|
||||
]
|
||||
description: str = "Read the content of a text file. Supports reading by line range. Each read is limited to 50KB; content exceeding this limit will be truncated."
|
||||
args_schema: Type[BaseModel] = ReadFileInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -15,7 +16,7 @@ from app.schemas.types import media_type_to_agent
|
||||
|
||||
class RecognizeMediaInput(BaseModel):
|
||||
"""识别媒体信息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="The title of the torrent/media to recognize (required for torrent recognition)")
|
||||
subtitle: Optional[str] = Field(None, description="The subtitle or description of the torrent (optional, helps improve recognition accuracy)")
|
||||
path: Optional[str] = Field(None, description="The file path to recognize (required for file recognition, mutually exclusive with title)")
|
||||
@@ -23,6 +24,11 @@ class RecognizeMediaInput(BaseModel):
|
||||
|
||||
class RecognizeMediaTool(MoviePilotTool):
|
||||
name: str = "recognize_media"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
ToolTag.Metadata,
|
||||
]
|
||||
description: str = "Extract/identify media information from torrent titles or file paths (NOT database search). Supports two modes: 1) Extract from torrent title and optional subtitle, 2) Extract from file path. Returns detailed media information. Use 'search_media' to search TMDB database, or 'scrape_metadata' to generate metadata files for existing files."
|
||||
args_schema: Type[BaseModel] = RecognizeMediaInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
reload_plugin_runtime,
|
||||
@@ -16,10 +17,8 @@ from app.log import logger
|
||||
class ReloadPluginInput(BaseModel):
|
||||
"""重载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to reload so the latest saved config takes effect.",
|
||||
@@ -28,6 +27,11 @@ class ReloadPluginInput(BaseModel):
|
||||
|
||||
class ReloadPluginTool(MoviePilotTool):
|
||||
name: str = "reload_plugin"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Reload an installed plugin so its latest saved configuration takes effect. "
|
||||
"This also refreshes the plugin's registered commands, scheduled services, and API routes."
|
||||
|
||||
@@ -5,16 +5,15 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
|
||||
@@ -23,6 +22,11 @@ class RunSchedulerInput(BaseModel):
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Scheduler,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
@@ -14,10 +15,8 @@ from app.schemas.types import EventType, MessageChannel
|
||||
class RunSlashCommandInput(BaseModel):
|
||||
"""运行斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
command: str = Field(
|
||||
...,
|
||||
description="The slash command to execute, e.g. '/cookiecloud'. "
|
||||
@@ -29,6 +28,11 @@ class RunSlashCommandInput(BaseModel):
|
||||
|
||||
class RunSlashCommandTool(MoviePilotTool):
|
||||
name: str = "run_slash_command"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.SlashCommand,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Execute a slash command (system or plugin) by sending a CommandExcute event. "
|
||||
"This tool supports ALL registered slash commands, including: "
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
@@ -14,10 +15,8 @@ from app.log import logger
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
workflow_id: int = Field(
|
||||
..., description="Workflow ID (can be obtained from query_workflows tool)"
|
||||
)
|
||||
@@ -29,6 +28,11 @@ class RunWorkflowInput(BaseModel):
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Workflow,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
@@ -15,10 +16,8 @@ from app.schemas import FileItem
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
@@ -35,6 +34,13 @@ class ScrapeMetadataInput(BaseModel):
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Media,
|
||||
ToolTag.Metadata,
|
||||
ToolTag.File,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
@@ -13,7 +14,7 @@ from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
"""搜索媒体工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
@@ -24,6 +25,10 @@ class SearchMediaInput(BaseModel):
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Search TMDB database for media resources (movies, TV shows, anime, etc.) by title, year, type, and other criteria. Returns detailed media information from TMDB. Use 'recognize_media' to extract info from torrent titles/file paths, or 'scrape_metadata' to generate metadata files."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
|
||||
@@ -6,18 +6,23 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonInput(BaseModel):
|
||||
"""搜索人物工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')")
|
||||
|
||||
|
||||
class SearchPersonTool(MoviePilotTool):
|
||||
name: str = "search_person"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database."
|
||||
args_schema: Type[BaseModel] = SearchPersonInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.chain.bangumi import BangumiChain
|
||||
@@ -14,7 +15,7 @@ from app.log import logger
|
||||
|
||||
class SearchPersonCreditsInput(BaseModel):
|
||||
"""搜索演员参演作品工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)")
|
||||
source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
@@ -22,6 +23,10 @@ class SearchPersonCreditsInput(BaseModel):
|
||||
|
||||
class SearchPersonCreditsTool(MoviePilotTool):
|
||||
name: str = "search_person_credits"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in."
|
||||
args_schema: Type[BaseModel] = SearchPersonCreditsInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
@@ -14,7 +15,7 @@ from app.schemas.types import media_type_to_agent
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes (can be obtained from query_subscribes tool)")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
@@ -23,6 +24,12 @@ class SearchSubscribeInput(BaseModel):
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
name: str = "search_subscribe"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Write,
|
||||
ToolTag.Subscription,
|
||||
ToolTag.Resource,
|
||||
]
|
||||
description: str = "Search for missing episodes/resources for a specific subscription. This tool will search torrent sites for the missing episodes of the subscription and automatically download matching resources. Use this when a user wants to search for missing episodes of a specific subscription."
|
||||
args_schema: Type[BaseModel] = SearchSubscribeInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.search import SearchChain
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper
|
||||
@@ -19,7 +20,7 @@ from ._torrent_search_utils import (
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
@@ -29,6 +30,12 @@ class SearchTorrentsInput(BaseModel):
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Resource,
|
||||
ToolTag.Site,
|
||||
ToolTag.Media,
|
||||
]
|
||||
description: str = ("Search for torrent files by media ID across configured indexer sites, cache the matched results, "
|
||||
"and return available filter options for follow-up selection. "
|
||||
"Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching.")
|
||||
|
||||
@@ -1,78 +1,169 @@
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Optional, Type, List, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from ddgs import DDGS
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
# 搜索超时时间(秒)
|
||||
SEARCH_TIMEOUT = 20
|
||||
# 单次搜索最多返回结果数
|
||||
MAX_SEARCH_RESULTS = 20
|
||||
# 默认搜索源
|
||||
DEFAULT_SEARCH_ENGINE = "auto"
|
||||
# 可显式调用的搜索引擎后端
|
||||
SEARCH_ENGINE_BACKENDS = (
|
||||
"auto",
|
||||
"duckduckgo",
|
||||
"google",
|
||||
"brave",
|
||||
"yahoo",
|
||||
"wikipedia",
|
||||
"yandex",
|
||||
"mojeek",
|
||||
)
|
||||
SUPPORTED_SEARCH_ENGINES = SEARCH_ENGINE_BACKENDS
|
||||
DDGS_AUTO_BACKEND = ",".join(
|
||||
backend for backend in SEARCH_ENGINE_BACKENDS if backend != DEFAULT_SEARCH_ENGINE
|
||||
)
|
||||
SITE_SEARCH_PATTERN = re.compile(r"\bsite:", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SearchSiteFilter:
|
||||
"""站点限定搜索参数"""
|
||||
|
||||
domain: str
|
||||
path: str
|
||||
search_target: str
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
explanation: Optional[str] = Field(
|
||||
None,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: str = Field(
|
||||
..., description="The search query string to search for on the web"
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
20,
|
||||
MAX_SEARCH_RESULTS,
|
||||
description="Maximum number of search results to return (default: 20, max: 20)",
|
||||
)
|
||||
search_engine: Optional[str] = Field(
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
description=(
|
||||
"Search backend to use. Supported values: auto, duckduckgo, google, "
|
||||
"brave, yahoo, wikipedia, yandex, mojeek. "
|
||||
"Use auto unless the user asks for a specific search engine."
|
||||
),
|
||||
)
|
||||
site_url: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional website/domain/URL to limit the search to, for example "
|
||||
"'https://docs.python.org/3/' or 'github.com'."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
"""
|
||||
网络搜索工具,支持 DDGS 搜索引擎和指定站点限定搜索。
|
||||
"""
|
||||
|
||||
name: str = "search_web"
|
||||
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Web,
|
||||
]
|
||||
description: str = (
|
||||
"Search the web for information when you need current information, facts, "
|
||||
"or references. Supports DDGS-backed search engine selection, automatic "
|
||||
"fallback, and site_url-limited searches for a specified website "
|
||||
"or URL. Uses the configured system proxy by default. Returns search "
|
||||
"results with titles, snippets, and URLs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SearchWebInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 20)
|
||||
return f"搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
max_results = kwargs.get("max_results", MAX_SEARCH_RESULTS)
|
||||
search_engine = self._normalize_search_engine(kwargs.get("search_engine"))
|
||||
site_url = kwargs.get("site_url")
|
||||
message = f"搜索网络内容: {query} (最多返回 {max_results} 条结果"
|
||||
if search_engine != DEFAULT_SEARCH_ENGINE:
|
||||
message += f",搜索源: {search_engine}"
|
||||
if site_url:
|
||||
message += f",限定站点: {site_url}"
|
||||
return f"{message})"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = MAX_SEARCH_RESULTS,
|
||||
search_engine: Optional[str] = DEFAULT_SEARCH_ENGINE,
|
||||
site_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
执行网络搜索。
|
||||
|
||||
:param query: 搜索关键词
|
||||
:param max_results: 最大返回结果数
|
||||
:param search_engine: 指定搜索源,默认自动选择
|
||||
:param site_url: 指定站点或网址,传入时只返回该范围内的搜索结果
|
||||
:return: JSON格式的搜索结果或错误信息
|
||||
"""
|
||||
search_engine = self._normalize_search_engine(search_engine)
|
||||
if search_engine not in SUPPORTED_SEARCH_ENGINES:
|
||||
supported = ", ".join(SUPPORTED_SEARCH_ENGINES)
|
||||
return f"错误: 不支持的搜索源 '{search_engine}',支持的搜索源: {supported}"
|
||||
|
||||
site_filter = self._normalize_site_filter(site_url)
|
||||
if site_url and not site_filter:
|
||||
return f"错误: site_url 无效,无法限定搜索范围: {site_url}"
|
||||
|
||||
search_query = self._build_search_query(query=query, site_filter=site_filter)
|
||||
if not search_query:
|
||||
return "错误: query 不能为空"
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
|
||||
f"执行工具: {self.name}, 参数: query={query}, "
|
||||
f"max_results={max_results}, search_engine={search_engine}, site_url={site_url}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 20), 20)
|
||||
results = []
|
||||
max_results = min(
|
||||
max(1, max_results or MAX_SEARCH_RESULTS),
|
||||
MAX_SEARCH_RESULTS,
|
||||
)
|
||||
results: List[Dict] = []
|
||||
|
||||
# 1. 优先使用 Exa (如果配置了 API Key)
|
||||
if settings.EXA_API_KEY:
|
||||
logger.info("使用 Exa 进行搜索...")
|
||||
results = await self._search_exa(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Exa,使用 Tavily (如果配置了 API Key)
|
||||
if not results and settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 3. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
for engine in self._get_search_plan(search_engine):
|
||||
results = await self._search_with_backend(
|
||||
engine=engine,
|
||||
query=search_query,
|
||||
max_results=max_results,
|
||||
site_filter=site_filter,
|
||||
)
|
||||
if results:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
return f"未找到与 '{search_query}' 相关的搜索结果"
|
||||
|
||||
# 格式化并裁剪结果
|
||||
formatted_results = self._format_and_truncate_results(results, max_results)
|
||||
@@ -84,81 +175,214 @@ class SearchWebTool(MoviePilotTool):
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
# 从设置中随机选择一个 API Key(如果有多个)
|
||||
tavity_api_key = random.choice(settings.TAVILY_API_KEY)
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": tavity_api_key,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("content", ""),
|
||||
"url": result.get("url", ""),
|
||||
"source": "Tavily",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
def _normalize_search_engine(search_engine: Optional[str]) -> str:
|
||||
"""规范化搜索源参数"""
|
||||
engine = (search_engine or DEFAULT_SEARCH_ENGINE).strip().lower()
|
||||
aliases = {
|
||||
"ddgs": DEFAULT_SEARCH_ENGINE,
|
||||
"ddg": "duckduckgo",
|
||||
"duck": "duckduckgo",
|
||||
"search": DEFAULT_SEARCH_ENGINE,
|
||||
"search_engine": DEFAULT_SEARCH_ENGINE,
|
||||
}
|
||||
return aliases.get(engine, engine)
|
||||
|
||||
@staticmethod
|
||||
async def _search_exa(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Exa API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={
|
||||
"x-api-key": settings.EXA_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"numResults": max_results,
|
||||
"type": "auto",
|
||||
"contents": {"highlights": {"maxCharacters": 2000}},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
def _get_search_plan(search_engine: str) -> List[str]:
|
||||
"""根据搜索源配置生成兜底搜索顺序"""
|
||||
if search_engine != DEFAULT_SEARCH_ENGINE:
|
||||
return [search_engine]
|
||||
return [DEFAULT_SEARCH_ENGINE]
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
highlights = result.get("highlights", [])
|
||||
snippet = (
|
||||
highlights[0] if highlights else result.get("text", "")[:500]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": snippet,
|
||||
"url": result.get("url", ""),
|
||||
"source": "Exa",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Exa 搜索失败: {e}")
|
||||
return []
|
||||
async def _search_with_backend(
|
||||
self,
|
||||
engine: str,
|
||||
query: str,
|
||||
max_results: int,
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用指定后端执行搜索。
|
||||
|
||||
:param engine: 搜索后端名称
|
||||
:param query: 已加工的搜索关键词
|
||||
:param max_results: 最大结果数
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 搜索结果列表
|
||||
"""
|
||||
logger.info(f"使用 DDGS 搜索后端 {self._get_ddgs_backend(engine)} 进行搜索...")
|
||||
return await self._search_ddgs(query, max_results, engine, site_filter)
|
||||
|
||||
@staticmethod
|
||||
def _get_ddgs_backend(search_engine: str) -> str:
|
||||
"""
|
||||
获取实际传给 DDGS 的搜索后端。
|
||||
|
||||
:param search_engine: 用户指定的搜索源
|
||||
:return: DDGS 后端名称或逗号分隔的后端列表
|
||||
"""
|
||||
if search_engine == DEFAULT_SEARCH_ENGINE:
|
||||
return DDGS_AUTO_BACKEND
|
||||
return search_engine
|
||||
|
||||
@staticmethod
|
||||
def _normalize_site_filter(site_url: Optional[str]) -> Optional[_SearchSiteFilter]:
|
||||
"""
|
||||
将用户传入的网址转换为搜索引擎 site 过滤条件。
|
||||
|
||||
:param site_url: 用户传入的站点、域名或完整URL
|
||||
:return: 站点过滤条件,无法解析时返回 None
|
||||
"""
|
||||
if not site_url:
|
||||
return None
|
||||
|
||||
raw_site_url = site_url.strip()
|
||||
if not raw_site_url:
|
||||
return None
|
||||
|
||||
parse_target = raw_site_url
|
||||
if not re.match(r"^https?://", raw_site_url, re.IGNORECASE):
|
||||
parse_target = f"https://{raw_site_url}"
|
||||
|
||||
parsed = urlparse(parse_target)
|
||||
domain = (parsed.hostname or "").lower()
|
||||
if not domain:
|
||||
return None
|
||||
|
||||
path = re.sub(r"/+", "/", parsed.path or "").rstrip("/")
|
||||
search_target = f"{domain}{path}" if path else domain
|
||||
return _SearchSiteFilter(domain=domain, path=path, search_target=search_target)
|
||||
|
||||
@staticmethod
|
||||
def _build_search_query(
|
||||
query: str,
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> str:
|
||||
"""
|
||||
生成实际发送给搜索后端的搜索关键词。
|
||||
|
||||
:param query: 原始搜索关键词
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 加入 site 过滤后的关键词
|
||||
"""
|
||||
search_query = (query or "").strip()
|
||||
if not site_filter or SITE_SEARCH_PATTERN.search(search_query):
|
||||
return search_query
|
||||
if not search_query:
|
||||
return f"site:{site_filter.search_target}"
|
||||
return f"{search_query} site:{site_filter.search_target}"
|
||||
|
||||
@staticmethod
|
||||
def _filter_results_by_site(
|
||||
results: List[Dict],
|
||||
site_filter: Optional[_SearchSiteFilter],
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
根据指定站点过滤搜索结果。
|
||||
|
||||
:param results: 原始搜索结果
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 站点范围内的搜索结果
|
||||
"""
|
||||
if not site_filter:
|
||||
return results
|
||||
return [
|
||||
result
|
||||
for result in results
|
||||
if SearchWebTool._result_matches_site(result.get("url", ""), site_filter)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _result_matches_site(url: str, site_filter: _SearchSiteFilter) -> bool:
|
||||
"""
|
||||
判断搜索结果 URL 是否属于指定站点。
|
||||
|
||||
:param url: 搜索结果 URL
|
||||
:param site_filter: 站点限定条件
|
||||
:return: URL 属于指定站点时返回 True
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
parse_target = url
|
||||
if not re.match(r"^https?://", url, re.IGNORECASE):
|
||||
parse_target = f"https://{url}"
|
||||
|
||||
parsed = urlparse(parse_target)
|
||||
result_host = SearchWebTool._normalize_host(parsed.hostname or "")
|
||||
target_host = SearchWebTool._normalize_host(site_filter.domain)
|
||||
if not result_host or not target_host:
|
||||
return False
|
||||
if result_host != target_host and not result_host.endswith(f".{target_host}"):
|
||||
return False
|
||||
if not site_filter.path:
|
||||
return True
|
||||
|
||||
result_path = re.sub(r"/+", "/", parsed.path or "").rstrip("/")
|
||||
return result_path == site_filter.path or result_path.startswith(
|
||||
f"{site_filter.path}/"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_host(host: str) -> str:
|
||||
"""
|
||||
标准化域名以便比较。
|
||||
|
||||
:param host: 原始域名
|
||||
:return: 去掉常见 www 前缀后的域名
|
||||
"""
|
||||
normalized_host = (host or "").lower()
|
||||
if normalized_host.startswith("www."):
|
||||
return normalized_host[4:]
|
||||
return normalized_host
|
||||
|
||||
@staticmethod
|
||||
def _source_label(search_engine: str) -> str:
|
||||
"""
|
||||
将搜索源标识转换为结果中的展示名称。
|
||||
|
||||
:param search_engine: 搜索源标识
|
||||
:return: 展示名称
|
||||
"""
|
||||
labels = {
|
||||
"auto": "DDGS",
|
||||
"duckduckgo": "DuckDuckGo",
|
||||
"google": "Google",
|
||||
"brave": "Brave",
|
||||
"yahoo": "Yahoo",
|
||||
"wikipedia": "Wikipedia",
|
||||
"yandex": "Yandex",
|
||||
"mojeek": "Mojeek",
|
||||
}
|
||||
return labels.get(
|
||||
search_engine or DEFAULT_SEARCH_ENGINE,
|
||||
search_engine or "SearchEngine",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_result_url(result: Dict) -> str:
|
||||
"""
|
||||
从不同搜索引擎结果结构中提取 URL。
|
||||
|
||||
:param result: 搜索引擎返回的单条结果
|
||||
:return: URL 字符串
|
||||
"""
|
||||
return result.get("href") or result.get("url") or ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_result_snippet(result: Dict) -> str:
|
||||
"""
|
||||
从不同搜索引擎结果结构中提取摘要。
|
||||
|
||||
:param result: 搜索引擎返回的单条结果
|
||||
:return: 摘要字符串
|
||||
"""
|
||||
return (
|
||||
result.get("body")
|
||||
or result.get("snippet")
|
||||
or result.get("content")
|
||||
or ""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
@@ -169,11 +393,26 @@ class SearchWebTool(MoviePilotTool):
|
||||
return proxy_setting.get("http") or proxy_setting.get("https")
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
async def _search_ddgs(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int,
|
||||
search_engine: str = DEFAULT_SEARCH_ENGINE,
|
||||
site_filter: Optional[_SearchSiteFilter] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用 DDGS 搜索引擎后端进行搜索。
|
||||
|
||||
:param query: 搜索关键词
|
||||
:param max_results: 最大结果数
|
||||
:param search_engine: DDGS搜索后端
|
||||
:param site_filter: 站点限定条件
|
||||
:return: 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
|
||||
def sync_search():
|
||||
"""在线程中执行同步搜索"""
|
||||
results = []
|
||||
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
@@ -182,26 +421,36 @@ class SearchWebTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(query, max_results=max_results)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
ddgs_results = ddgs.text(
|
||||
query,
|
||||
max_results=max_results,
|
||||
backend=self._get_ddgs_backend(search_engine),
|
||||
)
|
||||
if ddgs_results:
|
||||
for result in ddgs_results:
|
||||
source = (
|
||||
DEFAULT_SEARCH_ENGINE
|
||||
if search_engine == DEFAULT_SEARCH_ENGINE
|
||||
else search_engine
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
"source": "DuckDuckGo",
|
||||
"snippet": self._extract_result_snippet(result),
|
||||
"url": self._extract_result_url(result),
|
||||
"source": self._source_label(source),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
logger.warning(f"搜索引擎搜索进程失败: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, sync_search)
|
||||
results = await loop.run_in_executor(None, sync_search)
|
||||
return self._filter_results_by_site(results, site_filter)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e}")
|
||||
logger.warning(f"搜索引擎搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
@@ -15,10 +16,8 @@ from app.schemas.types import MessageChannel
|
||||
class SendLocalFileInput(BaseModel):
|
||||
"""发送本地附件工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why sending this local file helps the user",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why sending this local file helps the user",)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Absolute path to the local image or file to send to the user",
|
||||
@@ -45,6 +44,11 @@ class SendLocalFileInput(BaseModel):
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Message,
|
||||
ToolTag.File,
|
||||
]
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
|
||||
@@ -5,16 +5,15 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
@@ -37,6 +36,11 @@ class SendMessageInput(BaseModel):
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Message,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
sends_message: bool = True
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
|
||||
@@ -15,8 +16,8 @@ from app.schemas import Notification, NotificationType
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
"""发送语音消息工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
explanation: Optional[str] = Field(
|
||||
None,
|
||||
description="Clear explanation of why a voice reply is the best fit in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
@@ -26,54 +27,65 @@ class SendVoiceMessageInput(BaseModel):
|
||||
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
"""发送 Agent 语音回复的工具。"""
|
||||
|
||||
name: str = "send_voice_message"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Message,
|
||||
ToolTag.TerminalResponse,
|
||||
]
|
||||
sends_message: bool = True
|
||||
return_direct: bool = True
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Use this only when the user explicitly asks for "
|
||||
"a voice reply or when spoken playback is clearly better than plain text. On channels "
|
||||
"without voice support or when TTS is unavailable, it automatically falls back to sending "
|
||||
"the same content as plain text."
|
||||
"the same content as plain text. This is a terminal response tool: put the complete "
|
||||
"user-facing reply in `message`; after this tool runs, do not send another text reply "
|
||||
"or call `send_message` with the same content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成语音回复工具的执行提示。"""
|
||||
message = kwargs.get("message") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"发送语音回复: {message}"
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
"""合成语音并发送到当前对话渠道,不支持时回退为文字。"""
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
reply_mode = VoiceHelper.resolve_reply_mode(
|
||||
reply_mode = AgentCapabilityManager.resolve_reply_mode(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
)
|
||||
fallback_reason = "当前渠道不支持语音回复"
|
||||
if not VoiceHelper.is_enabled():
|
||||
fallback_reason = "当前未启用音频输入输出"
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
fallback_reason = "当前未启用音频输出"
|
||||
if (
|
||||
reply_mode == VoiceHelper.REPLY_MODE_NATIVE
|
||||
and VoiceHelper.is_available("tts")
|
||||
reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE
|
||||
and AgentCapabilityManager.is_audio_output_available()
|
||||
):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
voice_file = await asyncio.to_thread(
|
||||
AgentCapabilityManager.synthesize_speech, message
|
||||
)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE:
|
||||
elif reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE:
|
||||
fallback_reason = "当前未配置可用的语音合成能力"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
self.name,
|
||||
channel,
|
||||
used_voice,
|
||||
len(message),
|
||||
f"执行工具: {self.name}, channel={channel}, "
|
||||
f"use_voice={used_voice}, text_len={len(message)}"
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
@@ -87,7 +99,7 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path=voice_path,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
|
||||
if voice_path and settings.AUDIO_OUTPUT_INCLUDE_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
"""切换当前激活人格工具。"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SwitchPersonaInput(BaseModel):
|
||||
"""切换人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@@ -28,6 +27,10 @@ class SwitchPersonaInput(BaseModel):
|
||||
|
||||
class SwitchPersonaTool(MoviePilotTool):
|
||||
name: str = "switch_persona"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Persona,
|
||||
]
|
||||
description: str = (
|
||||
"Switch the active persona (人格) used by the agent runtime. "
|
||||
"This change is persistent for future turns. "
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
@@ -12,12 +13,16 @@ from app.log import logger
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
explanation: Optional[str] = Field(None, description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: int = Field(..., description="Site ID to test (can be obtained from query_sites tool)")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Site,
|
||||
]
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem, MediaType
|
||||
|
||||
@@ -13,10 +14,8 @@ from app.schemas import FileItem, MediaType
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
@@ -56,6 +55,13 @@ class TransferFileInput(BaseModel):
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Transfer,
|
||||
ToolTag.Library,
|
||||
ToolTag.File,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
list_installed_plugins,
|
||||
summarize_plugin,
|
||||
@@ -17,10 +18,8 @@ from app.log import logger
|
||||
class UninstallPluginInput(BaseModel):
|
||||
"""卸载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to uninstall. Use query_installed_plugins first to find the correct plugin_id.",
|
||||
@@ -29,6 +28,11 @@ class UninstallPluginInput(BaseModel):
|
||||
|
||||
class UninstallPluginTool(MoviePilotTool):
|
||||
name: str = "uninstall_plugin"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Uninstall an installed plugin by exact plugin_id. "
|
||||
"Use query_installed_plugins first when you need filtering or discovery."
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
@@ -22,10 +23,8 @@ from app.schemas.types import SystemConfigKey
|
||||
class UpdateCustomFilterRuleInput(BaseModel):
|
||||
"""更新自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
current_rule_id: str = Field(
|
||||
..., description="Existing custom rule ID to update."
|
||||
)
|
||||
@@ -60,6 +59,11 @@ class UpdateCustomFilterRuleInput(BaseModel):
|
||||
|
||||
class UpdateCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "update_custom_filter_rule"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Update an existing custom filter rule. "
|
||||
"If the rule ID is renamed, all rule groups that reference the old ID are updated automatically."
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
@@ -14,10 +15,8 @@ from app.schemas.types import SystemConfigKey
|
||||
class UpdateCustomIdentifiersInput(BaseModel):
|
||||
"""更新自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
identifiers: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
@@ -35,6 +34,11 @@ class UpdateCustomIdentifiersInput(BaseModel):
|
||||
|
||||
class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "update_custom_identifiers"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.FilterRule,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Update the full list of custom identifiers (自定义识别词) used for preprocessing torrent/file names. "
|
||||
"This tool REPLACES all existing identifier rules with the provided list. "
|
||||
@@ -50,7 +54,8 @@ class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
"3) Episode offset: '前定位词 <> 后定位词 >> EP±N'; "
|
||||
"4) Combined: '被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP±N'; "
|
||||
"Lines starting with '#' are comments. "
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;g=xxx;s=xxx;e=xxx]} "
|
||||
"for direct TMDB ID matching; g is an optional TMDB episode group ID for TV recognition."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
|
||||
|
||||
@@ -7,16 +7,15 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UpdatePersonaDefinitionInput(BaseModel):
|
||||
"""更新人格定义工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@@ -58,6 +57,11 @@ class UpdatePersonaDefinitionInput(BaseModel):
|
||||
|
||||
class UpdatePersonaDefinitionTool(MoviePilotTool):
|
||||
name: str = "update_persona_definition"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Persona,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Create or update a runtime persona definition (人格定义) without manually editing PERSONA.md files. "
|
||||
"Use this when the user explicitly asks to modify how a persona is defined, such as changing tone rules, "
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
@@ -14,10 +15,8 @@ from app.log import logger
|
||||
class UpdatePluginConfigInput(BaseModel):
|
||||
"""修改插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
explanation: Optional[str] = Field(None,
|
||||
description="Clear explanation of why this tool is being used in the current context",)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to update. Use query_plugin_config first to inspect the current config.",
|
||||
@@ -44,6 +43,11 @@ class UpdatePluginConfigInput(BaseModel):
|
||||
|
||||
class UpdatePluginConfigTool(MoviePilotTool):
|
||||
name: str = "update_plugin_config"
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.Plugin,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = (
|
||||
"Update the saved configuration of an installed plugin. "
|
||||
"By default this performs a partial merge update and does NOT reload the plugin automatically. "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user