Compare commits

...

164 Commits

Author SHA1 Message Date
jxxghp
48b1ac28de v2.8.2
- 新增 `MoviePilot助手` 智能体,支持自然语言对话完成任务(Beta版本,设置中打开开关并配置好大模型参数,通过 `/ai` 发送聊天内容),支持通过插件完善智能体能力
- 适配了新版本的飞牛影视
- 其它问题修复与细节改进

注意:基础组件升级,个别插件可能会有兼容性问题,需要插件适配。
2025-11-17 14:00:39 +08:00
jxxghp
6e329b17a9 Enhance Telegram message formatting: add detailed guidelines for MarkdownV2 usage, including support for strikethrough, headings, and lists. Implement smart escaping for Markdown to preserve formatting while avoiding API errors. 2025-11-17 13:49:56 +08:00
jxxghp
6a492198a8 fix post_message 2025-11-17 13:33:01 +08:00
jxxghp
8bf9b6e7cb feat:Agent插件工具发现 2025-11-17 13:00:23 +08:00
jxxghp
42e23ef564 Refactor agent workflows: streamline subscription and download processes, enhance query status workflow, and improve tool usage guidelines for better user interaction. 2025-11-17 12:49:03 +08:00
jxxghp
c6806ee648 fix agent tools 2025-11-17 12:34:20 +08:00
jxxghp
076fae696c fix 2025-11-17 11:57:46 +08:00
jxxghp
ed294d3ea4 Revert "fix schemas"
This reverts commit a5e7483870.
2025-11-17 11:48:18 +08:00
jxxghp
043be409d0 Enhance agent workflows and tools: unify subscription and download processes, add site querying functionality, and improve error handling in download operations. 2025-11-17 11:39:08 +08:00
jxxghp
a5e7483870 fix schemas 2025-11-17 10:58:24 +08:00
jxxghp
365335be46 rollback 2025-11-17 10:51:16 +08:00
jxxghp
62543dd171 fix:优化Agent消息发送格式 2025-11-17 10:43:16 +08:00
jxxghp
e2eef8ff21 fix agent message title 2025-11-17 10:18:05 +08:00
jxxghp
3acf937d56 fix add_download tool 2025-11-17 10:16:54 +08:00
jxxghp
d572e523ba 优化Agent上下文大小 2025-11-17 09:57:12 +08:00
jxxghp
82113abe88 fix agent sendmsg 2025-11-17 09:42:27 +08:00
jxxghp
b7d121c58f fix agent tools 2025-11-17 09:28:18 +08:00
jxxghp
6d5a85b144 fix search tools 2025-11-17 09:14:36 +08:00
jxxghp
78121917c6 Merge pull request #5112 from wikrin/fix_tests 2025-11-12 20:39:38 +08:00
jxxghp
a0913f0e32 Merge pull request #5109 from jiongjiongJOJO/dev 2025-11-12 20:39:10 +08:00
jxxghp
e96e284715 Merge pull request #5107 from wikrin/fix 2025-11-12 20:38:40 +08:00
Attente
c572a1b607 fix(tests): 修正 restype, 测试用例不使用识别词 2025-11-12 14:13:05 +08:00
囧囧JOJO
1845311f98 fix: 修复Docker编译时版本不兼容导致的报错问题
参考三楼回复:
https://stackoverflow.com/questions/76717537/valueerror-requirement-object-has-no-field-use-pep517-when-installing-pytho
2025-11-11 17:46:34 +08:00
Attente
4f806db8b7 fix: 修复变更默认下载器不生效的问题
- 配置模块迁移到 `SettingsConfigDict` 以支持 Pydantic v2 的配置方式
- 在 `MediaInfo` 中新增 `release_dates` 字段,用于存储多地区发行日期信息
- 修改 `MetaVideo` 类中的 token 传递逻辑,以修复搜索站点资源序列化错误的问题
2025-11-11 10:44:45 +08:00
jxxghp
22858cc1e9 Merge pull request #5100 from Seed680/v2 2025-11-06 18:43:41 +08:00
noone
a0329a3eb0 feat(rss): 支持自定义User-Agent获取RSS。目前有些站点没有配置UA时会不能正确获取RSS内容
- 在RSS方法中新增ua参数用于指定User-Agent
- 更新RequestUtils调用以传递自定义User-Agent
- 修改torrents链中RSS解析逻辑以支持站点配置的ua字段
- 设置默认超时时间为30秒以增强稳定性
2025-11-06 16:32:01 +08:00
jxxghp
b3e92088ee Merge pull request #5097 from wumode/refector-check_method 2025-11-05 23:15:24 +08:00
jxxghp
46db1c20f1 Merge pull request #5096 from cddjr/fix_trimemedia_cookies 2025-11-05 23:14:18 +08:00
wumode
9d182e53b2 fix: type hints 2025-11-05 15:41:31 +08:00
景大侠
1205fc7fdb 避免不必要的图片cookies查询 2025-11-05 15:22:02 +08:00
wumode
ff2826a448 feat(utils): Refactor check_method to use ast
- 使用 AST 解析函数源码,相比基于字符串的方法更稳定,能够正确处理具有多行 def 语句的函数
- 为 check_method 添加了单元测试
2025-11-05 14:16:37 +08:00
大虾
ee750115ec Update
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-04 13:29:45 +08:00
景大侠
0e13d22c97 fix 适配新版飞牛影视 2025-11-04 13:25:18 +08:00
jxxghp
8e7d040ac4 Merge pull request #5091 from wikrin/cached 2025-11-03 09:51:37 +08:00
Attente
6755202958 feat(cache): 使用 fresh 和 async_fresh 统一缓存控制方式
- 修复因缓存导致的插件更新后仍有更新提示的问题
- 统一使用 fresh/async_fresh 控制缓存行为
- 调整 TMDb 模块缓存策略,优化异步请求缓存清除机制
- 移除冗余的缓存方法封装,减少调用层级
- 简化 PluginHelper 中的缓存方法结构,移除 force 参数
2025-11-03 07:41:42 +08:00
jxxghp
8b7374a687 Merge pull request #5090 from wikrin/fix 2025-11-02 07:35:00 +08:00
Attente
c17cca2365 fix(update_setting): 修复设置保存错误的问题
- adapt to Pydantic V2
2025-11-01 23:51:59 +08:00
jxxghp
8016a9539a fix agent 2025-11-01 19:08:05 +08:00
jxxghp
e885fb15a0 Merge pull request #5089 from wikrin/fix 2025-11-01 18:27:35 +08:00
Attente
c7f098771b feat: adapt to Pydantic V2 2025-11-01 17:56:37 +08:00
Attente
fcd0908032 fix(transfer): 修复指定part不生效的问题 2025-11-01 17:56:23 +08:00
jxxghp
7ff1285084 fix agent tools 2025-11-01 12:07:17 +08:00
jxxghp
b45b603b97 fix agent tools 2025-11-01 12:01:48 +08:00
jxxghp
247208b8a9 fix agent 2025-11-01 11:41:22 +08:00
jxxghp
182c46037b fix agent 2025-11-01 10:40:45 +08:00
jxxghp
438d3210bc fix agent 2025-11-01 10:39:08 +08:00
jxxghp
d523c7c916 fix pydantic 2025-11-01 09:51:23 +08:00
jxxghp
09a19e94d5 fix config 2025-11-01 09:23:52 +08:00
jxxghp
3971c145df refactor: streamline data serialization in tool implementations
- Replaced model_dump and to_dict methods with direct calls to dict for improved consistency and performance in JSON serialization across multiple tools.
- Updated ConversationMemoryManager, GetRecommendationsTool, QueryDownloadsTool, and QueryMediaLibraryTool to enhance data handling.
2025-10-31 11:36:50 +08:00
jxxghp
055117d83d refactor: enhance tool message handling and improve error logging
- Updated _send_tool_message to accept a title parameter for better message context.
- Modified various tool implementations to utilize the new title parameter for clearer messaging.
- Improved error logging across multiple tools to include exception details for better debugging.
2025-10-31 09:16:53 +08:00
jxxghp
c6baf43986 Merge pull request #5085 from wumode/fix-event-handler-params 2025-10-29 07:50:47 +08:00
wumode
4ff16af3a7 fix: __invoke_plugin_method_async 中 __handle_event_error 参数传递错误 2025-10-28 20:09:44 +08:00
jxxghp
17a1bd352b Merge pull request #5071 from wikrin/optimize-file-size 2025-10-23 22:54:43 +08:00
Attente
7421ca09cc fix(transfer): 修复部分情况下无法正确统计已完成任务总大小的问题
- get_directory_size 使用 os.scandir 递归遍历提升性能
- 当任务文件项存储类型为 local 时,若其大小为空,则通过 SystemUtils 获取目录大小以确保
完成任务的准确统计。

fix(cache): 修改 fresh 和 async_fresh 默认参数为 True

refactor(filemanager): 移除整理后总大小计算逻辑

- 删除 TransHandler 中对整理目录总大小的冗余计算,提升性能并简化流程。

perf(system): 使用 scandir 优化文件扫描性能

- 重构 SystemUtils 中的文件扫描方法(list_files、exists_file、list_sub_files),
- 采用 os.scandir 替代 glob 实现,并预编译正则表达式以提升目录遍历与文件匹配性能。
2025-10-23 19:21:24 +08:00
jxxghp
9797e696e5 Merge pull request #5073 from WAY29/v2 2025-10-23 13:05:10 +08:00
jxxghp
c36d6d8b2d Merge pull request #5072 from wumode/fix_retry 2025-10-23 06:52:29 +08:00
wumode
3873786b99 fix: retry 2025-10-23 00:58:34 +08:00
WAY29
76fdba7f09 feat(endpoints): /download/add allow tmdbid/doubanid/bangumiid 2025-10-22 22:02:33 +08:00
jxxghp
72799e9638 Merge pull request #5068 from little6neko/v2 2025-10-22 06:18:28 +08:00
小六妞儿
2e77d03fe9 目录监控添加异常处理避免程序意外退出 2025-10-22 00:21:31 +08:00
jxxghp
0c58eae5e7 Merge pull request #5060 from wikrin/cached 2025-10-19 22:37:33 +08:00
Attente
b609567c38 feat(cache): 引入 fresh 和 async_fresh 以控制缓存行为
- 新增 `fresh` 和 `async_fresh` 用于在同步和异步函数中
临时禁用缓存。
- 通过 `_fresh` 这一 contextvars 变量实现上下文感知的
缓存刷新机制
- 修改了 `cached` 装饰器逻辑,在 `is_fresh()` 为 True
时跳过缓存读取。

- 修复 download 模块中路径处理问题,使用 `Path.as_posix()` 确保跨平台兼容性。
2025-10-19 22:31:50 +08:00
jxxghp
7ecfa44fa0 Merge pull request #5057 from xiaoQQya/v2 2025-10-19 06:49:23 +08:00
xiaoQQya
a685b1dc3b fix(douban): 修复 imdbid 匹配豆瓣信息成功错误返回 None 的问题 2025-10-18 22:34:55 +08:00
jxxghp
63ce49a17c Merge pull request #5056 from xiaoQQya/v2 2025-10-18 22:16:03 +08:00
xiaoQQya
820fbe4076 fix(douban): 修复使用 imdbid 未匹配到豆瓣信息时回退到使用名称匹配豆瓣信息失败的问题 2025-10-18 22:07:54 +08:00
jxxghp
efa05b7775 Update media tool descriptions for clarity and detail in JSON configuration 2025-10-18 22:00:24 +08:00
jxxghp
003781e903 add MoviePilot AI agent implementation and workflow manager 2025-10-18 21:55:31 +08:00
jxxghp
ee71bafc96 fix 2025-10-18 21:32:46 +08:00
jxxghp
bdd5f1231e add ai agent 2025-10-18 21:26:51 +08:00
jxxghp
6fee532c96 add ai agent 2025-10-18 21:26:36 +08:00
jxxghp
78aaad7b59 Merge pull request #5028 from ThedoRap/v2 2025-10-07 23:00:21 +08:00
Reaper
b128b0ede2 修复知行 极速之星 框架解析 做种信息 2025-10-02 20:43:06 +08:00
Reaper
737d2f3bc6 优化知行 极速之星 框架解析 2025-10-02 20:03:28 +08:00
jxxghp
179be53a65 Merge pull request #5025 from ThedoRap/v2 2025-10-02 06:50:37 +08:00
Reaper
1867f5e7c2 增加知行 极速之星 框架解析 2025-10-02 04:27:35 +08:00
jxxghp
6662d24565 Merge pull request #5019 from xiaoQQya/develop 2025-10-01 11:53:24 +08:00
jxxghp
5880566a99 Merge pull request #5018 from Aqr-K/fix-plugin 2025-10-01 11:52:43 +08:00
xiaoQQya
5d05b32711 fix: 修复 README 本地运行提示 No module named 'app' 的问题 2025-09-30 23:45:25 +08:00
Aqr-K
fa2b720e92 Refactor(plugins): Use pathlib.relative_to for robust plugin path resolution 2025-09-30 20:03:08 +08:00
jxxghp
d381238f83 Merge pull request #5017 from Aqr-K/fix-plugin 2025-09-30 19:55:46 +08:00
Aqr-K
751d627ead Merge branch 'fix-plugin' of https://github.com/aqr-k/MoviePilot into fix-plugin 2025-09-30 19:48:46 +08:00
Aqr-K
3e66a8de9b Rollback cache 2025-09-30 19:35:50 +08:00
Aqr-K
266052b12b Update app/core/plugin.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-09-30 19:26:33 +08:00
Aqr-K
803f4328f4 fix(plugins): Improve hot-reload robustness for multi-inheritance plugins 2025-09-30 18:34:26 +08:00
Aqr-K
8e95568e11 refactor(plugins): Improve hot-reloading with watchfiles 2025-09-30 18:01:02 +08:00
jxxghp
ab09ee4819 Merge pull request #4998 from Seed680/v2 2025-09-24 15:02:26 +08:00
noone
41f94a172f fix:对telegram发送标题进行转义 2025-09-24 14:29:42 +08:00
noone
566e597994 fix:撤销不必要转义 2025-09-24 14:26:09 +08:00
noone
765fb9c05f fix:更新Telegram解析模式为MarkdownV2;Telegram发送的内容按 Telegram V2 规则转义特殊字符 2025-09-24 14:14:11 +08:00
jxxghp
b6720a19f7 更新 plugin.py 2025-09-22 17:56:12 +08:00
jxxghp
3b130651c4 Merge pull request #4987 from Aqr-K/refactor/plugin-monitor 2025-09-22 11:41:13 +08:00
jxxghp
3f6c35dabe Merge pull request #4986 from Aqr-K/fix-plugin-reload 2025-09-22 07:08:33 +08:00
Aqr-K
db2a952bca refactor(plugin): Enhance hot reload with debounce and subdirectory support 2025-09-22 02:48:49 +08:00
Aqr-K
0ea9770bc3 Create debounce.py 2025-09-22 02:38:15 +08:00
Aqr-K
0b20956c90 fix 2025-09-21 22:42:18 +08:00
jxxghp
9f73b47d54 Merge pull request #4977 from jxxghp/cursor/fix-moviepilot-issue-4975-ff74 2025-09-19 18:15:08 +08:00
Cursor Agent
ce9c99af71 Refactor: Use copy instead of move for file operations
Co-authored-by: jxxghp <jxxghp@qq.com>
2025-09-19 09:54:44 +00:00
jxxghp
784024fb5d 更新 version.py 2025-09-19 08:50:33 +08:00
jxxghp
1145b32299 fix plugin install 2025-09-18 22:32:04 +08:00
jxxghp
ab71df0011 Merge pull request #4971 from cddjr/fix_glitch 2025-09-18 21:00:00 +08:00
jxxghp
fb137252a9 fix plugin id lower case 2025-09-18 18:00:15 +08:00
jxxghp
f57a680306 插件安装支持传递 repo_url 参数 2025-09-18 17:42:12 +08:00
景大侠
8bb3eaa320 fix 获取上次搜索结果时产生的NoneType异常
glitchtip#14
2025-09-18 17:23:20 +08:00
景大侠
9489730a44 fix u115刷新access_token失败会产生NoneType异常
glitchtip#49549
2025-09-18 17:23:20 +08:00
景大侠
d4795bb897 fix u115重试请求时报错unexpected keyword argument
glitchtip#136696
2025-09-18 17:23:19 +08:00
景大侠
63775872c7 fix TMDB因连接失败产生的NoneType错误
glitchtip#11
2025-09-18 17:05:09 +08:00
jxxghp
beff508a1f Merge pull request #4970 from cddjr/fix_trimemedia 2025-09-18 15:55:46 +08:00
景大侠
deaae8a2c6 fix 2025-09-18 15:39:10 +08:00
景大侠
46a27bd50c fix: 飞牛影视 2025-09-18 15:27:02 +08:00
jxxghp
24f2993433 Merge pull request #4958 from cddjr/fix_browse_mteam 2025-09-17 07:04:59 +08:00
景大侠
c80bfbfac5 fix: 浏览馒头报错NoneType 2025-09-17 01:59:28 +08:00
jxxghp
06abfc45c7 更新 version.py 2025-09-16 20:30:38 +08:00
jxxghp
440a773081 fix 2025-09-16 17:56:44 +08:00
jxxghp
0797bcb38b fix 2025-09-16 13:10:31 +08:00
jxxghp
d463b5bf0d Merge pull request #4955 from jxxghp/cursor/add-sort-type-to-subscription-queries-af67 2025-09-16 11:41:08 +08:00
Cursor Agent
0733c8edcc Add sort_type parameter to subscribe endpoints
Co-authored-by: jxxghp <jxxghp@qq.com>
2025-09-16 03:29:28 +00:00
jxxghp
86c7c05cb1 feat: 在获取订阅分享数据的接口中添加可选参数 2025-09-16 07:38:56 +08:00
jxxghp
18ff7ce753 feat: 在订阅统计中添加可选参数 2025-09-16 07:37:14 +08:00
jxxghp
8f2ed1004d Merge pull request #4952 from cddjr/fix_file_perm 2025-09-16 07:00:45 +08:00
景大侠
14961323c3 fix umask 2025-09-15 22:01:00 +08:00
景大侠
f8c682b183 fix: 修复刮削的文件权限只有0600的问题 2025-09-15 21:49:37 +08:00
jxxghp
dd92708f60 Merge pull request #4947 from pluto0x0/fix/4941-mttorent-imdb-search 2025-09-15 14:23:17 +08:00
Zifan Ying
4d9eeccefa fix: mtorrent搜索imdb时提供完整链接
fix: mtorrent搜索imdb时需要提供完整链接(例如https://www.imdb.com/title/tt3058674)
keyword为imdb条目时添加链接前缀
参考 https://wiki.m-team.cc/zh-tw/imdbtosearch
 
issue: https://github.com/jxxghp/MoviePilot/issues/4941
2025-09-15 00:31:45 -05:00
jxxghp
cd7b251031 Merge pull request #4946 from developer-wlj/wlj0914 2025-09-14 17:30:11 +08:00
developer-wlj
db614180b9 Revert "refactor: 优化临时文件的创建和上传逻辑"
This reverts commit 77c0f8f39e.
2025-09-14 17:14:52 +08:00
jxxghp
b6e527e5f4 Merge pull request #4945 from developer-wlj/wlj0914 2025-09-14 16:54:37 +08:00
developer-wlj
77c0f8f39e refactor: 优化临时文件的创建和上传逻辑
- 使用 with 语句自动管理临时文件的创建和关闭,提高代码的可读性和安全性
- 优化了代码结构,减少了嵌套的 try 语句,使代码更加清晰
2025-09-14 16:46:27 +08:00
jxxghp
58816d73c8 Merge pull request #4944 from developer-wlj/wlj0914 2025-09-14 16:42:37 +08:00
developer-wlj
3b194d282e fix: 修复在windows下因临时文件被占用,导致刮削失败
- 修改了两个函数中的临时文件创建和删除逻辑
- 使用手动删除代替自动删除,确保临时文件被正确清理
- 添加了异常处理,记录临时文件删除失败的情况
2025-09-14 16:28:24 +08:00
jxxghp
397f66433d v2.8.0 2025-09-13 15:58:00 +08:00
jxxghp
04a4ed1d0e fix delete_media_file 2025-09-13 14:10:15 +08:00
jxxghp
625850d4e7 fix 2025-09-13 13:35:51 +08:00
jxxghp
6c572baca5 rollback 2025-09-13 13:32:48 +08:00
jxxghp
ee0406a13f Handle smb protocol key error during disconnect (#4938)
* Refactor: Improve SMB connection handling and add signal handling

Co-authored-by: jxxghp <jxxghp@qq.com>

* Remove test_smb_fix.py

Co-authored-by: jxxghp <jxxghp@qq.com>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: jxxghp <jxxghp@qq.com>
2025-09-13 11:25:29 +08:00
jxxghp
608a049ba3 fix smb delete 2025-09-13 11:05:21 +08:00
jxxghp
4d9b5198e2 增强SMB存储的删除功能 2025-09-13 10:56:45 +08:00
jxxghp
24b6c970aa feat:emby用户名 2025-09-13 10:34:41 +08:00
jxxghp
239c47f469 fix #4917 2025-09-13 10:13:33 +08:00
jxxghp
f0fc64c517 fix #4917 2025-09-13 10:12:40 +08:00
jxxghp
8481fd38ce fix #4933 2025-09-13 09:54:28 +08:00
jxxghp
5f425129d5 fix #4934 2025-09-13 09:46:04 +08:00
jxxghp
92955b1315 fix:在fork进程中执行文件整理 2025-09-13 08:56:05 +08:00
jxxghp
a3872d5bb5 fix:在fork进程中执行文件整理 2025-09-13 08:50:20 +08:00
jxxghp
a123ff2c04 feat:在fork进程中执行文件整理 2025-09-13 08:32:31 +08:00
jxxghp
188de34306 mini chunk size 2025-09-12 21:45:26 +08:00
jxxghp
3d43750e9b fix async event 2025-09-10 17:33:12 +08:00
jxxghp
fea228c68d add SUPERUSER_PASSWORD 2025-09-10 15:42:17 +08:00
jxxghp
a71a28e563 更新 config.py 2025-09-10 07:00:10 +08:00
jxxghp
3b5d4982b5 add wizard flag 2025-09-09 13:50:11 +08:00
jxxghp
b201e9ab8c Revert "feat:在子进程中操作文件"
This reverts commit 4f304a70b7.
2025-09-08 17:23:25 +08:00
jxxghp
d30b9282fd fix alipan u115 error log 2025-09-08 17:13:01 +08:00
jxxghp
4f304a70b7 feat:在子进程中操作文件 2025-09-08 16:59:29 +08:00
jxxghp
59a54d4f04 fix plugin cache 2025-09-08 13:27:32 +08:00
jxxghp
1e94d794ed fix log 2025-09-08 12:12:00 +08:00
jxxghp
5bd210406b Merge pull request #4918 from cddjr/fix_4853 2025-09-08 11:36:41 +08:00
景大侠
e00514d36d fix: 将RSS中的发布日期转为本地时区 2025-09-08 11:28:08 +08:00
jxxghp
f013bf1931 fix 2025-09-08 10:59:28 +08:00
jxxghp
107cbbad1d fix 2025-09-08 10:54:45 +08:00
jxxghp
481f1f9d30 add full gc scheduler 2025-09-08 10:49:09 +08:00
jxxghp
704364061c fix redis test 2025-09-08 09:59:11 +08:00
jxxghp
c1bd2d6cf1 fix:优化下载 2025-09-08 09:50:08 +08:00
jxxghp
a018e1228c Merge pull request #4904 from DDS-Derek/fix_gosu 2025-09-05 21:40:41 +08:00
DDSRem
d962d9c7f6 feat(docker): add START_NOGOSU mode
fix https://github.com/jxxghp/MoviePilot/issues/4889
2025-09-05 21:30:59 +08:00
143 changed files with 4526 additions and 917 deletions

View File

@@ -40,10 +40,11 @@ git clone https://github.com/jxxghp/MoviePilot
```shell
git clone https://github.com/jxxghp/MoviePilot-Resources
```
- 安装后端依赖,设置`app`为源代码根目录,运行 `main.py` 启动后端服务,默认监听端口:`3001`API文档地址`http://localhost:3001/docs`
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`API文档地址`http://localhost:3001/docs`
```shell
cd MoviePilot
pip install -r requirements.txt
python3 main.py
python3 -m app.main
```
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
```shell

355
app/agent/__init__.py Normal file
View File

@@ -0,0 +1,355 @@
"""MoviePilot AI智能体实现"""
import asyncio
from typing import Dict, List, Any
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager
from app.agent.tools import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
class AgentChain(ChainBase):
pass
class MoviePilotAgent:
"""MoviePilot AI智能体"""
def __init__(self, session_id: str, user_id: str = None,
channel: str = None, source: str = None, username: str = None):
self.session_id = session_id
self.user_id = user_id
self.channel = channel # 消息渠道
self.source = source # 消息来源
self.username = username # 用户名
# 消息助手
self.message_helper = MessageHelper()
# 记忆管理器
self.memory_manager = ConversationMemoryManager()
# 提示词管理器
self.prompt_manager = PromptManager()
# 回调处理器
self.callback_handler = StreamingCallbackHandler(
session_id=session_id
)
# LLM模型
self.llm = self._initialize_llm()
# 工具
self.tools = self._initialize_tools()
# 会话存储
self.session_store = self._initialize_session_store()
# 提示词模板
self.prompt = self._initialize_prompt()
# Agent执行器
self.agent_executor = self._create_agent_executor()
def _initialize_llm(self):
"""初始化LLM模型"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
if not api_key:
raise ValueError("未配置 LLM_API_KEY")
if provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler]
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True
)
else:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True
)
def _initialize_tools(self) -> List:
"""初始化工具列表"""
return MoviePilotToolFactory.create_tools(
session_id=self.session_id,
user_id=self.user_id,
channel=self.channel,
source=self.source,
username=self.username,
callback_handler=self.callback_handler
)
@staticmethod
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
"""初始化内存存储"""
return {}
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
"""获取会话历史"""
if session_id not in self.session_store:
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
for msg in messages:
if msg.get("role") == "user":
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_ai_message(AIMessage(
content=msg.get("content", ""),
tool_calls=[ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)]
))
elif msg.get("role") == "tool_result":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "system":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
self.session_store[session_id] = chat_history
return self.session_store[session_id]
@staticmethod
def _initialize_prompt() -> ChatPromptTemplate:
"""初始化提示词模板"""
try:
prompt_template = ChatPromptTemplate.from_messages([
("system", "{system_prompt}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
logger.info("LangChain提示词模板初始化成功")
return prompt_template
except Exception as e:
logger.error(f"初始化提示词失败: {e}")
raise e
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""创建Agent执行器"""
try:
agent = create_openai_tools_agent(
llm=self.llm,
tools=self.tools,
prompt=self.prompt
)
executor = AgentExecutor(
agent=agent,
tools=self.tools,
verbose=settings.LLM_VERBOSE,
max_iterations=settings.LLM_MAX_ITERATIONS,
return_intermediate_steps=True,
handle_parsing_errors=True,
early_stopping_method="force"
)
return RunnableWithMessageHistory(
executor,
self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history"
)
except Exception as e:
logger.error(f"创建Agent执行器失败: {e}")
raise e
async def process_message(self, message: str) -> str:
"""处理用户消息"""
try:
# 添加用户消息到记忆
await self.memory_manager.add_memory(
self.session_id,
user_id=self.user_id,
role="user",
content=message
)
# 构建输入上下文
input_context = {
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
"input": message
}
# 执行Agent
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
await self._execute_agent(input_context)
# 获取Agent回复
agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道
await self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
session_id=self.session_id,
user_id=self.user_id,
role="agent",
content=agent_message
)
return agent_message
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
# 发送错误消息给用户(通过原渠道)
await self.send_agent_message(error_message)
return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent"""
try:
with get_openai_callback() as cb:
result = await self.agent_executor.ainvoke(
input_context,
config={"configurable": {"session_id": self.session_id}},
callbacks=[self.callback_handler]
)
logger.info(f"LLM调用消耗: \n{cb}")
if cb.total_tokens > 0:
result["token_usage"] = {
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"total_tokens": cb.total_tokens
}
return result
except asyncio.CancelledError:
logger.info(f"Agent执行被取消: session_id={self.session_id}")
return {
"output": "任务已取消",
"intermediate_steps": [],
"token_usage": {}
}
except Exception as e:
logger.error(f"Agent执行失败: {e}")
return {
"output": f"执行过程中发生错误: {str(e)}",
"intermediate_steps": [],
"token_usage": {}
}
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
await AgentChain().async_post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def cleanup(self):
"""清理智能体资源"""
if self.session_id in self.session_store:
del self.session_store[self.session_id]
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
class AgentManager:
"""AI智能体管理器"""
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
self.memory_manager = ConversationMemoryManager()
async def initialize(self):
"""初始化管理器"""
await self.memory_manager.initialize()
async def close(self):
"""关闭管理器"""
await self.memory_manager.close()
# 清理所有活跃的智能体
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
async def process_message(self, session_id: str, user_id: str, message: str,
channel: str = None, source: str = None, username: str = None) -> str:
"""处理用户消息"""
# 获取或创建Agent实例
if session_id not in self.active_agents:
logger.info(f"创建新的AI智能体实例session_id: {session_id}, user_id: {user_id}")
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=channel,
source=source,
username=username
)
agent.memory_manager = self.memory_manager
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
agent.user_id = user_id # 确保user_id是最新的
# 更新渠道信息
if channel:
agent.channel = channel
if source:
agent.source = source
if username:
agent.username = username
# 处理消息
return await agent.process_message(message)
async def clear_session(self, session_id: str, user_id: str):
"""清空会话"""
if session_id in self.active_agents:
agent = self.active_agents[session_id]
await agent.cleanup()
del self.active_agents[session_id]
await self.memory_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")
# 全局智能体管理器实例
agent_manager = AgentManager()

View File

@@ -0,0 +1,33 @@
import threading
from langchain_core.callbacks import AsyncCallbackHandler
from app.log import logger
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
self.session_id = session_id
self.current_message = ""
async def get_message(self):
"""获取当前消息内容,获取后清空"""
with self._lock:
if not self.current_message:
return ""
msg = self.current_message
logger.info(f"Agent消息: {msg}")
self.current_message = ""
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
if not token:
return
with self._lock:
# 缓存当前消息
self.current_message += token

View File

@@ -0,0 +1,280 @@
"""对话记忆管理器"""
import asyncio
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from app.core.config import settings
from app.helper.redis import AsyncRedisHelper
from app.log import logger
from app.schemas.agent import ConversationMemory
class ConversationMemoryManager:
"""对话记忆管理器"""
def __init__(self):
# 内存中的会话记忆缓存
self.memory_cache: Dict[str, ConversationMemory] = {}
# 使用现有的Redis助手
self.redis_helper = AsyncRedisHelper()
# 内存缓存清理任务Redis通过TTL自动过期
self.cleanup_task: Optional[asyncio.Task] = None
async def initialize(self):
"""初始化记忆管理器"""
try:
# 启动内存缓存清理任务Redis通过TTL自动过期
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
logger.info("对话记忆管理器初始化完成")
except Exception as e:
logger.warning(f"Redis连接失败将使用内存存储: {e}")
async def close(self):
"""关闭记忆管理器"""
if self.cleanup_task:
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
await self.redis_helper.close()
logger.info("对话记忆管理器已关闭")
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
"""获取会话记忆"""
# 首先检查缓存
cache_key = f"{user_id}:{session_id}" if user_id else session_id
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]
# 尝试从Redis加载
if settings.CACHE_BACKEND_TYPE == "redis":
try:
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
if memory_data:
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
memory = ConversationMemory(**memory_dict)
self.memory_cache[cache_key] = memory
return memory
except Exception as e:
logger.warning(f"从Redis加载记忆失败: {e}")
# 创建新的记忆
memory = ConversationMemory(session_id=session_id, user_id=user_id)
self.memory_cache[cache_key] = memory
await self._save_memory(memory)
return memory
async def set_title(self, session_id: str, user_id: str, title: str):
"""设置会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
memory.title = title
memory.updated_at = datetime.now()
await self._save_memory(memory)
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
"""获取会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
return memory.title
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
"""列出历史会话摘要(按更新时间倒序)
- 当启用Redis时遍历 `agent_memory:*` 键并读取摘要
- 当未启用Redis时基于内存缓存返回
"""
sessions: List[ConversationMemory] = []
# 从Redis遍历
if settings.CACHE_BACKEND_TYPE == "redis":
try:
# 使用Redis助手的items方法遍历所有键
async for key, value in self.redis_helper.items(region="AI_AGENT"):
if key.startswith("agent_memory:"):
try:
# 解析键名获取user_id和session_id
key_parts = key.split(":")
if len(key_parts) >= 3:
key_user_id = key_parts[2] if len(key_parts) > 3 else None
if not user_id or key_user_id == user_id:
data = value if isinstance(value, dict) else json.loads(value)
memory = ConversationMemory(**data)
sessions.append(memory)
except Exception as err:
logger.warning(f"解析Redis记忆数据失败: {err}")
continue
except Exception as e:
logger.warning(f"遍历Redis会话失败: {e}")
# 合并内存缓存(确保包含近期的会话)
for cache_key, memory in self.memory_cache.items():
# 如果指定了user_id只返回该用户的会话
if not user_id or memory.user_id == user_id:
sessions.append(memory)
# 去重(以 session_id 为键取最近updated
uniq: Dict[str, ConversationMemory] = {}
for mem in sessions:
existed = uniq.get(mem.session_id)
if (not existed) or (mem.updated_at > existed.updated_at):
uniq[mem.session_id] = mem
# 排序并裁剪
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
return [
{
"session_id": m.session_id,
"title": m.title or "新会话",
"message_count": len(m.messages),
"created_at": m.created_at.isoformat(),
"updated_at": m.updated_at.isoformat(),
}
for m in sorted_list
]
async def add_memory(
self,
session_id: str,
user_id: str,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None
):
"""添加消息到记忆"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
"metadata": metadata or {}
}
memory.messages.append(message)
memory.updated_at = datetime.now()
# 限制消息数量,避免记忆过大
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
if len(memory.messages) > max_messages:
# 保留最近的消息,但保留第一条系统消息
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
memory.messages = system_messages + recent_messages
await self._save_memory(memory)
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
def get_recent_messages_for_agent(
self,
session_id: str,
user_id: str
) -> List[Dict[str, Any]]:
"""为Agent获取最近的消息仅内存缓存
如果消息Token数量超过模型最大上下文长度的阀值会自动进行摘要裁剪
"""
cache_key = f"{user_id}:{session_id}" if user_id else session_id
memory = self.memory_cache.get(cache_key)
if not memory:
return []
# 获取所有消息
messages = memory.messages
return messages
async def get_recent_messages(
self,
session_id: str,
user_id: str,
limit: int = 10,
role_filter: Optional[list] = None
) -> List[Dict[str, Any]]:
"""获取最近的消息"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
messages = memory.messages
if role_filter:
messages = [msg for msg in messages if msg["role"] in role_filter]
return messages[-limit:] if messages else []
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
"""获取会话上下文"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
return memory.context
async def clear_memory(self, session_id: str, user_id: str):
"""清空会话记忆"""
cache_key = f"{user_id}:{session_id}" if user_id else session_id
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
if settings.CACHE_BACKEND_TYPE == "redis":
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
await self.redis_helper.delete(redis_key, region="AI_AGENT")
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
async def _save_memory(self, memory: ConversationMemory):
"""保存记忆到存储
Redis中的记忆会自动通过TTL机制过期无需手动清理
"""
# 更新内存缓存
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
self.memory_cache[cache_key] = memory
# 保存到Redis设置TTL自动过期
if settings.CACHE_BACKEND_TYPE == "redis":
try:
memory_dict = memory.model_dump()
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
await self.redis_helper.set(
redis_key,
memory_dict,
ttl=ttl,
region="AI_AGENT"
)
except Exception as e:
logger.warning(f"保存记忆到Redis失败: {e}")
async def _cleanup_expired_memories(self):
"""清理内存中过期记忆的后台任务
注意Redis中的记忆通过TTL机制自动过期这里只清理内存缓存
"""
while True:
try:
# 每小时清理一次
await asyncio.sleep(3600)
current_time = datetime.now()
expired_sessions = []
# 只检查内存缓存中的过期记忆
# Redis中的记忆会通过TTL自动过期无需手动处理
for cache_key, memory in self.memory_cache.items():
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
expired_sessions.append(cache_key)
# 只清理内存缓存不删除Redis中的键Redis会自动过期
for cache_key in expired_sessions:
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
if expired_sessions:
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理记忆时发生错误: {e}")

View File

@@ -0,0 +1,70 @@
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
## Your Identity and Capabilities
You are an AI agent for the MoviePilot media management system with the following core capabilities:
### Media Management Capabilities
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
- **Manage Downloads**: Search and add torrent resources to downloaders
- **Query Status**: Check subscription status, download progress, and media library status
### Intelligent Interaction Capabilities
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
- **Context Memory**: Remember conversation history and user preferences
- **Smart Recommendations**: Recommend related media content based on user preferences
- **Task Execution**: Automatically execute complex media management tasks
## Working Principles
1. **Always respond in Chinese**: All responses must be in Chinese
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
3. **Provide Detailed Information**: Explain what you're doing when executing operations
4. **Safety First**: Confirm user intent before performing download operations
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
## Common Operation Workflows
### Add Subscription Workflow
1. Understand the media content the user wants to subscribe to
2. Search for related media information
3. Create subscription rules
4. Confirm successful subscription
### Search and Download Workflow
1. Understand user requirements (movie names, TV show names, etc.)
2. Search for related media information
3. Search for related torrent resources by media info
4. Filter suitable resources
5. Add to downloader
### Query Status Workflow
1. Understand what information the user wants to know
2. Query related data
3. Organize and present results
## Tool Usage Guidelines
### Tool Usage Principles
- Use tools proactively to complete user requests
- Always explain what you're doing when using tools
- Provide detailed results and explanations
- Handle errors gracefully and suggest alternatives
- Confirm user intent before performing download operations
### Response Format
- Always respond in Chinese
- Use clear and friendly language
- Provide structured information when appropriate
- Include relevant details about media content (title, year, type, etc.)
- Explain the results of tool operations clearly
## Important Notes
- Always confirm user intent before performing download operations
- If search results are not ideal, proactively adjust search strategies
- Maintain a friendly and professional tone
- Seek solutions proactively when encountering problems
- Remember user preferences and provide personalized recommendations
- Handle errors gracefully and provide helpful suggestions

View File

@@ -0,0 +1,118 @@
"""提示词管理器"""
from pathlib import Path
from typing import Dict
from app.log import logger
class PromptManager:
"""提示词管理器"""
def __init__(self, prompts_dir: str = None):
if prompts_dir is None:
self.prompts_dir = Path(__file__).parent
else:
self.prompts_dir = Path(prompts_dir)
self.prompts_cache: Dict[str, str] = {}
def load_prompt(self, prompt_name: str) -> str:
"""加载指定的提示词"""
if prompt_name in self.prompts_cache:
return self.prompts_cache[prompt_name]
prompt_file = self.prompts_dir / prompt_name
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 缓存提示词
self.prompts_cache[prompt_name] = content
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
return content
except FileNotFoundError:
logger.error(f"提示词文件不存在: {prompt_file}")
raise
except Exception as e:
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
raise
def get_agent_prompt(self, channel: str = None) -> str:
"""
获取智能体提示词
:param channel: 消息渠道Telegram、微信、Slack等
:return: 提示词内容
"""
base_prompt = self.load_prompt("Agent Prompt.txt")
# 根据渠道添加特定的格式说明
if channel:
channel_format_info = self._get_channel_format_info(channel)
if channel_format_info:
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
return base_prompt
@staticmethod
def _get_channel_format_info(channel: str) -> str:
"""
获取渠道特定的格式说明
:param channel: 消息渠道
:return: 格式说明文本
"""
channel_lower = channel.lower() if channel else ""
if "telegram" in channel_lower:
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
**Supported Formatting:**
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
- **Italic text**: Use `_text_` (underscore)
- **Code**: Use `` `text` `` (backtick)
- **Links**: Use `[text](url)` format
- **Strikethrough**: Use `~text~` (tilde)
**IMPORTANT - Headings and Lists:**
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
**Examples:**
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
- ❌ Wrong list: `- Item 1` or `* Item 2`
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
**Special Characters:**
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
elif "wechat" in channel_lower or "微信" in channel:
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
- WeChat does NOT support Markdown formatting. Use plain text format only.
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
- Use plain text descriptions. You can organize content using line breaks and punctuation
- Links can be provided directly as URLs, no Markdown link format needed
- Keep messages concise and clear, use natural Chinese expressions"""
elif "slack" in channel_lower:
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
- Slack supports Markdown formatting
- Use `*text*` for bold
- Use `_text_` for italic
- Use `` `text` `` for code
- Link format: `<url|text>` or `[text](url)`"""
# 其他渠道使用标准Markdown
return None
def clear_cache(self):
"""清空缓存"""
self.prompts_cache.clear()
logger.info("提示词缓存已清空")

View File

@@ -0,0 +1,31 @@
"""MoviePilot工具模块"""
from .base import MoviePilotTool
from app.agent.tools.impl.search_media import SearchMediaTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
from app.agent.tools.impl.query_sites import QuerySitesTool
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
from app.agent.tools.impl.send_message import SendMessageTool
from .factory import MoviePilotToolFactory
__all__ = [
"MoviePilotTool",
"SearchMediaTool",
"AddSubscribeTool",
"SearchTorrentsTool",
"AddDownloadTool",
"QuerySubscribesTool",
"QueryDownloadsTool",
"QueryDownloadersTool",
"QuerySitesTool",
"GetRecommendationsTool",
"QueryMediaLibraryTool",
"SendMessageTool",
"MoviePilotToolFactory"
]

73
app/agent/tools/base.py Normal file
View File

@@ -0,0 +1,73 @@
"""MoviePilot工具基类"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Any
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.chain import ChainBase
from app.schemas import Notification
class ToolChain(ChainBase):
pass
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
_channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 发送执行工具说明
explanation = kwargs.get("explanation")
if explanation:
await self.send_tool_message(f"▶️️{explanation}")
return await self.run(**kwargs)
@abstractmethod
async def run(self, **kwargs) -> str:
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
self._channel = channel
self._source = source
self._username = username
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
self._callback_handler = callback_handler
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
await ToolChain().async_post_message(
Notification(
channel=self._channel,
source=self._source,
userid=self._user_id,
username=self._username,
title=title,
text=message
),
escape_markdown=False
)

View File

@@ -0,0 +1,84 @@
"""MoviePilot工具工厂"""
from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
from app.agent.tools.impl.query_sites import QuerySitesTool
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
from app.agent.tools.impl.search_media import SearchMediaTool
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.send_message import SendMessageTool
from app.core.plugin import PluginManager
from app.log import logger
from .base import MoviePilotTool
class MoviePilotToolFactory:
"""MoviePilot工具工厂"""
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
tools = []
tool_definitions = [
SearchMediaTool,
AddSubscribeTool,
SearchTorrentsTool,
AddDownloadTool,
QuerySubscribesTool,
QueryDownloadsTool,
QueryDownloadersTool,
QuerySitesTool,
GetRecommendationsTool,
QueryMediaLibraryTool,
SendMessageTool
]
# 创建内置工具
for ToolClass in tool_definitions:
tool = ToolClass(
session_id=session_id,
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tools.append(tool)
# 加载插件提供的工具
plugin_tools_count = 0
plugin_tools_info = PluginManager().get_plugin_agent_tools()
for plugin_info in plugin_tools_info:
plugin_id = plugin_info.get("plugin_id")
plugin_name = plugin_info.get("plugin_name")
tool_classes = plugin_info.get("tools", [])
for ToolClass in tool_classes:
try:
# 验证工具类是否继承自 MoviePilotTool
if not issubclass(ToolClass, MoviePilotTool):
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool已跳过")
continue
# 创建工具实例
tool = ToolClass(
session_id=session_id,
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
except Exception as e:
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
builtin_tools_count = len(tool_definitions)
if plugin_tools_count > 0:
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
else:
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools

View File

View File

@@ -0,0 +1,92 @@
"""添加下载工具"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.chain.download import DownloadChain
from app.core.context import Context
from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
from app.log import logger
from app.schemas import TorrentInfo
class AddDownloadInput(BaseModel):
"""添加下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
torrent_title: str = Field(...,
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
torrent_description: Optional[str] = Field(None,
description="Brief description of the torrent content (optional)")
downloader: Optional[str] = Field(None,
description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
class AddDownloadTool(MoviePilotTool):
name: str = "add_download"
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
args_schema: Type[BaseModel] = AddDownloadInput
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
try:
if not torrent_title or not torrent_url:
return "错误:必须提供种子标题和下载链接"
# 使用DownloadChain添加下载
download_chain = DownloadChain()
# 根据站点名称查询站点cookie
if not site_name:
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
siteinfo = await SiteOper().async_get_by_name(site_name)
if not siteinfo:
return f"错误:未找到站点信息:{site_name}"
# 创建下载上下文
torrent_info = TorrentInfo(
title=torrent_title,
description=torrent_description,
enclosure=torrent_url,
site_name=site_name,
site_ua=siteinfo.ua,
site_cookie=siteinfo.cookie,
site_proxy=siteinfo.proxy,
site_order=siteinfo.pri,
site_downloader=siteinfo.downloader
)
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
media_info = await ToolChain().async_recognize_media(meta=meta_info)
if not media_info:
return "错误:无法识别媒体信息,无法添加下载任务"
context = Context(
torrent_info=torrent_info,
meta_info=meta_info,
media_info=media_info
)
did = download_chain.download_single(
context=context,
downloader=downloader,
save_path=save_path,
label=labels
)
if did:
return f"成功添加下载任务:{torrent_title}"
else:
return "添加下载任务失败"
except Exception as e:
logger.error(f"添加下载任务失败: {e}", exc_info=True)
return f"添加下载任务时发生错误: {str(e)}"

View File

@@ -0,0 +1,60 @@
"""添加订阅工具"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.subscribe import SubscribeChain
from app.log import logger
from app.schemas.types import MediaType
class AddSubscribeInput(BaseModel):
"""添加订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
year: str = Field(..., description="Release year of the media (required for accurate identification)")
media_type: str = Field(...,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None,
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
tmdb_id: Optional[str] = Field(None,
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
class AddSubscribeTool(MoviePilotTool):
name: str = "add_subscribe"
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
args_schema: Type[BaseModel] = AddSubscribeInput
async def run(self, title: str, year: str, media_type: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
try:
subscribe_chain = SubscribeChain()
# 转换 tmdb_id 为整数
tmdbid_int = None
if tmdb_id:
try:
tmdbid_int = int(tmdb_id)
except (ValueError, TypeError):
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
sid, message = await subscribe_chain.async_add(
mtype=MediaType(media_type),
title=title,
year=year,
tmdbid=tmdbid_int,
season=season,
username=self._user_id
)
if sid:
return f"成功添加订阅:{title} ({year})"
else:
return f"添加订阅失败:{message}"
except Exception as e:
logger.error(f"添加订阅失败: {e}", exc_info=True)
return f"添加订阅时发生错误: {str(e)}"

View File

@@ -0,0 +1,84 @@
"""获取推荐工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.recommend import RecommendChain
from app.log import logger
class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending",
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
limit: Optional[int] = Field(20,
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
class GetRecommendationsTool(MoviePilotTool):
name: str = "get_recommendations"
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
args_schema: Type[BaseModel] = GetRecommendationsInput
async def run(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
try:
name_dicts = {
"tmdb_trending": "TMDB 热门推荐",
"douban_hot": "豆瓣热门推荐",
"bangumi_calendar": "番组计划推荐"
}
recommend_chain = RecommendChain()
results = []
if source == "tmdb_trending":
results = await recommend_chain.async_tmdb_trending(limit=limit)
elif source == "douban_hot":
if media_type == "movie":
results = await recommend_chain.async_douban_movie_hot(limit=limit)
elif media_type == "tv":
results = await recommend_chain.async_douban_tv_hot(limit=limit)
else: # all
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
elif source == "bangumi_calendar":
results = await recommend_chain.async_bangumi_calendar(limit=limit)
if results:
# 限制最多20条结果
total_count = len(results)
limited_results = results[:20]
# 精简字段,只保留关键信息
simplified_results = []
for r in limited_results:
# r 已经是字典格式to_dict的结果
simplified = {
"title": r.get("title"),
"en_title": r.get("en_title"),
"year": r.get("year"),
"type": r.get("type"),
"season": r.get("season"),
"tmdb_id": r.get("tmdb_id"),
"imdb_id": r.get("imdb_id"),
"douban_id": r.get("douban_id"),
"overview": r.get("overview", "")[:200] + "..." if r.get("overview") and len(r.get("overview", "")) > 200 else r.get("overview"),
"vote_average": r.get("vote_average"),
"poster_path": r.get("poster_path"),
"detail_link": r.get("detail_link")
}
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
return "未找到推荐内容。"
except Exception as e:
logger.error(f"获取推荐失败: {e}", exc_info=True)
return f"获取推荐时发生错误: {str(e)}"

View File

@@ -0,0 +1,34 @@
"""查询下载器工具"""
import json
from typing import Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
class QueryDownloadersInput(BaseModel):
"""查询下载器工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
class QueryDownloadersTool(MoviePilotTool):
name: str = "query_downloaders"
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
args_schema: Type[BaseModel] = QueryDownloadersInput
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
system_config_oper = SystemConfigOper()
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
if downloaders_config:
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
return "未配置下载器。"
except Exception as e:
logger.error(f"查询下载器失败: {e}")
return f"查询下载器时发生错误: {str(e)}"

View File

@@ -0,0 +1,80 @@
"""查询下载工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain
from app.log import logger
class QueryDownloadsInput(BaseModel):
"""查询下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
downloader: Optional[str] = Field(None,
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
status: Optional[str] = Field("all",
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
class QueryDownloadsTool(MoviePilotTool):
name: str = "query_downloads"
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
args_schema: Type[BaseModel] = QueryDownloadsInput
async def run(self, downloader: Optional[str] = None,
status: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
try:
download_chain = DownloadChain()
# 使用 DownloadChain.downloading 方法获取正在下载的任务
downloads = download_chain.downloading(name=downloader)
filtered_downloads = []
for dl in downloads:
if downloader and dl.downloader != downloader:
continue
if status != "all" and dl.status != status:
continue
filtered_downloads.append(dl)
if filtered_downloads:
# 限制最多20条结果
total_count = len(filtered_downloads)
limited_downloads = filtered_downloads[:20]
# 精简字段,只保留关键信息
simplified_downloads = []
for d in limited_downloads:
simplified = {
"downloader": d.downloader,
"hash": d.hash,
"title": d.title,
"name": d.name,
"year": d.year,
"season_episode": d.season_episode,
"size": d.size,
"progress": d.progress,
"state": d.state,
"upspeed": d.upspeed,
"dlspeed": d.dlspeed,
"left_time": d.left_time
}
# 精简 media 字段
if d.media:
simplified["media"] = {
"tmdbid": d.media.get("tmdbid"),
"type": d.media.get("type"),
"title": d.media.get("title"),
"season": d.media.get("season"),
"episode": d.media.get("episode")
}
simplified_downloads.append(simplified)
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
return "未找到相关下载任务"
except Exception as e:
logger.error(f"查询下载失败: {e}", exc_info=True)
return f"查询下载时发生错误: {str(e)}"

View File

@@ -0,0 +1,41 @@
"""查询媒体库工具"""
import json
from typing import Optional, List, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.mediaserver_oper import MediaServerOper
from app.log import logger
from app.schemas import MediaServerItem
class QueryMediaLibraryInput(BaseModel):
"""查询媒体库工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
title: Optional[str] = Field(None,
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
class QueryMediaLibraryTool(MoviePilotTool):
name: str = "query_media_library"
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
args_schema: Type[BaseModel] = QueryMediaLibraryInput
async def run(self, media_type: Optional[str] = "all",
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try:
media_server_oper = MediaServerOper()
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
if filtered_medias:
return json.dumps([m.to_dict() for m in filtered_medias])
return "媒体库中未找到相关媒体"
except Exception as e:
logger.error(f"查询媒体库失败: {e}", exc_info=True)
return f"查询媒体库时发生错误: {str(e)}"

View File

@@ -0,0 +1,66 @@
"""查询站点工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.site_oper import SiteOper
from app.log import logger
class QuerySitesInput(BaseModel):
"""查询站点工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all",
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
name: Optional[str] = Field(None,
description="Filter sites by name (partial match, optional)")
class QuerySitesTool(MoviePilotTool):
name: str = "query_sites"
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
args_schema: Type[BaseModel] = QuerySitesInput
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
try:
site_oper = SiteOper()
# 获取所有站点(按优先级排序)
sites = await site_oper.async_list()
filtered_sites = []
for site in sites:
# 按状态过滤
if status == "active" and not site.is_active:
continue
if status == "inactive" and site.is_active:
continue
# 按名称过滤(部分匹配)
if name and name.lower() not in (site.name or "").lower():
continue
filtered_sites.append(site)
if filtered_sites:
# 精简字段,只保留关键信息
simplified_sites = []
for s in filtered_sites:
simplified = {
"id": s.id,
"name": s.name,
"domain": s.domain,
"url": s.url,
"pri": s.pri,
"is_active": s.is_active,
"downloader": s.downloader,
"proxy": s.proxy,
"timeout": s.timeout
}
simplified_sites.append(simplified)
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
return result_json
return "未找到相关站点"
except Exception as e:
logger.error(f"查询站点失败: {e}", exc_info=True)
return f"查询站点时发生错误: {str(e)}"

View File

@@ -0,0 +1,73 @@
"""查询订阅工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
class QuerySubscribesTool(MoviePilotTool):
name: str = "query_subscribes"
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
args_schema: Type[BaseModel] = QuerySubscribesInput
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
try:
subscribe_oper = SubscribeOper()
subscribes = await subscribe_oper.async_list()
filtered_subscribes = []
for sub in subscribes:
if status != "all" and sub.state != status:
continue
if media_type != "all" and sub.type != media_type:
continue
filtered_subscribes.append(sub)
if filtered_subscribes:
# 限制最多20条结果
total_count = len(filtered_subscribes)
limited_subscribes = filtered_subscribes[:20]
# 精简字段,只保留关键信息
simplified_subscribes = []
for s in limited_subscribes:
simplified = {
"id": s.id,
"name": s.name,
"year": s.year,
"type": s.type,
"season": s.season,
"tmdbid": s.tmdbid,
"doubanid": s.doubanid,
"bangumiid": s.bangumiid,
"poster": s.poster,
"vote": s.vote,
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
"state": s.state,
"total_episode": s.total_episode,
"lack_episode": s.lack_episode,
"last_update": s.last_update,
"username": s.username
}
simplified_subscribes.append(simplified)
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
return "未找到相关订阅"
except Exception as e:
logger.error(f"查询订阅失败: {e}", exc_info=True)
return f"查询订阅时发生错误: {str(e)}"

View File

@@ -0,0 +1,96 @@
"""搜索媒体工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain
from app.log import logger
from app.schemas.types import MediaType
class SearchMediaInput(BaseModel):
"""搜索媒体工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
media_type: Optional[str] = Field(None,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None,
description="Season number for TV shows and anime (optional, only applicable for series)")
class SearchMediaTool(MoviePilotTool):
name: str = "search_media"
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
args_schema: Type[BaseModel] = SearchMediaInput
async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
try:
media_chain = MediaChain()
# 构建搜索标题
search_title = title
if year:
search_title = f"{title} {year}"
if media_type:
search_title = f"{search_title} {media_type}"
if season:
search_title = f"{search_title} S{season:02d}"
# 使用 MediaChain.search 方法
meta, results = await media_chain.async_search(title=search_title)
# 过滤结果
if results:
filtered_results = []
for result in results:
if year and result.year != year:
continue
if media_type:
if result.type != MediaType(media_type):
continue
if season and result.season != season:
continue
filtered_results.append(result)
if filtered_results:
# 限制最多20条结果
total_count = len(filtered_results)
limited_results = filtered_results[:20]
# 精简字段,只保留关键信息
simplified_results = []
for r in limited_results:
simplified = {
"title": r.title,
"en_title": r.en_title,
"year": r.year,
"type": r.type.value if r.type else None,
"season": r.season,
"tmdb_id": r.tmdb_id,
"imdb_id": r.imdb_id,
"douban_id": r.douban_id,
"overview": r.overview[:200] + "..." if r.overview and len(r.overview) > 200 else r.overview,
"vote_average": r.vote_average,
"poster_path": r.poster_path,
"detail_link": r.detail_link
}
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到符合条件的媒体资源: {title}"
else:
return f"未找到相关媒体资源: {title}"
except Exception as e:
error_message = f"搜索媒体失败: {str(e)}"
logger.error(f"搜索媒体失败: {e}", exc_info=True)
return error_message

View File

@@ -0,0 +1,122 @@
"""搜索种子工具"""
import json
import re
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain
from app.log import logger
from app.schemas.types import MediaType
class SearchTorrentsInput(BaseModel):
"""搜索种子工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(...,
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
media_type: Optional[str] = Field(None,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
sites: Optional[List[int]] = Field(None,
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
filter_pattern: Optional[str] = Field(None,
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
class SearchTorrentsTool(MoviePilotTool):
name: str = "search_torrents"
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
args_schema: Type[BaseModel] = SearchTorrentsInput
async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
try:
search_chain = SearchChain()
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
filtered_torrents = []
# 编译正则表达式(如果提供)
regex_pattern = None
if filter_pattern:
try:
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
except re.error as e:
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
return f"正则表达式格式错误: {str(e)}"
for torrent in torrents:
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
if year and torrent.meta_info and torrent.meta_info.year != year:
continue
if media_type and torrent.media_info:
if torrent.media_info.type != MediaType(media_type):
continue
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
continue
# 使用正则表达式过滤标题(分辨率、质量等关键字)
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
if not regex_pattern.search(torrent.torrent_info.title):
continue
filtered_torrents.append(torrent)
if filtered_torrents:
# 限制最多50条结果
total_count = len(filtered_torrents)
limited_torrents = filtered_torrents[:50]
# 精简字段,只保留关键信息
simplified_torrents = []
for t in limited_torrents:
simplified = {}
# 精简 torrent_info
if t.torrent_info:
simplified["torrent_info"] = {
"title": t.torrent_info.title,
"size": t.torrent_info.size,
"seeders": t.torrent_info.seeders,
"peers": t.torrent_info.peers,
"site_name": t.torrent_info.site_name,
"enclosure": t.torrent_info.enclosure,
"page_url": t.torrent_info.page_url,
"volume_factor": t.torrent_info.volume_factor,
"pubdate": t.torrent_info.pubdate
}
# 精简 media_info
if t.media_info:
simplified["media_info"] = {
"title": t.media_info.title,
"en_title": t.media_info.en_title,
"year": t.media_info.year,
"type": t.media_info.type.value if t.media_info.type else None,
"season": t.media_info.season,
"tmdb_id": t.media_info.tmdb_id
}
# 精简 meta_info
if t.meta_info:
simplified["meta_info"] = {
"name": t.meta_info.name,
"cn_name": t.meta_info.cn_name,
"en_name": t.meta_info.en_name,
"year": t.meta_info.year,
"type": t.meta_info.type.value if t.meta_info.type else None,
"begin_season": t.meta_info.begin_season
}
simplified_torrents.append(simplified)
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 50:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到相关种子资源: {title}"
except Exception as e:
error_message = f"搜索种子时发生错误: {str(e)}"
logger.error(f"搜索种子失败: {e}", exc_info=True)
return error_message

View File

@@ -0,0 +1,31 @@
"""发送消息工具"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
class SendMessageInput(BaseModel):
"""发送消息工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
message_type: Optional[str] = Field("info",
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
class SendMessageTool(MoviePilotTool):
name: str = "send_message"
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
args_schema: Type[BaseModel] = SendMessageInput
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try:
await self.send_tool_message(message, title=message_type)
return "消息已发送"
except Exception as e:
logger.error(f"发送消息失败: {e}")
return f"发送消息时发生错误: {str(e)}"

View File

@@ -137,7 +137,7 @@ async def transfer(days: Optional[int] = 7,
return [stat[1] for stat in transfer_stat]
@router.get("/cpu", summary="获取当前CPU使用率", response_model=int)
@router.get("/cpu", summary="获取当前CPU使用率", response_model=float)
def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取当前CPU使用率
@@ -145,7 +145,7 @@ def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
return SystemUtils.cpu_usage()
@router.get("/cpu2", summary="获取当前CPU使用率API_TOKEN", response_model=int)
@router.get("/cpu2", summary="获取当前CPU使用率API_TOKEN", response_model=float)
def cpu2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取当前CPU使用率 API_TOKEN认证?token=xxx

View File

@@ -40,10 +40,10 @@ def download(
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
# 媒体信息
mediainfo = MediaInfo()
mediainfo.from_dict(media_in.dict())
mediainfo.from_dict(media_in.model_dump())
# 种子信息
torrentinfo = TorrentInfo()
torrentinfo.from_dict(torrent_in.dict())
torrentinfo.from_dict(torrent_in.model_dump())
# 手动下载始终使用选择的下载器
torrentinfo.site_downloader = downloader
# 上下文
@@ -64,6 +64,9 @@ def download(
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
def add(
torrent_in: schemas.TorrentInfo,
tmdbid: Annotated[int | None, Body()] = None,
doubanid: Annotated[str | None, Body()] = None,
bangumiid: Annotated[int | None, Body()] = None,
downloader: Annotated[str | None, Body()] = None,
save_path: Annotated[str | None, Body()] = None,
current_user: User = Depends(get_current_active_user)) -> Any:
@@ -73,12 +76,12 @@ def add(
# 元数据
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
# 媒体信息
mediainfo = MediaChain().recognize_media(meta=metainfo)
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
if not mediainfo:
return schemas.Response(success=False, message="无法识别媒体信息")
# 种子信息
torrentinfo = TorrentInfo()
torrentinfo.from_dict(torrent_in.dict())
torrentinfo.from_dict(torrent_in.model_dump())
# 上下文
context = Context(
meta_info=metainfo,

View File

@@ -14,7 +14,7 @@ from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.schemas.types import EventType, MediaType
from app.schemas.types import EventType
router = APIRouter()
@@ -70,7 +70,7 @@ async def transfer_history(title: Optional[str] = None,
return schemas.Response(success=True,
data={
"list": result,
"list": [item.to_dict() for item in result],
"total": total,
})

View File

@@ -8,8 +8,10 @@ from app import schemas
from app.chain.user import UserChain
from app.core import security
from app.core.config import settings
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper
from app.schemas.types import SystemConfigKey
router = APIRouter()
@@ -29,7 +31,10 @@ def login_access_token(
if not success:
raise HTTPException(status_code=401, detail=user_or_message)
# 用户等级
level = SitesHelper().auth_level
# 是否显示配置向导
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
return schemas.Token(
access_token=security.create_access_token(
userid=user_or_message.id,
@@ -45,6 +50,7 @@ def login_access_token(
avatar=user_or_message.avatar,
level=level,
permissions=user_or_message.permissions or {},
widzard=show_wizard
)

View File

@@ -79,7 +79,7 @@ def exists(media_in: schemas.MediaInfo,
"""
# 转化为媒体信息对象
mediainfo = MediaInfo()
mediainfo.from_dict(media_in.dict())
mediainfo.from_dict(media_in.model_dump())
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
if not existsinfo:
return []
@@ -108,7 +108,7 @@ def not_exists(media_in: schemas.MediaInfo,
meta.year = media_in.year
# 转化为媒体信息对象
mediainfo = MediaInfo()
mediainfo.from_dict(media_in.dict())
mediainfo.from_dict(media_in.model_dump())
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
if mediainfo.type == MediaType.MOVIE:

View File

@@ -132,7 +132,7 @@ async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload
"""
客户端webpush通知订阅
"""
subinfo = subscription.dict()
subinfo = subscription.model_dump()
if subinfo not in global_vars.get_subscriptions():
global_vars.push_subscription(subinfo)
logger.debug(f"通知订阅成功: {subinfo}")
@@ -148,7 +148,7 @@ def send_notification(payload: schemas.SubscriptionMessage, _: schemas.TokenPayl
try:
webpush(
subscription_info=sub,
data=json.dumps(payload.dict()),
data=json.dumps(payload.model_dump()),
vapid_private_key=settings.VAPID.get("privateKey"),
vapid_claims={
"sub": settings.VAPID.get("subject")

View File

@@ -13,7 +13,7 @@ from app import schemas
from app.command import Command
from app.core.config import settings
from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token, verify_apitoken
from app.core.security import verify_apikey, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
@@ -21,7 +21,6 @@ from app.factory import app
from app.helper.plugin import PluginHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas.plugin import PluginMemoryInfo
from app.schemas.types import SystemConfigKey
PROTECTED_ROUTES = {"/api/v1/openapi.json", "/docs", "/docs/oauth2-redirect", "/redoc"}
@@ -494,57 +493,6 @@ def clone_plugin(plugin_id: str,
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
@router.get("/memory", summary="插件内存使用统计", response_model=List[PluginMemoryInfo])
def plugin_memory_stats(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取所有插件的内存使用统计信息
"""
try:
plugin_manager = PluginManager()
memory_stats = plugin_manager.get_plugin_memory_stats()
return memory_stats
except Exception as e:
logger.error(f"获取插件内存统计失败:{str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取插件内存统计失败:{str(e)}")
@router.get("/memory/{plugin_id}", summary="单个插件内存使用统计", response_model=PluginMemoryInfo)
def plugin_memory_stat(plugin_id: str, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取指定插件的内存使用统计信息
"""
try:
plugin_manager = PluginManager()
memory_stats = plugin_manager.get_plugin_memory_stats(plugin_id)
if not memory_stats:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"插件 {plugin_id} 不存在或未运行")
return memory_stats[0]
except HTTPException:
raise
except Exception as e:
logger.error(f"获取插件 {plugin_id} 内存统计失败:{str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取插件内存统计失败:{str(e)}")
@router.delete("/memory/cache", summary="清除插件内存统计缓存")
def clear_plugin_memory_cache(_: Annotated[str, Depends(verify_apitoken)],
plugin_id: Optional[str] = None) -> Any:
"""
清除插件内存统计缓存
"""
try:
plugin_manager = PluginManager()
plugin_manager.clear_plugin_memory_cache(plugin_id)
message = f"已清除插件 {plugin_id} 的内存统计缓存" if plugin_id else "已清除所有插件的内存统计缓存"
return schemas.Response(success=True, message=message)
except Exception as e:
logger.error(f"清除插件内存统计缓存失败:{str(e)}")
return schemas.Response(success=False, message=f"清除缓存失败:{str(e)}")
@router.get("/{plugin_id}", summary="获取插件配置")
async def plugin_config(plugin_id: str,
_: User = Depends(get_current_active_superuser_async)) -> dict:

View File

@@ -20,7 +20,7 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询搜索结果
"""
torrents = await SearchChain().async_last_search_results()
torrents = await SearchChain().async_last_search_results() or []
return [torrent.to_dict() for torrent in torrents]

View File

@@ -67,7 +67,7 @@ async def add_site(
site_in.name = site_info.get("name")
site_in.id = None
site_in.public = 1 if site_info.get("public") else 0
site = Site(**site_in.dict())
site = Site(**site_in.model_dump())
site.create(db)
# 通知站点更新
await eventmanager.async_send_event(EventType.SiteUpdated, {
@@ -92,7 +92,7 @@ async def update_site(
# 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/"
await site.async_update(db, site_in.dict())
await site.async_update(db, site_in.model_dump())
# 通知站点更新
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain
@@ -399,7 +399,7 @@ def auth_site(
if not auth_info or not auth_info.site or not auth_info.params:
return schemas.Response(success=False, message="请输入认证站点和认证参数")
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.model_dump())
# 认证成功后,重新初始化插件
PluginManager().init_config()
Scheduler().init_plugin_jobs()

View File

@@ -79,7 +79,7 @@ async def create_subscribe(
# 订阅用户
subscribe_in.username = current_user.name
# 转化为字典
subscribe_dict = subscribe_in.dict()
subscribe_dict = subscribe_in.model_dump()
if subscribe_in.id:
subscribe_dict.pop("id", None)
sid, message = await SubscribeChain().async_add(mtype=mtype,
@@ -106,7 +106,7 @@ async def update_subscribe(
return schemas.Response(success=False, message="订阅不存在")
# 避免更新缺失集数
old_subscribe_dict = subscribe.to_dict()
subscribe_dict = subscribe_in.dict()
subscribe_dict = subscribe_in.model_dump()
if not subscribe_in.lack_episode:
# 没有缺失集数时缺失集数清空避免更新为0
subscribe_dict.pop("lack_episode")
@@ -421,11 +421,23 @@ async def popular_subscribes(
page: Optional[int] = 1,
count: Optional[int] = 30,
min_sub: Optional[int] = None,
genre_id: Optional[int] = None,
min_rating: Optional[float] = None,
max_rating: Optional[float] = None,
sort_type: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询热门订阅
"""
subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
subscribes = await SubscribeHelper().async_get_statistic(
stype=stype,
page=page,
count=count,
genre_id=genre_id,
min_rating=min_rating,
max_rating=max_rating,
sort_type=sort_type
)
if subscribes:
ret_medias = []
for sub in subscribes:
@@ -517,7 +529,7 @@ async def subscribe_fork(
"""
复用订阅
"""
sub_dict = sub.dict()
sub_dict = sub.model_dump()
sub_dict.pop("id")
for key in list(sub_dict.keys()):
if not hasattr(schemas.Subscribe(), key):
@@ -570,11 +582,23 @@ async def popular_subscribes(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
genre_id: Optional[int] = None,
min_rating: Optional[float] = None,
max_rating: Optional[float] = None,
sort_type: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询分享的订阅
"""
return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
return await SubscribeHelper().async_get_shares(
name=name,
page=page,
count=count,
genre_id=genre_id,
min_rating=min_rating,
max_rating=max_rating,
sort_type=sort_type
)
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])

View File

@@ -11,10 +11,12 @@ import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from anyio import Path as AsyncPath
from app.helper.sites import SitesHelper # noqa # noqa
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.mediaserver import MediaServerChain
from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.cache import AsyncFileCache
@@ -31,7 +33,6 @@ from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.sites import SitesHelper # noqa # noqa
from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper
from app.log import logger
@@ -52,19 +53,21 @@ async def fetch_image(
proxy: bool = False,
use_cache: bool = False,
if_none_match: Optional[str] = None,
allowed_domains: Optional[set[str]] = None) -> Response:
cookies: Optional[str | dict] = None,
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
"""
处理图片缓存逻辑支持HTTP缓存和磁盘缓存
"""
if not url:
raise HTTPException(status_code=404, detail="URL not provided")
return None
if allowed_domains is None:
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS)
# 验证URL安全性
if not SecurityUtils.is_safe_url(url, allowed_domains):
raise HTTPException(status_code=404, detail="Unsafe URL")
logger.warn(f"Blocked unsafe image URL: {url}")
return None
# 缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
@@ -95,18 +98,24 @@ async def fetch_image(
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
response = await AsyncRequestUtils(
ua=settings.NORMAL_USER_AGENT,
proxies=proxies,
referer=referer,
cookies=cookies,
accept_type="image/avif,image/webp,image/apng,*/*",
).get_res(url=url)
if not response:
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
logger.warn(f"Failed to fetch image from URL: {url}")
return None
# 验证下载的内容是否为有效图片
try:
content = response.content
Image.open(io.BytesIO(content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
raise HTTPException(status_code=502, detail="Invalid image format")
logger.warn(f"Invalid image format for URL {url}: {e}")
return None
# 获取请求响应头
response_headers = response.headers
@@ -138,6 +147,7 @@ async def proxy_img(
imgurl: str,
proxy: bool = False,
cache: bool = False,
use_cookies: bool = False,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)
) -> Response:
@@ -148,7 +158,12 @@ async def proxy_img(
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
config and config.config and config.config.get("host")]
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache,
cookies = (
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
if use_cookies
else None
)
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
if_none_match=if_none_match, allowed_domains=allowed_domains)
@@ -176,7 +191,7 @@ def get_global_setting(token: str):
raise HTTPException(status_code=403, detail="Forbidden")
# FIXME: 新增敏感配置项时要在此处添加排除项
info = settings.dict(
info = settings.model_dump(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
@@ -197,7 +212,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
"""
查询系统环境变量,包括当前版本号(仅管理员)
"""
info = settings.dict(
info = settings.model_dump(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
)
info.update({

View File

@@ -41,7 +41,7 @@ async def create_user(
user = await current_user.async_get_by_name(db, name=user_in.name)
if user:
return schemas.Response(success=False, message="用户已存在")
user_info = user_in.dict()
user_info = user_in.model_dump()
if user_info.get("password"):
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
@@ -59,7 +59,7 @@ async def update_user(
"""
更新用户
"""
user_info = user_in.dict()
user_info = user_in.model_dump()
if user_info.get("password"):
# 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$'

View File

@@ -11,7 +11,7 @@ from app.chain.workflow import WorkflowChain
from app.core.config import global_vars
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.core.workflow import WorkFlowManager
from app.workflow import WorkFlowManager
from app.db import get_async_db, get_db
from app.db.models import Workflow
from app.db.systemconfig_oper import SystemConfigOper
@@ -47,7 +47,7 @@ async def create_workflow(workflow: schemas.Workflow,
workflow.state = "P"
if not workflow.trigger_type:
workflow.trigger_type = "timer"
workflow_obj = Workflow(**workflow.dict())
workflow_obj = Workflow(**workflow.model_dump())
await workflow_obj.async_create(db)
return schemas.Response(success=True, message="创建工作流成功")
@@ -277,7 +277,7 @@ def update_workflow(workflow: schemas.Workflow,
return schemas.Response(success=False, message="工作流不存在")
if not wf.trigger_type:
workflow.trigger_type = "timer"
wf.update(db, workflow.dict())
wf.update(db, workflow.model_dump())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务

View File

@@ -11,7 +11,7 @@ from fastapi.concurrency import run_in_threadpool
from qbittorrentapi import TorrentFilesList
from transmission_rpc import File
from app.core.cache import FileCache, AsyncFileCache
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
from app.core.config import settings
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager
@@ -358,9 +358,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid:
doubanid = None
bangumiid = None
return self.run_module("recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
with fresh(not cache):
return self.run_module("recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
@@ -391,9 +392,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid:
doubanid = None
bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async with async_fresh(not cache):
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
@@ -850,9 +852,13 @@ class ChainBase(metaclass=ABCMeta):
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 检查消息是否有效
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.dict())
self.messageoper.add(**message.model_dump())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
@@ -899,15 +905,15 @@ class ChainBase(metaclass=ABCMeta):
break
# 按设定发送
self.eventmanager.send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
self.messagequeue.send_message("post_message", message=send_message)
data={**send_message.model_dump(), "type": send_message.mtype})
self.messagequeue.send_message("post_message", message=send_message, **kwargs)
if not send_orignal:
return
# 发送消息事件
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype})
# 按原消息发送
self.messagequeue.send_message("post_message", message=message,
immediately=True if message.userid else False)
immediately=True if message.userid else False, **kwargs)
async def async_post_message(self,
message: Optional[Notification] = None,
@@ -929,9 +935,13 @@ class ChainBase(metaclass=ABCMeta):
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 检查消息是否有效
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.dict())
await self.messageoper.async_add(**message.model_dump())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
@@ -978,16 +988,16 @@ class ChainBase(metaclass=ABCMeta):
break
# 按设定发送
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
await self.messagequeue.async_send_message("post_message", message=send_message)
data={**send_message.model_dump(), "type": send_message.mtype})
await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs)
if not send_orignal:
return
# 发送消息事件
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**message.dict(), "type": message.mtype})
data={**message.model_dump(), "type": message.mtype})
# 按原消息发送
await self.messagequeue.async_send_message("post_message", message=message,
immediately=True if message.userid else False)
immediately=True if message.userid else False, **kwargs)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""
@@ -998,7 +1008,7 @@ class ChainBase(metaclass=ABCMeta):
"""
note_list = [media.to_dict() for media in medias]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
self.messageoper.add(**message.model_dump(), note=note_list)
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
immediately=True if message.userid else False)
@@ -1011,7 +1021,7 @@ class ChainBase(metaclass=ABCMeta):
"""
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
self.messageoper.add(**message.model_dump(), note=note_list)
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
immediately=True if message.userid else False)

View File

@@ -290,7 +290,7 @@ class DownloadChain(ChainBase):
# 登记下载记录
downloadhis = DownloadHistoryOper()
downloadhis.add(
path=str(download_path),
path=download_path.as_posix(),
type=_media.type.value,
title=_media.title,
year=_media.year,
@@ -331,8 +331,8 @@ class DownloadChain(ChainBase):
files_to_add.append({
"download_hash": _hash,
"downloader": _downloader,
"fullpath": str(_save_path / file),
"savepath": str(_save_path),
"fullpath": (_save_path / file).as_posix(),
"savepath": _save_path.as_posix(),
"filepath": file,
"torrentname": _meta.org_string,
})
@@ -994,7 +994,7 @@ class DownloadChain(ChainBase):
# 发出下载任务删除事件,如需处理辅种,可监听该事件
self.eventmanager.send_event(EventType.DownloadDeleted, {
"hash": hash_str,
"torrents": [torrent.dict() for torrent in torrents]
"torrents": [torrent.model_dump() for torrent in torrents]
})
else:
logger.info(f"没有在下载器中查询到 {hash_str} 对应的下载任务")

View File

@@ -1,4 +1,6 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from threading import Lock
from typing import Optional, List, Tuple, Union
@@ -20,6 +22,9 @@ from app.utils.string import StringUtils
recognize_lock = Lock()
scraping_lock = Lock()
current_umask = os.umask(0)
os.umask(current_umask)
class MediaChain(ChainBase):
"""
@@ -456,36 +461,65 @@ class MediaChain(ChainBase):
"""
if not _fileitem or not _content or not _path:
return
# 保存文件到临时目录
tmp_dir = settings.TEMP_PATH / StringUtils.generate_random_str(10)
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_file = tmp_dir / _path.name
tmp_file.write_bytes(_content)
# 获取文件的父目录
try:
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file, new_name=_path.name)
# 使用tempfile创建临时文件自动删除
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 写入内容
if isinstance(_content, bytes):
tmp_file.write(_content)
else:
tmp_file.write(_content.encode('utf-8'))
tmp_file.flush()
tmp_file.close() # 关闭文件句柄
# 刮削文件只需要读写权限
tmp_file_path.chmod(0o666 & ~current_umask)
# 上传文件
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path, new_name=_path.name)
if item:
logger.info(f"已保存文件:{item.path}")
else:
logger.warn(f"文件保存失败:{_path}")
finally:
if tmp_file.exists():
tmp_file.unlink()
def __download_image(_url: str) -> Optional[bytes]:
def __download_and_save_image(_fileitem: schemas.FileItem, _path: Path, _url: str):
"""
下载图片并保存
流式下载图片并直接保存到文件(减少内存占用)
:param _fileitem: 关联的媒体文件项
:param _path: 图片文件路径
:param _url: 图片下载URL
"""
if not _fileitem or not _url or not _path:
return
try:
logger.info(f"正在下载图片:{_url} ...")
r = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(url=_url)
if r:
return r.content
else:
logger.info(f"{_url} 图片下载失败,请检查网络连通性!")
request_utils = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT)
with request_utils.get_stream(url=_url) as r:
if r and r.status_code == 200:
# 使用tempfile创建临时文件自动删除
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 流式写入文件
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
tmp_file.close() # 关闭文件句柄
# 刮削的图片只需要读写权限
tmp_file_path.chmod(0o666 & ~current_umask)
# 上传文件
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path,
new_name=_path.name)
if item:
logger.info(f"已保存图片:{item.path}")
else:
logger.warn(f"图片保存失败:{_path}")
else:
logger.info(f"{_url} 图片下载失败")
except Exception as err:
logger.error(f"{_url} 图片下载失败:{str(err)}")
return None
if not fileitem:
return
@@ -587,11 +621,8 @@ class MediaChain(ChainBase):
image_path = filepath.with_name(image_name)
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 写入图片到当前目录
if content:
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
# 流式下载图片并直接保存
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
else:
logger.info(f"已存在图片文件:{image_path}")
else:
@@ -637,13 +668,10 @@ class MediaChain(ChainBase):
for episode, image_url in image_dict.items():
image_path = filepath.with_suffix(Path(image_url).suffix)
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
if not parent:
parent = storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
# 流式下载图片并直接保存
if not parent:
parent = storagechain.get_parent_item(fileitem)
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
else:
logger.info(f"已存在图片文件:{image_path}")
else:
@@ -694,13 +722,10 @@ class MediaChain(ChainBase):
image_path = filepath.with_name(image_name)
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到剧集目录
if content:
if not parent:
parent = storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
# 流式下载图片并直接保存
if not parent:
parent = storagechain.get_parent_item(fileitem)
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
else:
logger.info(f"已存在图片文件:{image_path}")
else:
@@ -730,13 +755,11 @@ class MediaChain(ChainBase):
continue
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
if not parent:
parent = storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
# 流式下载图片并直接保存
if not parent:
parent = storagechain.get_parent_item(fileitem)
__download_and_save_image(_fileitem=parent, _path=image_path,
_url=image_url)
else:
logger.info(f"已存在图片文件:{image_path}")
else:
@@ -786,11 +809,8 @@ class MediaChain(ChainBase):
image_path = filepath / image_name
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
# 流式下载图片并直接保存
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
else:
logger.info(f"已存在图片文件:{image_path}")
else:

View File

@@ -113,6 +113,16 @@ class MediaServerChain(ChainBase):
"""
return self.run_module("mediaserver_play_url", server=server, item_id=item_id)
def get_image_cookies(
self, server: Optional[str], image_url: str
) -> Optional[str | dict]:
"""
获取图片的Cookies
"""
return self.run_module(
"mediaserver_image_cookies", server=server, image_url=image_url
)
def sync(self):
"""
同步媒体库所有数据到本地数据库
@@ -167,7 +177,7 @@ class MediaServerChain(ChainBase):
for episode in espisodes_info:
seasoninfo[episode.season] = episode.episodes
# 插入数据
item_dict = item.dict()
item_dict = item.model_dump()
item_dict["seasoninfo"] = seasoninfo
item_dict["item_type"] = item_type
dboper.add(**item_dict)

View File

@@ -1,6 +1,8 @@
import asyncio
import re
from typing import Any, Optional, Dict, Union, List
from app.agent import agent_manager
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
@@ -163,6 +165,10 @@ class MessageChain(ChainBase):
original_message_id=original_message_id, original_chat_id=original_chat_id)
else:
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
elif text.startswith('/ai') or text.startswith('/AI'):
# AI智能体处理
self._handle_ai_message(text=text, channel=channel, source=source,
userid=userid, username=username)
elif text.startswith('/'):
# 执行命令
self.eventmanager.send_event(
@@ -815,3 +821,86 @@ class MessageChain(ChainBase):
buttons.append(page_buttons)
return buttons
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
userid: Union[str, int], username: str) -> None:
"""
处理AI智能体消息
"""
try:
# 检查AI智能体是否启用
if not settings.AI_AGENT_ENABLE:
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="MoviePilot智能助手未启用请在系统设置中启用"
))
return
# 检查LLM配置
if not settings.LLM_API_KEY:
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="MoviePilot智能助未配置请在系统设置中配置"
))
return
# 提取用户消息
user_message = text[3:].strip() # 移除 "/ai" 前缀
if not user_message:
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="请输入您的问题或需求"
))
return
# 发送处理中消息
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="MoviePilot助手已收到您的请求请稍候..."
))
# 生成会话ID
session_id = f"user_{userid}_{hash(user_message) % 10000}"
# 在事件循环中处理
try:
loop = asyncio.get_event_loop()
loop.run_until_complete(
agent_manager.process_message(
session_id=session_id,
user_id=str(userid),
message=user_message,
channel=channel.value if channel else None,
source=source,
username=username
)
)
except RuntimeError:
# 如果没有事件循环,创建新的
asyncio.run(
agent_manager.process_message(
session_id=session_id,
user_id=str(userid),
message=user_message,
channel=channel.value if channel else None,
source=source,
username=username
)
)
except Exception as e:
logger.error(f"处理AI智能体消息失败: {e}")
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from typing import Dict, Tuple
from typing import List, Optional
from app.helper.sites import SitesHelper # noqa
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
@@ -16,7 +17,6 @@ from app.core.event import eventmanager, Event
from app.core.metainfo import MetaInfo
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.progress import ProgressHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import NotExistMediaInfo
@@ -86,13 +86,13 @@ class SearchChain(ChainBase):
self.save_cache(contexts, self.__result_temp_file)
return contexts
def last_search_results(self) -> List[Context]:
def last_search_results(self) -> Optional[List[Context]]:
"""
获取上次搜索结果
"""
return self.load_cache(self.__result_temp_file)
async def async_last_search_results(self) -> List[Context]:
async def async_last_search_results(self) -> Optional[List[Context]]:
"""
异步获取上次搜索结果
"""
@@ -324,9 +324,6 @@ class SearchChain(ChainBase):
:param _torrents: 种子列表
:return: 去重后的种子列表
"""
if not settings.SEARCH_MULTIPLE_NAME:
return _torrents
# 通过encosure去重
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
for t in _torrents}.values())
@@ -384,16 +381,23 @@ class SearchChain(ChainBase):
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
time.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
self.__search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
results = self.__search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
# 合并结果
search_count += 1
torrents.extend(results)
# 有结果则停止
if not settings.SEARCH_MULTIPLE_NAME and torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
# 处理结果
return self.__parse_result(

View File

@@ -56,7 +56,7 @@ class SiteChain(ChainBase):
if userdata:
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
name=site.get("name"),
payload=userdata.dict())
payload=userdata.model_dump())
# 发送事件
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": site.get("id")

View File

@@ -173,7 +173,7 @@ class StorageChain(ChainBase):
dir_item = fileitem if fileitem.type == "dir" else self.get_parent_item(fileitem)
if not dir_item:
logger.warn(f"{fileitem.storage}{fileitem.path} 上级目录不存在")
return False
return True
# 查找操作文件项匹配的配置目录(资源目录、媒体库目录)
associated_dir = max(

View File

@@ -150,7 +150,7 @@ class TorrentsChain(ChainBase):
return []
# 解析RSS
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
timeout=int(site.get("timeout") or 30))
timeout=int(site.get("timeout") or 30), ua=site.get("ua") if site.get("ua") else None)
if rss_items is None:
# rss过期尝试保留原配置生成新的rss
self.__renew_rss_url(domain=domain, site=site)

View File

@@ -33,6 +33,7 @@ from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey,
SystemConfigKey, ChainEventType, ContentType
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
downloader_lock = threading.Lock()
job_lock = threading.Lock()
@@ -329,8 +330,12 @@ class JobManager:
# 计算状态为完成的任务数
if __mediaid__ not in self._job_view:
return 0
return sum([task.fileitem.size for task in self._job_view[__mediaid__].tasks if
task.state == "completed" and task.fileitem.size is not None])
return sum([
task.fileitem.size if task.fileitem.size is not None
else (SystemUtils.get_directory_size(Path(task.fileitem.path)) if task.fileitem.storage == "local" else 0)
for task in self._job_view[__mediaid__].tasks
if task.state == "completed"
])
def total(self) -> int:
"""
@@ -1111,6 +1116,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
file_meta=file_meta)
if begin_ep is not None:
file_meta.begin_episode = begin_ep
if part is not None:
file_meta.part = part
if end_ep is not None:
file_meta.end_episode = end_ep
@@ -1120,10 +1126,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
downloadhis = DownloadHistoryOper()
if bluray_dir:
# 蓝光原盘,按目录名查询
download_history = downloadhis.get_by_path(str(file_path))
download_history = downloadhis.get_by_path(file_path.as_posix())
else:
# 按文件全路径查询
download_file = downloadhis.get_file_by_fullpath(str(file_path))
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
if download_file:
download_history = downloadhis.get_by_hash(download_file.download_hash)
@@ -1436,7 +1442,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
for keyword in exclude_words:
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
return True
return False
@@ -1472,7 +1478,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
file_path = save_path / file.name
# 如果存在未被屏蔽的媒体文件,则不删除种子
if (file_path.suffix in self.all_exts
and not self._is_blocked_by_exclude_words(str(file_path), transfer_exclude_words)
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
and file_path.exists()):
return False

View File

@@ -11,7 +11,7 @@ from pydantic.fields import Callable
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.event import Event, eventmanager
from app.core.workflow import WorkFlowManager
from app.workflow import WorkFlowManager
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.log import logger
@@ -180,7 +180,7 @@ class WorkflowExecutor:
"""
合并上下文
"""
for key, value in context.dict().items():
for key, value in context.model_dump().items():
if not getattr(self.context, key, None):
setattr(self.context, key, value)

View File

@@ -215,7 +215,7 @@ class Command(metaclass=Singleton):
except Exception as e:
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
"""
触发事件,允许调整命令数据
"""

View File

@@ -1,8 +1,10 @@
import contextvars
import inspect
import shutil
import tempfile
import threading
from abc import ABC, abstractmethod
from contextlib import contextmanager, asynccontextmanager
from functools import wraps
from pathlib import Path
from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union
@@ -27,6 +29,9 @@ DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
lock = threading.Lock()
# 上下文变量来控制缓存行为
_fresh = contextvars.ContextVar('fresh', default=False)
class CacheBackend(ABC):
"""
@@ -455,7 +460,7 @@ class MemoryBackend(CacheBackend):
if region_cache:
with lock:
region_cache.clear()
logger.info(f"Cleared cache for region: {region}")
logger.debug(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
@@ -589,13 +594,13 @@ class AsyncMemoryBackend(AsyncCacheBackend):
if region_cache:
with lock:
region_cache.clear()
logger.info(f"Cleared cache for region: {region}")
logger.debug(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
with lock:
region_cache.clear()
logger.info("Cleared all cache")
logger.info("All cache cleared")
async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]:
"""
@@ -1010,6 +1015,49 @@ class AsyncFileBackend(AsyncCacheBackend):
pass
@contextmanager
def fresh(fresh: bool = True):
"""
是否获取新数据(不使用缓存的值)
Usage:
with fresh():
result = some_cached_function()
"""
token = _fresh.set(fresh)
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
try:
yield
finally:
_fresh.reset(token)
logger.debug(f"Reset fresh mode. {id(token):#x}")
@asynccontextmanager
async def async_fresh(fresh: bool = True):
"""
是否获取新数据(不使用缓存的值)
Usage:
async with async_fresh():
result = await some_async_cached_function()
"""
token = _fresh.set(fresh)
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
try:
yield
finally:
_fresh.reset(token)
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
def is_fresh() -> bool:
"""
是否获取新数据
"""
try:
return _fresh.get()
except LookupError:
return False
def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend:
"""
获取文件缓存后端实例Redis或文件系统ttl仅在Redis环境中有效
@@ -1084,16 +1132,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
"""
def decorator(func):
# 检查是否为异步函数
is_async = inspect.iscoroutinefunction(func)
# 根据函数类型选择对应的缓存后端没有ttl时默认是 LRU 缓存,否则是 TTL 缓存
if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
def should_cache(value: Any) -> bool:
"""
@@ -1169,16 +1207,20 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
is_async = inspect.iscoroutinefunction(func)
if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 异步函数的缓存装饰器
@wraps(func)
async def async_wrapper(*args, **kwargs):
# 获取缓存键
cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = await cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
cache_region):
return cached_value
if not is_fresh():
# 尝试获取缓存
cached_value = await cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
cache_region):
return cached_value
# 执行异步函数并缓存结果
result = await func(*args, **kwargs)
# 判断是否需要缓存
@@ -1198,15 +1240,19 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
async_wrapper.cache_clear = cache_clear
return async_wrapper
else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 同步函数的缓存装饰器
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
if not is_fresh():
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存

View File

@@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
from urllib.parse import urlparse
from dotenv import set_key
from pydantic import BaseModel, BaseSettings, validator, Field
from pydantic import BaseModel, Field, ConfigDict, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from app.log import logger, log_settings, LogConfigModel
from app.schemas import MediaType
@@ -49,8 +50,7 @@ class ConfigModel(BaseModel):
Pydantic 配置模型,描述所有配置项及其类型和默认值
"""
class Config:
extra = "ignore" # 忽略未定义的配置项
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
# ==================== 基础应用配置 ====================
# 项目名称
@@ -75,6 +75,8 @@ class ConfigModel(BaseModel):
DEBUG: bool = False
# 是否开发模式
DEV: bool = False
# 高级设置模式
ADVANCED_MODE: bool = True
# ==================== 安全认证配置 ====================
# 密钥
@@ -87,8 +89,10 @@ class ConfigModel(BaseModel):
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
# RESOURCE_TOKEN过期时间
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
# 超级管理员
# 超级管理员初始用户名
SUPERUSER: str = "admin"
# 超级管理员初始密码
SUPERUSER_PASSWORD: Optional[str] = None
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
@@ -167,7 +171,7 @@ class ConfigModel(BaseModel):
# ==================== 媒体元数据配置 ====================
# 媒体搜索来源 themoviedb/douban/bangumi多个用,分隔
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
SEARCH_SOURCE: str = "themoviedb"
# 媒体识别来源 themoviedb/douban
RECOGNIZE_SOURCE: str = "themoviedb"
# 刮削来源 themoviedb/douban
@@ -252,7 +256,7 @@ class ConfigModel(BaseModel):
# 订阅搜索时间间隔(小时)
SUBSCRIBE_SEARCH_INTERVAL: int = 24
# 检查本地媒体库是否存在资源开关
LOCAL_EXISTS_SEARCH: bool = False
LOCAL_EXISTS_SEARCH: bool = True
# ==================== 站点配置 ====================
# 站点数据刷新间隔(小时)
@@ -364,6 +368,8 @@ class ConfigModel(BaseModel):
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
# 编码探测的最低置信度阈值
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
# 主动内存回收时间间隔分钟0为不启用
MEMORY_GC_INTERVAL: int = 30
# ==================== 安全配置 ====================
# 允许的图片缓存域名
@@ -392,24 +398,51 @@ class ConfigModel(BaseModel):
# ==================== 存储配置 ====================
# 对rclone进行快照对比时是否检查文件夹的修改时间
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
# 对OpenList进行快照对比时是否检查文件夹的修改时间
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
# ==================== Docker配置 ====================
# Docker Client API地址
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
# ==================== AI智能体配置 ====================
# AI智能体开关
AI_AGENT_ENABLE: bool = False
# LLM提供商 (openai/google/deepseek)
LLM_PROVIDER: str = "deepseek"
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"
# LLM API密钥
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
# LLM温度参数
LLM_TEMPERATURE: float = 0.1
# LLM最大迭代次数
LLM_MAX_ITERATIONS: int = 15
# LLM工具调用超时时间
LLM_TOOL_TIMEOUT: int = 300
# 是否启用详细日志
LLM_VERBOSE: bool = False
# 最大记忆消息数量
LLM_MAX_MEMORY_MESSAGES: int = 50
# 记忆保留天数
LLM_MEMORY_RETENTION_DAYS: int = 30
# Redis记忆保留天数如果使用Redis
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""
系统配置类
"""
class Config:
case_sensitive = True
env_file = SystemUtils.get_env_path()
env_file_encoding = "utf-8"
model_config = SettingsConfigDict(
case_sensitive=True,
env_file=SystemUtils.get_env_path(),
env_file_encoding="utf-8",
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -506,33 +539,54 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
return default, True
@validator('*', pre=True, always=True)
def generic_type_validator(cls, value: Any, field): # noqa
@model_validator(mode='before')
@classmethod
def generic_type_validator(cls, data: Any): # noqa
"""
通用校验器,尝试将配置值转换为期望的类型
"""
if field.name == "API_TOKEN":
converted_value, needs_update = cls.validate_api_token(value, value)
else:
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
field.name)
if needs_update:
cls.update_env_config(field, value, converted_value)
return converted_value
if not isinstance(data, dict):
return data
# 处理 API_TOKEN 特殊验证
if 'API_TOKEN' in data:
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
if needs_update:
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
data['API_TOKEN'] = converted_value
# 对其他字段进行类型转换
for field_name, field_info in cls.model_fields.items():
if field_name not in data:
continue
value = data[field_name]
if value is None:
continue
field = cls.model_fields.get(field_name)
if field:
converted_value, needs_update = cls.generic_type_converter(
value, value, field.annotation, field.default, field_name
)
if needs_update:
cls.update_env_config(field_name, value, converted_value)
data[field_name] = converted_value
return data
@staticmethod
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
"""
更新 env 配置
"""
message = None
is_converted = original_value is not None and str(original_value) != str(converted_value)
if is_converted:
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
logger.warning(message)
if field.name in os.environ:
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
if field_name in os.environ:
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
logger.warning(message)
return False, message
else:
@@ -542,10 +596,10 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
else:
value_to_write = str(converted_value) if converted_value is not None else ""
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field.name, value_to_set=value_to_write,
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
quote_mode="always")
if is_converted:
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
return True, message
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
@@ -559,19 +613,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return False, f"配置项 '{key}' 不存在"
try:
field = self.__fields__[key]
field = Settings.model_fields[key]
original_value = getattr(self, key)
if field.name == "API_TOKEN":
if key == "API_TOKEN":
converted_value, needs_update = self.validate_api_token(value, original_value)
else:
converted_value, needs_update = self.generic_type_converter(value,
original_value,
field.type_,
field.default,
key)
converted_value, needs_update = self.generic_type_converter(
value, original_value, field.annotation, field.default, key
)
# 如果没有抛出异常,则统一使用 converted_value 进行更新
if needs_update or str(value) != str(converted_value):
success, message = self.update_env_config(field, value, converted_value)
success, message = self.update_env_config(key, value, converted_value)
# 仅成功更新配置时,才更新内存
if success:
setattr(self, key, converted_value)

View File

@@ -250,6 +250,8 @@ class MediaInfo:
production_countries: list = field(default_factory=list)
# 语种
spoken_languages: list = field(default_factory=list)
# 所有发行日期
release_dates: list = field(default_factory=list)
# 状态
status: str = None
# 标签
@@ -257,7 +259,7 @@ class MediaInfo:
# 评价数量
vote_count: int = None
# 流行度
popularity: int = None
popularity: float = None
# 时长
runtime: int = None
# 下一集
@@ -433,6 +435,18 @@ class MediaInfo:
if self.release_date:
# 年份
self.year = self.release_date[:4]
# 所有发行日期
self.release_dates = [
{
"date": release_date.get("release_date"),
"iso_code": result.get("iso_3166_1"),
"note": release_date.get("note"),
"type": release_date.get("type"),
}
for result in info.get("release_dates", {}).get("results", [])
for release_date in result.get("release_dates", [])
if release_date.get("release_date")
]
else:
# 电视剧
self.title = info.get('name')

View File

@@ -1,3 +1,4 @@
import asyncio
import importlib
import inspect
import random
@@ -71,15 +72,26 @@ class EventManager(metaclass=Singleton):
"""
def __init__(self):
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
self.__event_queue = PriorityQueue() # 优先级队列
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {} # 广播事件的订阅者
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {} # 链式事件的订阅者
self.__disabled_handlers = set() # 禁用的事件处理器集合
self.__disabled_classes = set() # 禁用的事件处理器类集合
self.__lock = threading.Lock() # 线程锁
self.__event = threading.Event() # 退出事件
# 动态线程池,用于消费事件
self.__executor = ThreadHelper()
# 用于保存启动的事件消费者线程
self.__consumer_threads = []
# 优先级队列
self.__event_queue = PriorityQueue()
# 广播事件的订阅者
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {}
# 链式事件的订阅者
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {}
# 禁用的事件处理器集合
self.__disabled_handlers = set()
# 禁用的事件处理器类集合
self.__disabled_classes = set()
# 线程锁
self.__lock = threading.Lock()
# 退出事件
self.__event = threading.Event()
# 当前事件循环
self.loop = asyncio.get_event_loop()
def start(self):
"""
@@ -438,7 +450,15 @@ class EventManager(metaclass=Singleton):
isolated_event = Event(event_type=event.event_type,
event_data=event_data_copy,
priority=event.priority)
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
if inspect.iscoroutinefunction(handler):
# 对于异步函数,直接在事件循环中运行
asyncio.run_coroutine_threadsafe(
self.__safe_invoke_handler_async(handler, isolated_event),
self.loop
)
else:
# 对于同步函数,在线程池中运行
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
def __safe_invoke_handler(self, handler: Callable, event: Event):
"""
@@ -566,7 +586,8 @@ class EventManager(metaclass=Singleton):
# 插件同步函数在异步环境中运行,避免阻塞
await run_in_threadpool(method, event)
except Exception as e:
self.__handle_event_error(event=event, handler=handler, e=e, module_name=plugin.name)
self.__handle_event_error(event=event, module_name=plugin.name,
class_name=class_name, method_name=method_name, e=e)
async def __invoke_module_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
"""

View File

@@ -94,7 +94,6 @@ class MetaVideo(MetaBase):
title = re.sub(r'\d{4}[\s._-]\d{1,2}[\s._-]\d{1,2}', "", title)
# 拆分tokens
tokens = Tokens(title)
self.tokens = tokens
# 实例化StreamingPlatforms对象
streaming_platforms = StreamingPlatforms()
# 解析名称、年份、季、集、资源类型、分辨率等
@@ -102,7 +101,7 @@ class MetaVideo(MetaBase):
while token:
self._index += 1 # 更新当前处理的token索引
# Part
self.__init_part(token)
self.__init_part(token, tokens)
# 标题
if self._continue_flag:
self.__init_name(token)
@@ -123,7 +122,7 @@ class MetaVideo(MetaBase):
self.__init_resource_type(token)
# 流媒体平台
if self._continue_flag:
self.__init_web_source(token, streaming_platforms)
self.__init_web_source(token, tokens, streaming_platforms)
# 视频编码
if self._continue_flag:
self.__init_video_encode(token)
@@ -311,7 +310,7 @@ class MetaVideo(MetaBase):
self.en_name = token
self._last_token_type = "enname"
def __init_part(self, token: str):
def __init_part(self, token: str, tokens: Tokens):
"""
识别Part
"""
@@ -327,12 +326,12 @@ class MetaVideo(MetaBase):
if re_res:
if not self.part:
self.part = re_res.group(1)
nextv = self.tokens.cur()
nextv = tokens.cur()
if nextv \
and ((nextv.isdigit() and (len(nextv) == 1 or len(nextv) == 2 and nextv.startswith('0')))
or nextv.upper() in ['A', 'B', 'C', 'I', 'II', 'III']):
self.part = "%s%s" % (self.part, nextv)
self.tokens.get_next()
tokens.get_next()
self._last_token_type = "part"
self._continue_flag = False
# self._stop_name_flag = False
@@ -582,7 +581,7 @@ class MetaVideo(MetaBase):
self._effect.append(effect)
self._last_token = effect.upper()
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
def __init_web_source(self, token: str, tokens: Tokens, streaming_platforms: StreamingPlatforms):
"""
识别流媒体平台
"""
@@ -594,10 +593,10 @@ class MetaVideo(MetaBase):
prev_token = None
prev_idx = self._index - 2
if 0 <= prev_idx < len(self.tokens.tokens):
prev_token = self.tokens.tokens[prev_idx]
if 0 <= prev_idx < len(tokens.tokens):
prev_token = tokens.tokens[prev_idx]
next_token = self.tokens.peek()
next_token = tokens.peek()
if streaming_platforms.is_streaming_platform(token):
platform_name = streaming_platforms.get_streaming_platform_name(token)
@@ -616,7 +615,7 @@ class MetaVideo(MetaBase):
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
query_range = 2
if is_next:
self.tokens.get_next()
tokens.get_next()
break
if not platform_name:
@@ -626,8 +625,8 @@ class MetaVideo(MetaBase):
match_start_idx = self._index - query_range
match_end_idx = self._index - 1
start_index = max(0, match_start_idx - query_range)
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
tokens_to_check = self.tokens.tokens[start_index:end_index]
end_index = min(len(tokens.tokens), match_end_idx + 1 + query_range)
tokens_to_check = tokens.tokens[start_index:end_index]
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
self.web_source = platform_name

View File

@@ -1,3 +1,4 @@
import ast
import asyncio
import concurrent
import concurrent.futures
@@ -9,14 +10,15 @@ import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import threading
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
from fastapi import HTTPException
from starlette import status
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
from watchfiles import watch
from app import schemas
from app.core.cache import fresh, async_fresh
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.db.plugindata_oper import PluginDataOper
@@ -26,64 +28,12 @@ from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas.types import EventType, SystemConfigKey
from app.utils.crypto import RSAUtils
from app.utils.limit import rate_limit_window
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
class PluginMonitorHandler(FileSystemEventHandler):
def on_modified(self, event):
"""
插件文件修改后重载
"""
if event.is_directory:
return
# 使用 pathlib 处理文件路径,跳过非 .py 文件以及 pycache 目录中的文件
event_path = Path(event.src_path)
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
return
# 读取插件根目录下的__init__.py文件读取class XXXX(_PluginBase)的类名
try:
plugins_root = settings.ROOT_PATH / "app" / "plugins"
# 确保修改的文件在 plugins 目录下
if plugins_root not in event_path.parents:
return
# 获取插件目录路径没有找到__init__.py时说明不是有效包跳过插件重载
# 插件重载目前没有支持app/plugins/plugin/package/__init__.py的场景这里也不做支持
plugin_dir = event_path.parent
init_file = plugin_dir / "__init__.py"
if not init_file.exists():
logger.debug(f"{plugin_dir} 下没有找到 __init__.py跳过插件重载")
return
with open(init_file, "r", encoding="utf-8") as f:
lines = f.readlines()
pid = None
for line in lines:
if line.startswith("class") and "(_PluginBase)" in line:
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
if pid:
self.__reload_plugin(pid)
except Exception as e:
logger.error(f"插件文件修改后重载出错:{str(e)}")
@staticmethod
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
def __reload_plugin(pid):
"""
重新加载插件
"""
try:
logger.info(f"插件 {pid} 文件修改,重新加载...")
PluginManager().reload_plugin(pid)
except Exception as e:
logger.error(f"插件文件修改后重载出错:{str(e)}")
class PluginManager(metaclass=Singleton):
"""
插件管理器
@@ -96,8 +46,10 @@ class PluginManager(metaclass=Singleton):
self._running_plugins: dict = {}
# 配置Key
self._config_key: str = "plugin.%s"
# 监听器
self._observer: Observer = None
# 监控线程
self._monitor_thread: Optional[threading.Thread] = None
# 监控停止事件
self._stop_monitor_event = threading.Event()
# 开发者模式监测插件修改
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
self.__start_monitor()
@@ -264,7 +216,6 @@ class PluginManager(metaclass=Singleton):
# 导入模块
module = importlib.import_module(module_name)
importlib.reload(module)
# 检查模块中的类
for name, obj in module.__dict__.items():
@@ -318,10 +269,9 @@ class PluginManager(metaclass=Singleton):
重新加载插件文件修改监测
"""
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
if self._observer and self._observer.is_alive():
logger.info("插件文件修改监测已经在运行中...")
else:
self.__start_monitor()
# 先关闭已有监测,再重新启动
self.stop_monitor()
self.__start_monitor()
else:
self.stop_monitor()
@@ -329,25 +279,123 @@ class PluginManager(metaclass=Singleton):
"""
启用监测插件文件修改监测
"""
if self._monitor_thread and self._monitor_thread.is_alive():
logger.info("插件文件修改监测已经在运行中...")
return
logger.info("开始监测插件文件修改...")
monitor_handler = PluginMonitorHandler()
self._observer = Observer()
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
self._observer.start()
# 在启动新线程之前,确保停止事件是清除状态
self._stop_monitor_event.clear()
# 创建并启动监控线程
self._monitor_thread = threading.Thread(
target=self._run_file_watcher,
daemon=True
)
self._monitor_thread.start()
def stop_monitor(self):
"""
停止监测插件文件修改监测
"""
# 停止监测
if self._observer and self._observer.is_alive():
if self._monitor_thread and self._monitor_thread.is_alive():
logger.info("正在停止插件文件修改监测...")
self._observer.stop()
self._observer.join()
self._stop_monitor_event.set()
self._monitor_thread.join(timeout=5)
if self._monitor_thread.is_alive():
logger.warning("插件文件修改监测线程在5秒内未能正常停止。")
self._monitor_thread = None
logger.info("插件文件修改监测停止完成")
else:
logger.info("未启用插件文件修改监测,无需停止")
def _run_file_watcher(self):
"""
运行 watchfiles 监视器的主循环。
"""
# 监视插件目录
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
logger.info(">>> 监控线程已启动准备进入watch循环...")
# 使用 watchfiles 监视目录变化,并响应变化事件
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
yield_on_timeout=True):
# 如果收到停止事件,退出循环
if not changes:
continue
# 处理变化事件
plugins_to_reload = set()
for _change_type, path_str in changes:
event_path = Path(path_str)
# 跳过非 .py 文件以及 pycache 目录中的文件
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
continue
# 解析插件ID
pid = self._get_plugin_id_from_path(event_path)
# 跳过无效插件文件
if pid:
# 收集需要重载的插件ID自动去重避免重复重载
plugins_to_reload.add(pid)
# 触发重载
if plugins_to_reload:
logger.info(f"检测到插件文件变化,准备重载: {list(plugins_to_reload)}")
for pid in plugins_to_reload:
try:
self.reload_plugin(pid)
except Exception as e:
logger.error(f"插件 {pid} 热重载失败: {e}", exc_info=True)
@staticmethod
def _get_plugin_id_from_path(event_path: Path) -> Optional[str]:
"""
根据文件路径解析出插件的ID。
:param event_path: 被修改文件的 Path 对象。
:return: 插件ID字符串如果不是有效插件文件则返回 None。
"""
try:
plugins_root = settings.ROOT_PATH / "app" / "plugins"
# 确保修改的文件在 plugins 目录下
if not event_path.is_relative_to(plugins_root):
return None
try:
plugin_dir_name = event_path.relative_to(plugins_root).parts[0]
plugin_dir = plugins_root / plugin_dir_name
except (ValueError, IndexError):
return None
init_file = plugin_dir / "__init__.py"
if not init_file.exists():
return None
# 读取 __init__.py 文件,查找插件主类名
with open(init_file, "r", encoding="utf-8") as f:
source_code = f.read()
tree = ast.parse(source_code)
# 遍历AST查找继承自 _PluginBase 的类
for node in ast.walk(tree):
# 检查节点是否为类定义
if isinstance(node, ast.ClassDef):
# 遍历该类的所有基类
for base in node.bases:
# 检查基类是否是我们寻找的 _PluginBase
# ast.Name 用于处理简单的基类名
if isinstance(base, ast.Name) and base.id == '_PluginBase':
# 返回这个类的名字
return node.name
return None
except Exception as e:
logger.error(f"从路径解析插件ID时出错: {e}")
return None
@staticmethod
def __stop_plugin(plugin: Any):
"""
@@ -410,6 +458,10 @@ class PluginManager(metaclass=Singleton):
except KeyError:
# 模块可能已经被删除
pass
importlib.invalidate_caches()
logger.debug("已清除查找器的缓存")
if plugin_id:
if modules_to_remove:
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
@@ -693,6 +745,36 @@ class PluginManager(metaclass=Singleton):
logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}")
return ret_actions
def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取插件智能体工具
[{
"plugin_id": "插件ID",
"plugin_name": "插件名称",
"tools": [ToolClass1, ToolClass2, ...]
}]
"""
ret_tools = []
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools):
try:
if not plugin.get_state():
continue
tools = plugin.get_agent_tools()
if tools:
ret_tools.append({
"plugin_id": plugin_id,
"plugin_name": plugin.plugin_name,
"tools": tools
})
except Exception as e:
logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}")
return ret_tools
@staticmethod
def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str:
"""
@@ -1024,7 +1106,8 @@ class PluginManager(metaclass=Singleton):
# 已安装插件
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = PluginHelper().get_plugins(market, package_version, force)
with fresh(force):
online_plugins = PluginHelper().get_plugins(market, package_version)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
@@ -1231,7 +1314,8 @@ class PluginManager(metaclass=Singleton):
# 已安装插件
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
async with async_fresh(force):
online_plugins = await PluginHelper().async_get_plugins(market, package_version)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")

View File

@@ -66,6 +66,12 @@ class Site(Base):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(select(cls).where(cls.name == name))
return result.scalar_one_or_none()
@classmethod
@db_query
def get_actives(cls, db: Session):

View File

@@ -85,6 +85,12 @@ class SiteOper(DbOper):
"""
return await Site.async_get_by_domain(self._db, domain)
async def async_get_by_name(self, name: str) -> Site:
"""
异步按名称获取站点
"""
return await Site.async_get_by_name(self._db, name)
def get_domains_by_ids(self, ids: List[int]) -> List[str]:
"""
按ID获取站点域名

View File

@@ -128,10 +128,10 @@ class TransferHistoryOper(DbOper):
self.add_force(
src=fileitem.path,
src_storage=fileitem.storage,
src_fileitem=fileitem.dict(),
src_fileitem=fileitem.model_dump(),
dest=transferinfo.target_item.path if transferinfo.target_item else None,
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
mode=mode,
type=mediainfo.type.value,
category=mediainfo.category,
@@ -159,10 +159,10 @@ class TransferHistoryOper(DbOper):
his = self.add_force(
src=fileitem.path,
src_storage=fileitem.storage,
src_fileitem=fileitem.dict(),
src_fileitem=fileitem.model_dump(),
dest=transferinfo.target_item.path if transferinfo.target_item else None,
dest_storage=transferinfo.target_item.storage if transferinfo.target_item else None,
dest_fileitem=transferinfo.target_item.dict() if transferinfo.target_item else None,
dest_fileitem=transferinfo.target_item.model_dump() if transferinfo.target_item else None,
mode=mode,
type=mediainfo.type.value,
category=mediainfo.category,
@@ -188,7 +188,7 @@ class TransferHistoryOper(DbOper):
year=meta.year,
src=fileitem.path,
src_storage=fileitem.storage,
src_fileitem=fileitem.dict(),
src_fileitem=fileitem.model_dump(),
mode=mode,
seasons=meta.season,
episodes=meta.episode,

View File

@@ -367,7 +367,6 @@ class TemplateHelper(metaclass=SingletonClass):
return rendered
return None
except Exception as e:
logger.error(f"模板处理失败: {str(e)}")
raise ValueError(f"模板处理失败: {str(e)}") from e
@staticmethod
@@ -713,6 +712,7 @@ class MessageQueueManager(metaclass=SingletonClass):
self._running = False
logger.info("正在停止消息队列...")
self.thread.join()
logger.info("消息队列已停止")
class MessageHelper(metaclass=Singleton):

View File

@@ -48,35 +48,13 @@ class PluginHelper(metaclass=WeakSingleton):
if self.install_report():
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
def get_plugins(self, repo_url: str, package_version: Optional[str] = None,
force: bool = False) -> Optional[Dict[str, dict]]:
@cached(maxsize=128, ttl=1800)
def get_plugins(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
获取Github所有最新插件列表
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新,忽略缓存
"""
# 如果强制刷新,直接调用不带缓存的版本
if force:
return self._get_plugins_uncached(repo_url, package_version)
# 正常情况下调用带缓存的版本
return self._get_plugins_cached(repo_url, package_version)
@cached(maxsize=64, ttl=1800)
def _get_plugins_cached(self, repo_url: str, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
获取Github所有最新插件列表使用缓存
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
"""
return self._get_plugins_uncached(repo_url, package_version)
def _get_plugins_uncached(self, repo_url: str, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
获取Github所有最新插件列表不使用缓存
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
"""
if not repo_url:
return None
@@ -161,7 +139,7 @@ class PluginHelper(metaclass=WeakSingleton):
return res.json()
return {}
def install_reg(self, pid: str) -> bool:
def install_reg(self, pid: str, repo_url: Optional[str] = None) -> bool:
"""
安装插件统计
"""
@@ -170,24 +148,39 @@ class PluginHelper(metaclass=WeakSingleton):
if not pid:
return False
install_reg_url = self._install_reg.format(pid=pid)
res = RequestUtils(proxies=settings.PROXY, timeout=5).get_res(install_reg_url)
res = RequestUtils(
proxies=settings.PROXY,
content_type="application/json",
timeout=5
).post(install_reg_url, json={
"plugin_id": pid,
"repo_url": repo_url
})
if res and res.status_code == 200:
return True
return False
def install_report(self) -> bool:
def install_report(self, items: Optional[List[Tuple[str, Optional[str]]]] = None) -> bool:
"""
上报存量插件安装统计
上报存量插件安装统计(批量)。支持上送 repo_url。
:param items: 可选,形如 [(plugin_id, repo_url), ...];不传则回落到历史配置,仅上送 plugin_id。
"""
if not settings.PLUGIN_STATISTIC_SHARE:
return False
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
return False
payload_plugins = []
if items:
for pid, repo_url in items:
if pid:
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
else:
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
return False
payload_plugins = [{"plugin_id": plugin, "repo_url": None} for plugin in plugins]
res = RequestUtils(proxies=settings.PROXY,
content_type="application/json",
timeout=5).post(self._install_report,
json={"plugins": [{"plugin_id": plugin} for plugin in plugins]})
json={"plugins": payload_plugins})
return True if res else False
def install(self, pid: str, repo_url: str, package_version: Optional[str] = None, force_install: bool = False) \
@@ -252,16 +245,16 @@ class PluginHelper(metaclass=WeakSingleton):
# 使用 release 进行安装
def prepare_release() -> Tuple[bool, str]:
return self.__install_from_release(
pid.lower(), user_repo, release_tag
pid, user_repo, release_tag
)
return self.__install_flow_sync(pid.lower(), force_install, prepare_release)
return self.__install_flow_sync(pid, force_install, prepare_release, repo_url)
else:
# 如果 release_tag 不存在,说明插件没有发布版本,使用文件列表方式安装
def prepare_filelist() -> Tuple[bool, str]:
return self.__prepare_content_via_filelist_sync(pid.lower(), user_repo, package_version)
return self.__install_flow_sync(pid.lower(), force_install, prepare_filelist)
return self.__install_flow_sync(pid, force_install, prepare_filelist, repo_url)
def __get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
Tuple[Optional[list], Optional[str]]:
@@ -275,7 +268,7 @@ class PluginHelper(metaclass=WeakSingleton):
# 如果 package_version 存在(如 "v2"),则加上版本号
if package_version:
file_api += f".{package_version}"
file_api += f"/{pid}"
file_api += f"/{pid.lower()}"
res = self.__request_with_fallback(file_api,
headers=settings.REPO_GITHUB_HEADERS(repo=user_repo),
@@ -408,8 +401,8 @@ class PluginHelper(metaclass=WeakSingleton):
:param pid: 插件 ID
:return: 备份目录路径
"""
plugin_dir = PLUGIN_DIR / pid
backup_dir = Path(settings.TEMP_PATH) / "plugin_backup" / pid
plugin_dir = PLUGIN_DIR / pid.lower()
backup_dir = Path(settings.TEMP_PATH) / "plugin_backup" / pid.lower()
if plugin_dir.exists():
# 备份时清理已有的备份目录,防止残留文件影响
@@ -429,7 +422,7 @@ class PluginHelper(metaclass=WeakSingleton):
:param pid: 插件 ID
:param backup_dir: 备份目录路径
"""
plugin_dir = PLUGIN_DIR / pid
plugin_dir = PLUGIN_DIR / pid.lower()
if plugin_dir.exists():
shutil.rmtree(plugin_dir, ignore_errors=True)
logger.debug(f"{pid} 已清理插件目录 {plugin_dir}")
@@ -446,7 +439,7 @@ class PluginHelper(metaclass=WeakSingleton):
删除旧插件
:param pid: 插件 ID
"""
plugin_dir = PLUGIN_DIR / pid
plugin_dir = PLUGIN_DIR / pid.lower()
if plugin_dir.exists():
shutil.rmtree(plugin_dir, ignore_errors=True)
@@ -557,41 +550,42 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"获取插件 {pid} 元数据失败:{e}")
return {}
def __install_flow_sync(self, pid_lower: str, force_install: bool,
prepare_content: Callable[[], Tuple[bool, str]]) -> Tuple[bool, str]:
def __install_flow_sync(self, pid: str, force_install: bool,
prepare_content: Callable[[], Tuple[bool, str]],
repo_url: Optional[str] = None) -> Tuple[bool, str]:
"""
同步安装统一流程:备份→清理→准备内容→安装依赖→上报
prepare_content 负责把插件文件放到 app/plugins/{pid}
"""
backup_dir = None
if not force_install:
backup_dir = self.__backup_plugin(pid_lower)
backup_dir = self.__backup_plugin(pid)
self.__remove_old_plugin(pid_lower)
self.__remove_old_plugin(pid)
success, message = prepare_content()
if not success:
logger.error(f"{pid_lower} 准备插件内容失败:{message}")
logger.error(f"{pid} 准备插件内容失败:{message}")
if backup_dir:
self.__restore_plugin(pid_lower, backup_dir)
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
self.__restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
else:
self.__remove_old_plugin(pid_lower)
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
self.__remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, message
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid_lower)
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid)
if dependencies_exist and not dep_ok:
logger.error(f"{pid_lower} 依赖安装失败:{dep_msg}")
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
if backup_dir:
self.__restore_plugin(pid_lower, backup_dir)
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
self.__restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
else:
self.__remove_old_plugin(pid_lower)
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
self.__remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, dep_msg
self.install_reg(pid_lower)
self.install_reg(pid, repo_url)
return True, ""
def __install_from_release(self, pid: str, user_repo: str, release_tag: str) -> Tuple[bool, str]:
@@ -915,35 +909,13 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"[GitHub] 所有策略均请求失败URL: {url},请检查网络连接或 GitHub 配置")
return None
async def async_get_plugins(self, repo_url: str, package_version: Optional[str] = None,
force: bool = False) -> Optional[Dict[str, dict]]:
@cached(maxsize=128, ttl=1800)
async def async_get_plugins(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
异步获取Github所有最新插件列表
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新,忽略缓存
"""
# 异步版本直接调用不带缓存的版本(缓存在异步环境下可能有并发问题)
if force:
await self._async_get_plugins_cached.cache_clear()
return await self._async_get_plugins_cached(repo_url, package_version)
@cached(maxsize=128, ttl=1800)
async def _async_get_plugins_cached(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
获取Github所有最新插件列表使用缓存
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
"""
return await self._async_get_plugins_uncached(repo_url, package_version)
async def _async_get_plugins_uncached(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
异步获取Github所有最新插件列表不使用缓存
:param repo_url: Github仓库地址
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
"""
if not repo_url:
return None
@@ -980,7 +952,7 @@ class PluginHelper(metaclass=WeakSingleton):
return res.json()
return {}
async def async_install_reg(self, pid: str) -> bool:
async def async_install_reg(self, pid: str, repo_url: Optional[str] = None) -> bool:
"""
异步安装插件统计
"""
@@ -989,24 +961,39 @@ class PluginHelper(metaclass=WeakSingleton):
if not pid:
return False
install_reg_url = self._install_reg.format(pid=pid)
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5).get_res(install_reg_url)
res = await AsyncRequestUtils(
proxies=settings.PROXY,
content_type="application/json",
timeout=5
).post(install_reg_url, json={
"plugin_id": pid,
"repo_url": repo_url
})
if res and res.status_code == 200:
return True
return False
async def async_install_report(self) -> bool:
async def async_install_report(self, items: Optional[List[Tuple[str, Optional[str]]]] = None) -> bool:
"""
异步上报存量插件安装统计
异步上报存量插件安装统计(批量)。支持上送 repo_url。
:param items: 可选,形如 [(plugin_id, repo_url), ...];不传则回落到历史配置,仅上送 plugin_id。
"""
if not settings.PLUGIN_STATISTIC_SHARE:
return False
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
return False
payload_plugins = []
if items:
for pid, repo_url in items:
if pid:
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
else:
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
return False
payload_plugins = [{"plugin_id": plugin, "repo_url": None} for plugin in plugins]
res = await AsyncRequestUtils(proxies=settings.PROXY,
content_type="application/json",
timeout=5).post(self._install_report,
json={"plugins": [{"plugin_id": plugin} for plugin in plugins]})
json={"plugins": payload_plugins})
return True if res else False
async def __async_get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
@@ -1021,7 +1008,7 @@ class PluginHelper(metaclass=WeakSingleton):
# 如果 package_version 存在(如 "v2"),则加上版本号
if package_version:
file_api += f".{package_version}"
file_api += f"/{pid}"
file_api += f"/{pid.lower()}"
res = await self.__async_request_with_fallback(file_api,
headers=settings.REPO_GITHUB_HEADERS(repo=user_repo),
@@ -1133,8 +1120,8 @@ class PluginHelper(metaclass=WeakSingleton):
:param pid: 插件 ID
:return: 备份目录路径
"""
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
backup_dir = AsyncPath(settings.TEMP_PATH) / "plugin_backup" / pid
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
backup_dir = AsyncPath(settings.TEMP_PATH) / "plugin_backup" / pid.lower()
if await plugin_dir.exists():
# 备份时清理已有的备份目录,防止残留文件影响
@@ -1154,7 +1141,7 @@ class PluginHelper(metaclass=WeakSingleton):
:param pid: 插件 ID
:param backup_dir: 备份目录路径
"""
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
if await plugin_dir.exists():
await aioshutil.rmtree(plugin_dir, ignore_errors=True)
logger.debug(f"{pid} 已清理插件目录 {plugin_dir}")
@@ -1172,7 +1159,7 @@ class PluginHelper(metaclass=WeakSingleton):
异步删除旧插件
:param pid: 插件 ID
"""
plugin_dir = AsyncPath(PLUGIN_DIR) / pid
plugin_dir = AsyncPath(PLUGIN_DIR) / pid.lower()
if await plugin_dir.exists():
await aioshutil.rmtree(plugin_dir, ignore_errors=True)
@@ -1414,16 +1401,16 @@ class PluginHelper(metaclass=WeakSingleton):
# 使用 release 进行安装
async def prepare_release() -> Tuple[bool, str]:
return await self.__async_install_from_release(
pid.lower(), user_repo, release_tag
pid, user_repo, release_tag
)
return await self.__install_flow_async(pid.lower(), force_install, prepare_release)
return await self.__install_flow_async(pid, force_install, prepare_release, repo_url)
else:
# 如果没有 release_tag则使用文件列表安装方式
async def prepare_filelist() -> Tuple[bool, str]:
return await self.__prepare_content_via_filelist_async(pid.lower(), user_repo, package_version)
return await self.__prepare_content_via_filelist_async(pid, user_repo, package_version)
return await self.__install_flow_async(pid.lower(), force_install, prepare_filelist)
return await self.__install_flow_async(pid, force_install, prepare_filelist, repo_url)
async def __async_get_plugin_meta(self, pid: str, repo_url: str,
package_version: Optional[str]) -> dict:
@@ -1438,78 +1425,79 @@ class PluginHelper(metaclass=WeakSingleton):
logger.warn(f"获取插件 {pid} 元数据失败:{e}")
return {}
async def __install_flow_async(self, pid_lower: str, force_install: bool,
prepare_content: Callable[[], Awaitable[Tuple[bool, str]]]) -> Tuple[bool, str]:
async def __install_flow_async(self, pid: str, force_install: bool,
prepare_content: Callable[[], Awaitable[Tuple[bool, str]]],
repo_url: Optional[str] = None) -> Tuple[bool, str]:
"""
异步安装流程,处理插件内容准备、依赖安装和注册
"""
backup_dir = None
if not force_install:
backup_dir = await self.__async_backup_plugin(pid_lower)
backup_dir = await self.__async_backup_plugin(pid)
await self.__async_remove_old_plugin(pid_lower)
await self.__async_remove_old_plugin(pid)
success, message = await prepare_content()
if not success:
logger.error(f"{pid_lower} 准备插件内容失败:{message}")
logger.error(f"{pid} 准备插件内容失败:{message}")
if backup_dir:
await self.__async_restore_plugin(pid_lower, backup_dir)
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
await self.__async_restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
else:
await self.__async_remove_old_plugin(pid_lower)
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
await self.__async_remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, message
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid_lower)
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid)
if dependencies_exist and not dep_ok:
logger.error(f"{pid_lower} 依赖安装失败:{dep_msg}")
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
if backup_dir:
await self.__async_restore_plugin(pid_lower, backup_dir)
logger.warning(f"{pid_lower} 插件安装失败,已还原备份插件")
await self.__async_restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
else:
await self.__async_remove_old_plugin(pid_lower)
logger.warning(f"{pid_lower} 已清理对应插件目录,请尝试重新安装")
await self.__async_remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, dep_msg
await self.async_install_reg(pid_lower)
await self.async_install_reg(pid, repo_url)
return True, ""
def __prepare_content_via_filelist_sync(self, pid_lower: str, user_repo: str,
def __prepare_content_via_filelist_sync(self, pid: str, user_repo: str,
package_version: Optional[str]) -> Tuple[bool, str]:
"""
同步准备插件内容,通过文件列表获取插件文件和依赖
"""
file_list, msg = self.__get_file_list(pid_lower, user_repo, package_version)
file_list, msg = self.__get_file_list(pid, user_repo, package_version)
if not file_list:
return False, msg
requirements_file_info = next((f for f in file_list if f.get("name") == "requirements.txt"), None)
if requirements_file_info:
ok, m = self.__download_and_install_requirements(requirements_file_info, pid_lower, user_repo)
ok, m = self.__download_and_install_requirements(requirements_file_info, pid, user_repo)
if not ok:
logger.debug(f"{pid_lower} 依赖预安装失败:{m}")
logger.debug(f"{pid} 依赖预安装失败:{m}")
else:
logger.debug(f"{pid_lower} 依赖预安装成功")
ok, m = self.__download_files(pid_lower, file_list, user_repo, package_version, True)
logger.debug(f"{pid} 依赖预安装成功")
ok, m = self.__download_files(pid, file_list, user_repo, package_version, True)
if not ok:
return False, m
return True, ""
async def __prepare_content_via_filelist_async(self, pid_lower: str, user_repo: str,
async def __prepare_content_via_filelist_async(self, pid: str, user_repo: str,
package_version: Optional[str]) -> Tuple[bool, str]:
"""
异步准备插件内容,通过文件列表获取插件文件和依赖
"""
file_list, msg = await self.__async_get_file_list(pid_lower, user_repo, package_version)
file_list, msg = await self.__async_get_file_list(pid, user_repo, package_version)
if not file_list:
return False, msg
requirements_file_info = next((f for f in file_list if f.get("name") == "requirements.txt"), None)
if requirements_file_info:
ok, m = await self.__async_download_and_install_requirements(requirements_file_info, pid_lower, user_repo)
ok, m = await self.__async_download_and_install_requirements(requirements_file_info, pid, user_repo)
if not ok:
logger.debug(f"{pid_lower} 依赖预安装失败:{m}")
logger.debug(f"{pid} 依赖预安装失败:{m}")
else:
logger.debug(f"{pid_lower} 依赖预安装成功")
ok, m = await self.__async_download_files(pid_lower, file_list, user_repo, package_version, True)
logger.debug(f"{pid} 依赖预安装成功")
ok, m = await self.__async_download_files(pid, file_list, user_repo, package_version, True)
if not ok:
return False, m
return True, ""

View File

@@ -258,10 +258,10 @@ class RedisHelper(metaclass=Singleton):
for key in self.client.scan_iter(redis_key):
pipe.delete(key)
pipe.execute()
logger.info(f"Cleared Redis cache for region: {region}")
logger.debug(f"Cleared Redis cache for region: {region}")
else:
self.client.flushdb()
logger.info("Cleared all Redis cache")
logger.info("All Redis cache Cleared")
except Exception as e:
logger.error(f"Failed to clear cache, region: {region}, error: {e}")
@@ -496,7 +496,7 @@ class AsyncRedisHelper(metaclass=Singleton):
async for key in self.client.scan_iter(redis_key):
await pipe.delete(key)
await pipe.execute()
logger.info(f"Cleared Redis cache for region (async): {region}")
logger.debug(f"Cleared Redis cache for region (async): {region}")
else:
await self.client.flushdb()
logger.info("Cleared all Redis cache (async)")

View File

@@ -228,13 +228,14 @@ class RssHelper:
}
def parse(self, url, proxy: bool = False,
timeout: Optional[int] = 15, headers: dict = None) -> Union[List[dict], None, bool]:
timeout: Optional[int] = 15, headers: dict = None, ua: str = None) -> Union[List[dict], None, bool]:
"""
解析RSS订阅URL获取RSS中的种子信息
:param url: RSS地址
:param proxy: 是否使用代理
:param timeout: 请求超时
:param headers: 自定义请求头
:param ua: 自定义User-Agent
:return: 种子信息列表如为None代表Rss过期如果为False则为错误
"""
# 开始处理
@@ -243,8 +244,9 @@ class RssHelper:
return False
try:
ret = RequestUtils(proxies=settings.PROXY if proxy else None,
timeout=timeout, headers=headers).get_res(url)
ret = RequestUtils(ua=ua,
proxies=settings.PROXY if proxy else None,
timeout=timeout or 30, headers=headers).get_res(url)
if not ret:
logger.error(f"获取RSS失败请求返回空值URL: {url}")
return False
@@ -384,6 +386,9 @@ class RssHelper:
pubdate = ""
if pubdate_nodes and pubdate_nodes[0].text:
pubdate = StringUtils.get_time(pubdate_nodes[0].text)
if pubdate is not None:
# 转为本地时区
pubdate = pubdate.astimezone(tz=None)
# 获取豆瓣昵称
nickname_nodes = item.xpath('.//*[local-name()="creator"]')

View File

@@ -47,7 +47,7 @@ class StorageHelper:
if s.type == storage:
s.config = conf
break
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
def add_storage(self, storage: str, name: str, conf: dict):
"""
@@ -68,7 +68,7 @@ class StorageHelper:
name=name,
config=conf
))
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])
def reset_storage(self, storage: str):
"""
@@ -79,4 +79,4 @@ class StorageHelper:
if s.type == storage:
s.config = {}
break
SystemConfigOper().set(SystemConfigKey.Storages, [s.dict() for s in storagies])
SystemConfigOper().set(SystemConfigKey.Storages, [s.model_dump() for s in storagies])

View File

@@ -131,7 +131,9 @@ class SubscribeHelper(metaclass=WeakSingleton):
return []
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30,
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
"""
获取订阅统计数据
"""
@@ -139,16 +141,30 @@ class SubscribeHelper(metaclass=WeakSingleton):
if not enabled:
return []
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
params = {
"stype": stype,
"page": page,
"count": count
})
}
# 添加可选参数
if genre_id is not None:
params["genre_id"] = genre_id
if min_rating is not None:
params["min_rating"] = min_rating
if max_rating is not None:
params["max_rating"] = max_rating
if sort_type is not None:
params["sort_type"] = sort_type
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params=params)
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30,
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
"""
异步获取订阅统计数据
"""
@@ -156,11 +172,23 @@ class SubscribeHelper(metaclass=WeakSingleton):
if not enabled:
return []
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
params = {
"stype": stype,
"page": page,
"count": count
})
}
# 添加可选参数
if genre_id is not None:
params["genre_id"] = genre_id
if min_rating is not None:
params["min_rating"] = min_rating
if max_rating is not None:
params["max_rating"] = max_rating
if sort_type is not None:
params["sort_type"] = sort_type
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params=params)
return self._handle_list_response(res)
@@ -358,7 +386,9 @@ class SubscribeHelper(metaclass=WeakSingleton):
return self._handle_response(res, clear_cache=False)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30,
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
"""
获取订阅分享数据
"""
@@ -366,17 +396,30 @@ class SubscribeHelper(metaclass=WeakSingleton):
if not enabled:
return []
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
params = {
"name": name,
"page": page,
"count": count
})
}
# 添加可选参数
if genre_id is not None:
params["genre_id"] = genre_id
if min_rating is not None:
params["min_rating"] = min_rating
if max_rating is not None:
params["max_rating"] = max_rating
if sort_type is not None:
params["sort_type"] = sort_type
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params=params)
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> \
List[dict]:
async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30,
genre_id: Optional[int] = None, min_rating: Optional[float] = None,
max_rating: Optional[float] = None, sort_type: Optional[str] = None) -> List[dict]:
"""
异步获取订阅分享数据
"""
@@ -384,11 +427,23 @@ class SubscribeHelper(metaclass=WeakSingleton):
if not enabled:
return []
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
params = {
"name": name,
"page": page,
"count": count
})
}
# 添加可选参数
if genre_id is not None:
params["genre_id"] = genre_id
if min_rating is not None:
params["min_rating"] = min_rating
if max_rating is not None:
params["max_rating"] = max_rating
if sort_type is not None:
params["sort_type"] = sort_type
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params=params)
return self._handle_list_response(res)

View File

@@ -11,7 +11,8 @@ from pathlib import Path
from typing import Dict, Any, Optional
import click
from pydantic import BaseSettings, BaseModel
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from app.utils.system import SystemUtils
@@ -21,8 +22,7 @@ class LogConfigModel(BaseModel):
Pydantic 配置模型,描述所有配置项及其类型和默认值
"""
class Config:
extra = "ignore" # 忽略未定义的配置项
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
# 配置文件目录
CONFIG_DIR: Optional[str] = None
@@ -71,10 +71,11 @@ class LogSettings(BaseSettings, LogConfigModel):
"""
return self.LOG_MAX_FILE_SIZE * 1024 * 1024
class Config:
case_sensitive = True
env_file = SystemUtils.get_env_path()
env_file_encoding = "utf-8"
model_config = ConfigDict(
case_sensitive=True,
env_file=SystemUtils.get_env_path(),
env_file_encoding="utf-8"
)
# 实例化日志设置

View File

@@ -95,4 +95,4 @@ if __name__ == '__main__':
# 更新数据库
update_db()
# 启动API服务
Server.run()
Server.run()

View File

@@ -232,6 +232,19 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
super().__init__()
self._default_config_name: Optional[str] = None
def init_service(self, service_name: str,
service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None):
"""
初始化服务,获取配置并实例化对应服务
:param service_name: 服务名称,作为配置匹配的依据
:param service_type: 服务的类型可以是类类型Type[TService]、工厂函数Callable或 None 来跳过实例化
"""
# 重置默认配置名称
self.reset_default_config_name()
# 初始化服务
super().init_service(service_name, service_type)
def get_default_config_name(self) -> Optional[str]:
"""
获取默认服务配置的名称
@@ -263,6 +276,12 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
return {}
return {conf.name: conf for conf in configs if conf.type == self._service_name and conf.enabled}
def reset_default_config_name(self):
"""
重置默认配置名称
"""
self._default_config_name = None
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
"""

View File

@@ -984,11 +984,13 @@ class DoubanModule(_ModuleBase):
"""
if result:
doubanid = result.get("id")
if doubanid and not str(doubanid).isdigit():
doubanid = re.search(r"\d+", doubanid).group(0)
result["id"] = doubanid
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
return result
if doubanid:
if not str(doubanid).isdigit():
doubanid = re.search(r"\d+", doubanid).group(0)
result["id"] = doubanid
logger.info(f"{imdbid} 查询到豆瓣信息:{result.get('title')}")
return result
return None
return None
@staticmethod

View File

@@ -10,10 +10,10 @@ from requests import Response
from app import schemas
from app.core.config import settings
from app.log import logger
from app.schemas import MediaServerItem
from app.schemas.types import MediaType
from app.utils.http import RequestUtils
from app.utils.url import UrlUtils
from app.schemas import MediaServerItem
class Emby:
@@ -22,9 +22,10 @@ class Emby:
_apikey: Optional[str] = None
_sync_libraries: List[str] = []
user: Optional[Union[str, int]] = None
_username: Optional[str] = None
def __init__(self, host: Optional[str] = None, apikey: Optional[str] = None, play_host: Optional[str] = None,
sync_libraries: list = None, **kwargs):
username: Optional[str] = None, sync_libraries: list = None, **kwargs):
if not host or not apikey:
logger.error("Emby服务器配置不完整")
return
@@ -35,7 +36,8 @@ class Emby:
if self._playhost:
self._playhost = UrlUtils.standardize_base_url(self._playhost)
self._apikey = apikey
self.user = self.get_user(settings.SUPERUSER)
self._username = username
self.user = self.get_user(username or settings.SUPERUSER)
self.folders = self.get_emby_folders()
self.serverid = self.get_server_id()
self._sync_libraries = sync_libraries or []
@@ -139,7 +141,8 @@ class Emby:
logger.error(f"连接User/Views 出错:" + str(e))
return []
def get_librarys(self, username: Optional[str] = None, hidden: Optional[bool] = False) -> List[schemas.MediaServerLibrary]:
def get_librarys(self, username: Optional[str] = None, hidden: Optional[bool] = False) -> List[
schemas.MediaServerLibrary]:
"""
获取媒体服务器所有媒体库列表
"""
@@ -567,6 +570,7 @@ class Emby:
if library_id != "/":
return self.__refresh_emby_library_by_id(library_id)
logger.info(f"Emby媒体库刷新完成")
return True
def __get_emby_library_id_by_item(self, item: schemas.RefreshMediaItem) -> Optional[str]:
"""
@@ -636,7 +640,7 @@ class Emby:
item_type=item.get("Type"),
title=item.get("Name"),
original_title=item.get("OriginalTitle"),
year=item.get("ProductionYear"),
year=str(item.get("ProductionYear")),
tmdbid=int(tmdbid) if tmdbid else None,
imdbid=item.get("ProviderIds", {}).get("Imdb"),
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
@@ -706,9 +710,9 @@ class Emby:
yield items
elif item.get("Type") in ["Movie", "Series"]:
yield self.__format_item_info(item)
except Exception as e:
logger.error(f"连接Users/Items出错" + str(e))
return None
def get_webhook_message(self, form: any, args: dict) -> Optional[schemas.WebhookEventInfo]:
"""
@@ -1109,7 +1113,8 @@ class Emby:
return ""
return "%sItems/%s/Images/Primary" % (self._host, item_id)
def get_resume(self, num: Optional[int] = 12, username: Optional[str] = None) -> Optional[List[schemas.MediaServerPlayItem]]:
def get_resume(self, num: Optional[int] = 12, username: Optional[str] = None) -> Optional[
List[schemas.MediaServerPlayItem]]:
"""
获得继续观看
"""
@@ -1146,7 +1151,7 @@ class Emby:
link = self.get_play_url(item.get("Id"))
if item_type == MediaType.MOVIE.value:
title = item.get("Name")
subtitle = item.get("ProductionYear")
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
else:
title = f'{item.get("SeriesName")}'
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
@@ -1178,7 +1183,8 @@ class Emby:
logger.error(f"连接Users/Items/Resume出错" + str(e))
return []
def get_latest(self, num: Optional[int] = 20, username: Optional[str] = None) -> Optional[List[schemas.MediaServerPlayItem]]:
def get_latest(self, num: Optional[int] = 20, username: Optional[str] = None) -> Optional[
List[schemas.MediaServerPlayItem]]:
"""
获得最近更新
"""
@@ -1217,7 +1223,7 @@ class Emby:
ret_latest.append(schemas.MediaServerPlayItem(
id=item.get("Id"),
title=item.get("Name"),
subtitle=item.get("ProductionYear"),
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
type=item_type,
image=image,
link=link,

View File

@@ -15,7 +15,7 @@ def transfer_process(path: str) -> Callable[[int | float], None]:
"""
传输进度回调
"""
pbar = tqdm(total=100, desc="整理进度", unit="%")
pbar = tqdm(total=100, desc="进度", unit="%")
progress = ProgressHelper(HashUtils.md5(path))
progress.start()
@@ -23,7 +23,7 @@ def transfer_process(path: str) -> Callable[[int | float], None]:
"""
更新进度百分比
"""
percent_value = int(percent)
percent_value = round(percent, 2) if isinstance(percent, float) else percent
pbar.n = percent_value
# 更新进度
pbar.refresh()

View File

@@ -14,6 +14,7 @@ from app.log import logger
from app.modules.filemanager import StorageBase
from app.modules.filemanager.storages import transfer_process
from app.schemas.types import StorageSchema
from app.utils.http import RequestUtils
from app.utils.singleton import WeakSingleton
from app.utils.string import StringUtils
@@ -251,6 +252,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
# 检查会话
self._check_session()
# 错误日志控制
no_error_log = kwargs.pop("no_error_log", False)
try:
resp = self.session.request(
method, f"{self.base_url}{endpoint}",
@@ -273,7 +277,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
# 返回数据
ret_data = resp.json()
if ret_data.get("code"):
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
if not no_error_log:
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
if result_key:
return ret_data.get(result_key)
@@ -597,7 +602,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
file_size = local_path.stat().st_size
# 1. 创建文件并检查秒传
chunk_size = 100 * 1024 * 1024 # 分片大小 100M
chunk_size = 10 * 1024 * 1024 # 分片大小 10M
create_res = self._create_file(drive_id=target_dir.drive_id,
parent_file_id=target_dir.fileid,
file_name=target_name,
@@ -729,7 +734,25 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
progress_callback = transfer_process(Path(fileitem.path).as_posix())
try:
with requests.get(download_url, stream=True) as r:
# 构建请求头,包含必要的认证信息
headers = {
"User-Agent": settings.NORMAL_USER_AGENT,
"Referer": "https://www.aliyundrive.com/",
"Accept": "*/*",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
"Accept-Encoding": "gzip, deflate, br",
"Connection": "keep-alive",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "cross-site"
}
# 如果有access_token添加到请求头
if self.access_token:
headers["Authorization"] = f"Bearer {self.access_token}"
request_utils = RequestUtils(headers=headers)
with request_utils.get_stream(download_url, raise_exception=True) as r:
r.raise_for_status()
downloaded_size = 0
with open(local_path, "wb") as f:
@@ -748,22 +771,13 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
# 完成下载
progress_callback(100)
logger.info(f"【阿里云盘】下载完成: {fileitem.name}")
except requests.exceptions.RequestException as e:
logger.error(f"【阿里云盘】下载网络错误: {fileitem.name} - {str(e)}")
# 删除可能部分下载的文件
if local_path.exists():
local_path.unlink()
return None
return local_path
except Exception as e:
logger.error(f"【阿里云盘】下载失败: {fileitem.name} - {str(e)}")
# 删除可能部分下载的文件
if local_path.exists():
local_path.unlink()
return None
return local_path
def check(self) -> bool:
return self.access_token is not None
@@ -815,7 +829,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
json={
"drive_id": drive_id or self._default_drive_id,
"file_path": path.as_posix()
}
},
no_error_log=True
)
if not resp:
return None

View File

@@ -4,8 +4,6 @@ from datetime import datetime
from pathlib import Path
from typing import Optional, List
import requests
from app import schemas
from app.core.cache import cached
from app.core.config import settings, global_vars
@@ -569,18 +567,22 @@ class Alist(StorageBase, metaclass=WeakSingleton):
else:
local_path = path / fileitem.name
with requests.get(download_url, headers=self.__get_header_with_token(), stream=True) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if global_vars.is_transfer_stopped(fileitem.path):
logger.info(f"【OpenList】{fileitem.path} 下载已取消!")
return None
f.write(chunk)
request_utils = RequestUtils(headers=self.__get_header_with_token())
try:
with request_utils.get_stream(download_url, raise_exception=True) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if global_vars.is_transfer_stopped(fileitem.path):
logger.info(f"【OpenList】{fileitem.path} 下载已取消!")
return None
f.write(chunk)
except Exception as e:
logger.error(f"【OpenList】下载文件 {fileitem.path} 失败:{e}")
if local_path.exists():
return local_path
if local_path.exists():
return local_path
return None
return local_path
def upload(
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None, task: bool = False

View File

@@ -26,8 +26,8 @@ class LocalStorage(StorageBase):
"softlink": "软链接"
}
# 文件块大小默认100MB
chunk_size = 100 * 1024 * 1024
# 文件块大小默认10MB
chunk_size = 10 * 1024 * 1024
def init_storage(self):
"""

View File

@@ -39,8 +39,8 @@ class SMB(StorageBase, metaclass=WeakSingleton):
"copy": "复制",
}
# 文件块大小默认100MB
chunk_size = 100 * 1024 * 1024
# 文件块大小默认10MB
chunk_size = 10 * 1024 * 1024
def __init__(self):
super().__init__()
@@ -49,6 +49,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
self._host = None
self._username = None
self._password = None
self._init_connection()
def _init_connection(self):
@@ -380,19 +381,95 @@ class SMB(StorageBase, metaclass=WeakSingleton):
self._check_connection()
smb_path = self._normalize_path(fileitem.path.rstrip("/"))
logger.info(f"【SMB】开始删除: {fileitem.path} (类型: {fileitem.type})")
# 先检查路径是否存在
if not smbclient.path.exists(smb_path):
logger.warn(f"【SMB】路径不存在跳过删除: {fileitem.path}")
return True
if fileitem.type == "dir":
# 删除目录
smbclient.rmdir(smb_path)
# 递归删除目录及其内容
logger.debug(f"【SMB】递归删除目录: {smb_path}")
self._recursive_delete(smb_path)
else:
# 删除文件
logger.debug(f"【SMB】删除文件: {smb_path}")
smbclient.remove(smb_path)
logger.info(f"【SMB】删除成功: {fileitem.path}")
return True
except Exception as e:
logger.error(f"【SMB】删除失败: {e}")
except SMBConnectionError as e:
logger.error(f"【SMB】删除失败 - 连接错误: {fileitem.path} - {e}")
return False
except SMBResponseException as e:
logger.error(f"【SMB】删除失败 - SMB响应错误: {fileitem.path} - {e}")
return False
except SMBException as e:
logger.error(f"【SMB】删除失败 - SMB错误: {fileitem.path} - {e}")
return False
except Exception as e:
logger.error(f"【SMB】删除失败 - 未知错误: {fileitem.path} - {e}")
return False
def _recursive_delete(self, smb_path: str):
"""
递归删除目录及其所有内容
"""
try:
# 检查路径是否存在
if not smbclient.path.exists(smb_path):
logger.debug(f"【SMB】路径不存在跳过删除: {smb_path}")
return
# 如果是文件,直接删除
if smbclient.path.isfile(smb_path):
logger.debug(f"【SMB】删除文件: {smb_path}")
smbclient.remove(smb_path)
return
# 如果是目录,先删除其内容
if smbclient.path.isdir(smb_path):
logger.debug(f"【SMB】开始删除目录内容: {smb_path}")
try:
# 列出目录内容
entries = smbclient.listdir(smb_path)
logger.debug(f"【SMB】目录 {smb_path} 包含 {len(entries)} 个项目")
for entry in entries:
if entry in [".", ".."]:
continue
entry_path = f"{smb_path}\\{entry}"
logger.debug(f"【SMB】递归删除子项: {entry_path}")
# 递归删除子项
self._recursive_delete(entry_path)
# 删除空目录
logger.debug(f"【SMB】删除空目录: {smb_path}")
smbclient.rmdir(smb_path)
logger.debug(f"【SMB】目录删除成功: {smb_path}")
except SMBResponseException as e:
# 如果目录不为空,尝试强制删除
logger.warn(f"【SMB】目录不为空尝试强制删除: {smb_path} - {e}")
# 使用remove方法尝试删除某些SMB服务器支持
try:
smbclient.remove(smb_path)
logger.info(f"【SMB】强制删除目录成功: {smb_path}")
except Exception as remove_error:
# 如果还是失败,记录错误并抛出异常
logger.error(f"【SMB】无法删除非空目录: {smb_path} - {remove_error}")
raise SMBConnectionError(f"无法删除非空目录 {smb_path}: {remove_error}")
except SMBException as e:
logger.error(f"【SMB】SMB操作失败: {smb_path} - {e}")
raise SMBConnectionError(f"SMB操作失败 {smb_path}: {e}")
except SMBConnectionError:
# 重新抛出SMB连接错误
raise
except Exception as e:
logger.error(f"【SMB】递归删除失败: {smb_path} - {e}")
raise SMBConnectionError(f"递归删除失败 {smb_path}: {e}")
def rename(self, fileitem: schemas.FileItem, name: str) -> bool:
"""
@@ -584,8 +661,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
析构函数,清理连接
"""
try:
# smbclient 自动管理连接池,但我们可以重置缓存
if hasattr(self, '_connected') and self._connected:
if self._connected:
reset_connection_cache()
except Exception as e:
logger.debug(f"【SMB】清理连接失败: {e}")

View File

@@ -91,6 +91,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
"refresh_time": int(time.time()),
**tokens
})
else:
return None
access_token = tokens.get("access_token")
if access_token:
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
@@ -209,6 +211,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
# 检查会话
self._check_session()
# 错误日志标志
no_error_log = kwargs.pop("no_error_log", False)
# 重试次数
retry_times = kwargs.pop("retry_limit", 5)
try:
resp = self.session.request(
method, f"{self.base_url}{endpoint}",
@@ -222,6 +229,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
logger.warn(f"【115】{method} 请求 {endpoint} 失败!")
return None
kwargs["retry_limit"] = retry_times
# 处理速率限制
if resp.status_code == 429:
reset_time = 5 + int(resp.headers.get("X-RateLimit-Reset", 60))
@@ -238,8 +247,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
ret_data = resp.json()
if ret_data.get("code") != 0:
error_msg = ret_data.get("message")
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
retry_times = kwargs.get("retry_limit", 5)
if not no_error_log:
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
if "已达到当前访问上限" in error_msg:
if retry_times <= 0:
logger.error(f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!")
@@ -536,8 +545,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
security_token=SecurityToken
)
bucket = oss2.Bucket(auth, endpoint, bucket_name) # noqa
# determine_part_size方法用于确定分片大小设置分片大小为 100M
part_size = determine_part_size(file_size, preferred_size=100 * 1024 * 1024)
# determine_part_size方法用于确定分片大小设置分片大小为 10M
part_size = determine_part_size(file_size, preferred_size=10 * 1024 * 1024)
# 初始化进度条
logger.info(f"【115】开始上传: {local_path} -> {target_path},分片大小:{StringUtils.str_filesize(part_size)}")
@@ -718,7 +727,8 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
"data",
data={
"path": path.as_posix()
}
},
no_error_log=True
)
if not resp:
return None

View File

@@ -14,10 +14,10 @@ from app.helper.directory import DirectoryHelper
from app.helper.message import TemplateHelper
from app.log import logger
from app.modules.filemanager.storages import StorageBase
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData, \
TransferRenameEventData
from app.schemas.types import MediaType, ChainEventType
from app.utils.system import SystemUtils
from app.schemas import TransferRenameEventData
lock = Lock()
@@ -129,7 +129,7 @@ class TransHandler:
transfer_type=transfer_type,
need_notify=need_notify,
)
return self.result.copy()
return self.result.model_copy()
else:
new_path = target_path / fileitem.name
# 整理目录
@@ -147,21 +147,18 @@ class TransHandler:
fileitem=fileitem,
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
logger.info(f"文件夹 {fileitem.path} 整理成功")
# 计算目录下所有文件大小
total_size = sum(file.stat().st_size for file in Path(fileitem.path).rglob('*') if file.is_file())
# 返回整理后的路径
self.__set_result(success=True,
fileitem=fileitem,
target_item=new_diritem,
target_diritem=new_diritem,
total_size=total_size,
need_scrape=need_scrape,
need_notify=need_notify,
transfer_type=transfer_type)
return self.result.copy()
return self.result.model_copy()
else:
# 整理单个文件
if mediainfo.type == MediaType.TV:
@@ -174,7 +171,7 @@ class TransHandler:
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
# 文件结束季为空
in_meta.end_season = None
@@ -210,7 +207,7 @@ class TransHandler:
transfer_type=transfer_type,
need_notify=need_notify,
)
return self.result.copy()
return self.result.model_copy()
else:
new_file = target_path / fileitem.name
folder_path = target_path
@@ -227,7 +224,7 @@ class TransHandler:
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
# 目标文件
target_item = target_oper.get_item(new_file)
if target_item:
@@ -239,7 +236,8 @@ class TransHandler:
overflag = True
if not overflag:
# 目标文件已存在
logger.info(f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
logger.info(
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
if overwrite_mode == 'always':
# 总是覆盖同名文件
overflag = True
@@ -257,7 +255,7 @@ class TransHandler:
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
elif overwrite_mode == 'never':
# 存在不覆盖
self.__set_result(success=False,
@@ -268,7 +266,7 @@ class TransHandler:
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
elif overwrite_mode == 'latest':
# 仅保留最新版本
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
@@ -295,7 +293,7 @@ class TransHandler:
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
logger.info(f"文件 {fileitem.path} 整理成功")
self.__set_result(success=True,
@@ -305,7 +303,7 @@ class TransHandler:
need_scrape=need_scrape,
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.copy()
return self.result.model_copy()
finally:
self.result = None
@@ -424,7 +422,7 @@ class TransHandler:
# 复制文件到新目录
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{fileitem.path} 复制文件失败"

View File

@@ -154,7 +154,7 @@ class FilterModule(_ModuleBase):
custom_rules = self.rulehelper.get_custom_rules()
for rule in custom_rules:
logger.info(f"加载自定义规则 {rule.id} - {rule.name}")
self.rule_set[rule.id] = rule.dict()
self.rule_set[rule.id] = rule.model_dump()
@staticmethod
def get_name() -> str:

View File

@@ -33,6 +33,8 @@ class SiteSchema(Enum):
MTorrent = "MTorrent"
Yema = "Yema"
HDDolby = "HDDolby"
Zhixing = "Zhixing"
Bitpt = "Bitpt"
class SiteParserBase(metaclass=ABCMeta):

View File

@@ -0,0 +1,161 @@
#
# 极速之星 https://bitpt.cn/
# author: ThedoRap
# time: 2025-10-02
#
# -*- coding: utf-8 -*-
import re
from typing import Optional, Tuple
from urllib.parse import urljoin, urlencode
from bs4 import BeautifulSoup
from app.modules.indexer.parser import SiteParserBase, SiteSchema
from app.utils.string import StringUtils
class BitptSiteUserInfo(SiteParserBase):
schema = SiteSchema.Bitpt
def _parse_site_page(self, html_text: str):
self._user_basic_page = "userdetails.php?uid={uid}"
self._user_detail_page = None
self._user_basic_params = {}
self._user_traffic_page = None
self._sys_mail_unread_page = None
self._user_mail_unread_page = None
self._mail_unread_params = {}
self._torrent_seeding_base = "browse.php"
self._torrent_seeding_params = {"t": "myseed", "st": "2", "d": "desc"}
self._torrent_seeding_headers = {}
self._addition_headers = {}
def _parse_logged_in(self, html_text):
soup = BeautifulSoup(html_text, 'html.parser')
return bool(soup.find(id='userinfotop'))
def _parse_user_base_info(self, html_text: str):
if not html_text:
return None
soup = BeautifulSoup(html_text, 'html.parser')
table = soup.find('table', class_='frmtable')
if not table:
return
rows = table.find_all('tr')
info_dict = {}
for row in rows:
cells = row.find_all('td')
if len(cells) == 2:
key = cells[0].text.strip()
value = cells[1].text.strip()
info_dict[key] = value
self.userid = info_dict.get('UID')
self.username = info_dict.get('用户名').split('\xa0')[0] if '用户名' in info_dict else None
self.user_level = info_dict.get('用户级别') if '用户级别' in info_dict else None
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
self.upload = StringUtils.num_filesize(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
self.download = StringUtils.num_filesize(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
bonus_str = info_dict.get('星辰', '')
self.bonus = float(re.search(r'累计([\d\.]+)', bonus_str).group(1)) if re.search(r'累计([\d\.]+)', bonus_str) else 0
self.message_unread = 0
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
self.seeding = 0
self.seeding_size = 0
else:
seeding_info = soup.find('div', style="margin:0 auto;width:90%;font-size:14px;margin-top:10px;margin-bottom:10px;text-align:center;")
if seeding_info:
seeding_link = seeding_info.find_all('a')[1].text if len(seeding_info.find_all('a')) > 1 else ''
match = re.search(r'当前上传的种子\((\d+)个, 共([\d\.]+ [KMGT]B)\)', seeding_link)
if match:
self.seeding = int(match.group(1))
self.seeding_size = StringUtils.num_filesize(match.group(2))
else:
self.seeding = 0
self.seeding_size = 0
def _parse_user_traffic_info(self, html_text: str):
pass
def _parse_user_detail_info(self, html_text: str):
pass
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
if not html_text:
return 0, 0
soup = BeautifulSoup(html_text, 'html.parser')
torrent_table = soup.find('table', class_='torrenttable')
if not torrent_table:
return 0, 0
rows = torrent_table.find_all('tr')
if len(rows) <= 1:
return 0, 0
torrents = [row for row in rows[1:] if 'btr' in row.get('class', [])]
page_seeding = 0
page_seeding_size = 0
for torrent in torrents:
size_td = torrent.find('td', class_='r')
if size_td:
size_a = size_td.find('a')
size_text = size_a.text.strip() if size_a else size_td.text.strip()
if size_text:
page_seeding += 1
page_seeding_size += StringUtils.num_filesize(size_text)
return page_seeding, page_seeding_size
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
pass
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
pass
def _parse_user_torrent_seeding_info(self, html_text: str):
pass
def parse(self):
super().parse()
if self._index_html:
soup = BeautifulSoup(self._index_html, 'html.parser')
user_link = soup.find('a', href=re.compile(r'userdetails\.php\?uid=\d+'))
if user_link:
uid_match = re.search(r'uid=(\d+)', user_link['href'])
if uid_match:
self.userid = uid_match.group(1)
if self.userid and self._user_basic_page:
basic_url = self._user_basic_page.format(uid=self.userid)
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
self._parse_user_base_info(basic_html)
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
seeding_base_url = urljoin(self._base_url, self._torrent_seeding_base)
params = self._torrent_seeding_params.copy()
page_num = 1
while True:
params['p'] = page_num
query_string = urlencode(params)
full_url = f"{seeding_base_url}?{query_string}"
seeding_html = self._get_page_content(url=full_url)
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
self.seeding += page_seeding
self.seeding_size += page_seeding_size
if page_seeding == 0:
break
page_num += 1
# 🔑 最终对外统一转字符串
self.userid = str(self.userid or "")
self.username = str(self.username or "")
self.user_level = str(self.user_level or "")
self.join_at = str(self.join_at or "")
self.upload = str(self.upload or 0)
self.download = str(self.download or 0)
self.ratio = str(self.ratio or 0)
self.bonus = str(self.bonus or 0.0)
self.message_unread = str(self.message_unread or 0)
self.seeding = str(self.seeding or 0)
self.seeding_size = str(self.seeding_size or 0)

View File

@@ -0,0 +1,184 @@
#
# 知行 http://pt.zhixing.bjtu.edu.cn/
# author: ThedoRap
# time: 2025-10-02
#
# -*- coding: utf-8 -*-
import re
from typing import Optional, Tuple
from app.modules.indexer.parser import SiteParserBase, SiteSchema
from app.utils.string import StringUtils
from bs4 import BeautifulSoup
from urllib.parse import urljoin
class ZhixingSiteUserInfo(SiteParserBase):
schema = SiteSchema.Zhixing
def _parse_site_page(self, html_text: str):
"""
获取站点页面地址
"""
self._user_basic_page = "user/{uid}/"
self._user_detail_page = None
self._user_basic_params = {}
self._user_traffic_page = None
self._sys_mail_unread_page = None
self._user_mail_unread_page = None
self._mail_unread_params = {}
self._torrent_seeding_base = "user/{uid}/seeding"
self._torrent_seeding_params = {}
self._torrent_seeding_headers = {}
self._addition_headers = {}
def _parse_logged_in(self, html_text):
"""
判断是否登录成功, 通过判断是否存在用户信息
"""
soup = BeautifulSoup(html_text, 'html.parser')
return bool(soup.find(id='um'))
def _parse_user_base_info(self, html_text: str):
"""
解析用户基本信息这里把_parse_user_traffic_info和_parse_user_detail_info合并到这里
"""
if not html_text:
return None
soup = BeautifulSoup(html_text, 'html.parser')
details_tabs = soup.find_all('div', class_='user-details-tabs')
info_dict = {}
for tab in details_tabs:
for p in tab.find_all('p'):
text = p.text.strip()
if '' in text:
parts = text.split('', 1)
elif ':' in text:
parts = text.split(':', 1)
else:
continue
if len(parts) == 2:
key = parts[0].strip()
value_text = parts[1].strip()
value = re.split(r'\s*\(', value_text)[0].strip().split('查看')[0].strip()
info_dict[key] = value
self._basic_info = info_dict # Save for fallback
self.userid = info_dict.get('UID')
self.username = info_dict.get('用户名')
self.user_level = info_dict.get('用户组')
self.join_at = StringUtils.unify_datetime_str(info_dict.get('注册时间')) if '注册时间' in info_dict else None
def num_filesize_safe(s: str):
if s:
s = s.strip()
if re.match(r'^\d+(\.\d+)?$', s):
s += ' B'
return StringUtils.num_filesize(s) if s else 0
self.upload = num_filesize_safe(info_dict.get('上传流量')) if '上传流量' in info_dict else 0
self.download = num_filesize_safe(info_dict.get('下载流量')) if '下载流量' in info_dict else 0
self.ratio = float(info_dict.get('共享率')) if '共享率' in info_dict else 0
self.bonus = float(info_dict.get('保种积分')) if '保种积分' in info_dict else 0.0
self.message_unread = 0 # 暂无消息解析
# Temporarily set seeding from basic, will override or fallback later
self.seeding = int(info_dict.get('当前保种数量')) if '当前保种数量' in info_dict else 0
self.seeding_size = num_filesize_safe(info_dict.get('当前保种容量')) if '当前保种容量' in info_dict else 0
def _parse_user_traffic_info(self, html_text: str):
pass
def _parse_user_detail_info(self, html_text: str):
pass
def _parse_user_torrent_seeding_page_info(self, html_text: str) -> Tuple[int, int]:
"""
解析用户做种信息单页,返回本页数量和大小
"""
if not html_text:
return 0, 0
soup = BeautifulSoup(html_text, 'html.parser')
torrents = soup.find_all('tr', id=re.compile(r'^t\d+'))
page_seeding = 0
page_seeding_size = 0
for torrent in torrents:
size_td = torrent.find('td', class_='r')
if size_td:
size_text = size_td.find('a').text if size_td.find('a') else size_td.text.strip()
page_seeding += 1
page_seeding_size += StringUtils.num_filesize(size_text)
return page_seeding, page_seeding_size
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
pass
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
pass
def _parse_user_torrent_seeding_info(self, html_text: str):
"""
占位,避免抽象类报错
"""
pass
def parse(self):
"""
解析站点信息
"""
super().parse()
# 先从首页解析userid
if self._index_html:
soup = BeautifulSoup(self._index_html, 'html.parser')
user_link = soup.find('a', href=re.compile(r'/user/\d+/'))
if user_link:
uid_match = re.search(r'/user/(\d+)/', user_link['href'])
if uid_match:
self.userid = uid_match.group(1)
# 如果有userid则格式化页面
if self.userid:
if self._user_basic_page:
basic_url = self._user_basic_page.format(uid=self.userid)
basic_html = self._get_page_content(url=urljoin(self._base_url, basic_url))
self._parse_user_base_info(basic_html)
if hasattr(self, '_torrent_seeding_base') and self._torrent_seeding_base:
self.seeding = 0 # Reset to sum from pages
self.seeding_size = 0
seeding_base = self._torrent_seeding_base.format(uid=self.userid)
seeding_base_url = urljoin(self._base_url, seeding_base)
page_num = 1
while True:
seeding_url = f"{seeding_base_url}/p{page_num}"
seeding_html = self._get_page_content(url=seeding_url)
page_seeding, page_seeding_size = self._parse_user_torrent_seeding_page_info(seeding_html)
self.seeding += page_seeding
self.seeding_size += page_seeding_size
if page_seeding == 0:
break
page_num += 1
# Fallback to basic if no seeding found from pages
if self.seeding == 0 and hasattr(self, '_basic_info'):
def num_filesize_safe(s: str):
if s:
s = s.strip()
if re.match(r'^\d+(\.\d+)?$', s):
s += ' B'
return StringUtils.num_filesize(s) if s else 0
self.seeding = int(self._basic_info.get('当前保种数量', 0))
self.seeding_size = num_filesize_safe(self._basic_info.get('当前保种容量', ''))
# 🔑 最终对外统一转字符串,避免 join 报错
self.userid = str(self.userid or "")
self.username = str(self.username or "")
self.user_level = str(self.user_level or "")
self.join_at = str(self.join_at or "")
self.upload = str(self.upload or 0)
self.download = str(self.download or 0)
self.ratio = str(self.ratio or 0)
self.bonus = str(self.bonus or 0.0)
self.message_unread = str(self.message_unread or 0)
self.seeding = str(self.seeding or 0)
self.seeding_size = str(self.seeding_size or 0)

View File

@@ -75,6 +75,9 @@ class MTorrentSpider:
categories = self._tv_category
else:
categories = self._movie_category
# mtorrent搜索imdb需要输入完整imdb链接参见 https://wiki.m-team.cc/zh-tw/imdbtosearch
if keyword and keyword.startswith("tt"):
keyword = f"https://www.imdb.com/title/{keyword}"
return {
"keyword": keyword,
"categories": categories,

View File

@@ -732,7 +732,7 @@ class Jellyfin:
item_type=item.get("Type"),
title=item.get("Name"),
original_title=item.get("OriginalTitle"),
year=item.get("ProductionYear"),
year=str(item.get("ProductionYear")),
tmdbid=int(tmdbid) if tmdbid else None,
imdbid=item.get("ProviderIds", {}).get("Imdb"),
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
@@ -924,7 +924,7 @@ class Jellyfin:
image = self.generate_image_link(item.get("Id"), "Backdrop", False)
if item_type == MediaType.MOVIE.value:
title = item.get("Name")
subtitle = item.get("ProductionYear")
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
else:
title = f'{item.get("SeriesName")}'
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
@@ -984,7 +984,7 @@ class Jellyfin:
ret_latest.append(schemas.MediaServerPlayItem(
id=item.get("Id"),
title=item.get("Name"),
subtitle=item.get("ProductionYear"),
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
type=item_type,
image=image,
link=link,

View File

@@ -437,7 +437,7 @@ class Plex:
@staticmethod
def __get_ids(guids: List[Any]) -> dict:
def parse_tmdb_id(value: str) -> (bool, int):
def parse_tmdb_id(value: str) -> tuple[bool, int]:
"""尝试将TMDB ID字符串转换为整数。如果成功返回(True, int),失败则返回(False, None)。"""
try:
int_value = int(value)
@@ -509,7 +509,7 @@ class Plex:
item_type=item.type,
title=item.title,
original_title=item.originalTitle,
year=item.year,
year=str(item.year),
tmdbid=ids.get("tmdb_id"),
imdbid=ids.get("imdb_id"),
tvdbid=ids.get("tvdb_id"),
@@ -746,7 +746,7 @@ class Plex:
item_type = MediaType.MOVIE.value if item.TYPE == "movie" else MediaType.TV.value
if item_type == MediaType.MOVIE.value:
title = item.title
subtitle = item.year
subtitle = str(item.year) if item.year else None
else:
title = item.grandparentTitle
subtitle = f"S{item.parentIndex}:E{item.index} - {item.title}"
@@ -825,7 +825,7 @@ class Plex:
ret_resume.append(schemas.MediaServerPlayItem(
id=item.key,
title=title,
subtitle=item.year,
subtitle=str(item.year) if item.year else None,
type=item_type,
image=image,
link=link,

View File

@@ -51,10 +51,6 @@ class RedisModule(_ModuleBase):
"""
if settings.CACHE_BACKEND_TYPE != "redis":
return None
redis_helper = RedisHelper()
try:
if redis_helper.test():
return True, ""
return False, "Redis连接失败请检查配置"
finally:
redis_helper.close()
if RedisHelper().test():
return True, ""
return False, "Redis连接失败请检查配置"

View File

@@ -264,7 +264,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
userid=userid, username=username, text=text)
return None
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息

View File

@@ -120,7 +120,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
logger.debug(f"解析SynologyChat消息失败{str(err)}")
return None
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息体

View File

@@ -261,7 +261,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return cleaned
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息体
@@ -283,7 +283,8 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
image=message.image, userid=userid, link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id)
original_chat_id=message.original_chat_id,
escape_markdown=kwargs.get("escape_markdown"))
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""

View File

@@ -31,7 +31,8 @@ class Telegram:
_callback_handlers: Dict[str, Callable] = {} # 存储回调处理器
_user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
_bot_username: Optional[str] = None # Bot username for mention detection
_escape_chars = r'_*[]()~`>#+-=|{}.!' # Telegram MarkdownV2
_markdown_escape_pattern = re.compile(f'([{re.escape(_escape_chars)}])') # Telegram MarkdownV2 规则转义特殊字符正则pattern
def __init__(self, TELEGRAM_TOKEN: Optional[str] = None, TELEGRAM_CHAT_ID: Optional[str] = None, **kwargs):
"""
初始化参数
@@ -52,7 +53,7 @@ class Telegram:
else:
apihelper.proxy = settings.PROXY
# bot
_bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown")
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
# 记录句柄
self._bot = _bot
# 获取并存储bot用户名用于@检测
@@ -215,7 +216,8 @@ class Telegram:
userid: Optional[str] = None, link: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None) -> Optional[bool]:
original_chat_id: Optional[str] = None,
escape_markdown: bool = True) -> Optional[bool]:
"""
发送Telegram消息
:param title: 消息标题
@@ -226,7 +228,8 @@ class Telegram:
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
:param original_message_id: 原消息ID如果提供则编辑原消息
:param original_chat_id: 原消息的聊天ID编辑消息时需要
:userid: 发送消息的目标用户ID为空则发给管理员
:param escape_markdown: 是否对内容进行Markdown转义
"""
if not self._telegram_token or not self._telegram_chat_id:
return None
@@ -236,10 +239,20 @@ class Telegram:
return False
try:
if title:
# 标题总是转义因为通常标题不包含Markdown格式
title = self.escape_markdown(title)
if text:
# 对text进行Markdown特殊字符转义
text = re.sub(r"([_`])", r"\\\1", text)
caption = f"*{title}*\n{text}"
if escape_markdown:
# 完全转义模式:转义所有特殊字符
text = self.escape_markdown(text)
else:
# 智能转义模式保留Markdown格式只转义普通文本中的特殊字符
text = self.escape_markdown_smart(text)
if title:
caption = f"*{title}*\n{text}"
else:
caption = text
else:
caption = f"*{title}*"
@@ -499,7 +512,7 @@ class Telegram:
if image:
# 如果有图片使用edit_message_media
media = InputMediaPhoto(media=image, caption=text, parse_mode="Markdown")
media = InputMediaPhoto(media=image, caption=text, parse_mode="MarkdownV2")
self._bot.edit_message_media(
chat_id=chat_id,
message_id=message_id,
@@ -512,7 +525,7 @@ class Telegram:
chat_id=chat_id,
message_id=message_id,
text=text,
parse_mode="Markdown",
parse_mode="MarkdownV2",
reply_markup=reply_markup
)
return True
@@ -542,7 +555,7 @@ class Telegram:
ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id,
photo=photo,
caption=caption,
parse_mode="Markdown",
parse_mode="MarkdownV2",
reply_markup=reply_markup)
if ret is None:
raise RetryException("发送图片消息失败")
@@ -553,12 +566,12 @@ class Telegram:
for i in range(0, len(caption), 4095):
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
text=caption[i:i + 4095],
parse_mode="Markdown",
parse_mode="MarkdownV2",
reply_markup=reply_markup if i == 0 else None)
else:
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
text=caption,
parse_mode="Markdown",
parse_mode="MarkdownV2",
reply_markup=reply_markup)
if ret is None:
raise RetryException("发送文本消息失败")
@@ -597,3 +610,84 @@ class Telegram:
self._bot.stop_polling()
self._polling_thread.join()
logger.info("Telegram消息接收服务已停止")
def escape_markdown(self, text: str) -> str:
# 按 Telegram MarkdownV2 规则转义特殊字符
if not isinstance(text, str):
return str(text) if text is not None else ""
return self._markdown_escape_pattern.sub(r'\\\1', text)
def escape_markdown_smart(self, text: str) -> str:
"""
智能转义Markdown文本只转义不在Markdown标记内的特殊字符
这样可以保留已有的Markdown格式如*粗体*、_斜体_、[链接](url)等),
同时转义普通文本中的特殊字符以避免API错误
注意Telegram MarkdownV2不支持以下语法这些字符会被转义
- 标题语法(#、##、###)会被转义为 \#、\##、\###
- 列表语法(-、*、+)会被转义为 \-、\*、\+
- 引用语法(>)会被转义为 \>
建议使用加粗文本模拟标题:*标题文本*
:param text: 要转义的文本
:return: 转义后的文本
"""
if not isinstance(text, str):
return str(text) if text is not None else ""
# 如果没有特殊字符,直接返回
if not any(char in self._escape_chars for char in text):
return text
# 标记受保护的区域Markdown标记内的内容不转义
protected = [False] * len(text)
# 按优先级匹配Markdown标记从最复杂到最简单
# 1. 链接:[text](url) - 必须最先匹配
link_pattern = r'\[([^\]]*)\]\(([^)]*)\)'
for match in re.finditer(link_pattern, text):
for i in range(match.start(), match.end()):
protected[i] = True
# 2. 粗体:*text*(单个*,不是**
bold_pattern = r'(?<!\*)\*(?!\*)([^*]+?)(?<!\*)\*(?!\*)'
for match in re.finditer(bold_pattern, text):
if not any(protected[match.start():match.end()]):
for i in range(match.start(), match.end()):
protected[i] = True
# 3. 斜体_text_单个_不是__
italic_pattern = r'(?<!_)_(?!_)([^_]+?)(?<!_)_(?!_)'
for match in re.finditer(italic_pattern, text):
if not any(protected[match.start():match.end()]):
for i in range(match.start(), match.end()):
protected[i] = True
# 4. 代码:`text`
code_pattern = r'`([^`]+)`'
for match in re.finditer(code_pattern, text):
if not any(protected[match.start():match.end()]):
for i in range(match.start(), match.end()):
protected[i] = True
# 5. 删除线:~text~
strikethrough_pattern = r'~([^~]+)~'
for match in re.finditer(strikethrough_pattern, text):
if not any(protected[match.start():match.end()]):
for i in range(match.start(), match.end()):
protected[i] = True
# 构建结果:只转义未保护区域的特殊字符
result = []
for i, char in enumerate(text):
if protected[i]:
# 受保护区域Markdown标记内不转义
result.append(char)
elif char in self._escape_chars:
# 未保护区域,转义特殊字符
result.append('\\' + char)
else:
result.append(char)
return ''.join(result)

View File

@@ -747,6 +747,9 @@ class TmdbApi:
logger.info("正在从TheDbMovie网站查询%s ..." % name)
tmdb_url = self._build_tmdb_search_url(name)
res = RequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(url=tmdb_url)
if res is None:
logger.error("无法连接TheDbMovie")
return None
# 响应验证
response_result = self._validate_response(res)
@@ -1857,6 +1860,9 @@ class TmdbApi:
tmdb_url = self._build_tmdb_search_url(name)
res = await AsyncRequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(
url=tmdb_url)
if res is None:
logger.error("无法连接TheDbMovie")
return None
# 响应验证
response_result = self._validate_response(res)

View File

@@ -8,7 +8,7 @@ from datetime import datetime
import requests
import requests.exceptions
from app.core.cache import cached
from app.core.cache import cached, fresh, async_fresh
from app.core.config import settings
from app.utils.http import RequestUtils, AsyncRequestUtils
from .exceptions import TMDbException
@@ -18,14 +18,12 @@ logger = logging.getLogger(__name__)
class TMDb(object):
def __init__(self, obj_cached=True, session=None, language=None):
def __init__(self, session=None, language=None):
self._api_key = settings.TMDB_API_KEY
self._language = language or settings.TMDB_LOCALE or "en-US"
self._session_id = None
self._session = session
self._wait_on_rate_limit = True
self._debug_enabled = False
self._cache_enabled = obj_cached
self._proxies = settings.PROXY
self._domain = settings.TMDB_API_DOMAIN
self._page = None
@@ -41,7 +39,6 @@ class TMDb(object):
self._remaining = 40
self._reset = None
self._timeout = 15
self.obj_cached = obj_cached
self.__clear_async_cache__ = False
@@ -111,36 +108,8 @@ class TMDb(object):
def wait_on_rate_limit(self, wait_on_rate_limit):
self._wait_on_rate_limit = bool(wait_on_rate_limit)
@property
def debug(self):
return self._debug_enabled
@debug.setter
def debug(self, debug):
self._debug_enabled = bool(debug)
@property
def cache(self):
return self._cache_enabled
@cache.setter
def cache(self, cache):
self._cache_enabled = bool(cache)
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
def cached_request(self, method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
return self.request(method, url, data, json)
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
async def async_cached_request(self, method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
if self.__clear_async_cache__:
self.__clear_async_cache__ = False
await self.async_cached_request.cache_clear()
return await self.async_request(method, url, data, json)
def request(self, method, url, data, json):
def request(self, method, url, data, json, **kwargs):
if method == "GET":
req = self._req.get_res(url, params=data, json=json)
else:
@@ -149,7 +118,8 @@ class TMDb(object):
raise TMDbException("无法连接TheMovieDb请检查网络连接")
return req
async def async_request(self, method, url, data, json):
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
async def async_request(self, method, url, data, json, **kwargs):
if method == "GET":
req = await self._async_req.get_res(url, params=data, json=json)
else:
@@ -160,7 +130,7 @@ class TMDb(object):
def cache_clear(self):
self.__clear_async_cache__ = True
return self.cached_request.cache_clear()
return self.request.cache_clear()
def _validate_api_key(self):
if self.api_key is None or self.api_key == "":
@@ -204,13 +174,6 @@ class TMDb(object):
if "total_pages" in json_data:
self._total_pages = json_data["total_pages"]
if self.debug:
logger.info(json_data)
if is_async:
logger.info(self.async_cached_request.cache_info())
else:
logger.info(self.cached_request.cache_info())
@staticmethod
def _handle_errors(json_data):
if "errors" in json_data:
@@ -224,10 +187,9 @@ class TMDb(object):
self._validate_api_key()
url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST":
req = self.cached_request(method, url, data, json)
else:
req = self.request(method, url, data, json)
with fresh(not call_cached or method == "POST"):
req = self.request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
if req is None:
return None
@@ -253,10 +215,13 @@ class TMDb(object):
self._validate_api_key()
url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST":
req = await self.async_cached_request(method, url, data, json)
else:
req = await self.async_request(method, url, data, json)
if self.__clear_async_cache__:
self.__clear_async_cache__ = False
await self.async_request.cache_clear()
async with async_fresh(not call_cached or method == "POST"):
req = await self.async_request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
if req is None:
return None

View File

@@ -154,7 +154,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
source = args.get("source")
if source:
server: TrimeMedia = self.get_instance(source)
server: Optional[TrimeMedia] = self.get_instance(source)
if not server:
return None
result = server.get_webhook_message(body)
@@ -247,7 +247,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
媒体数量统计
"""
if server:
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return None
servers = [server_obj]
@@ -268,7 +268,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
媒体库列表
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if server_obj:
return server_obj.get_librarys(hidden=hidden)
return None
@@ -290,7 +290,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
:return: 返回一个生成器对象,用于逐步获取媒体服务器中的项目
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if server_obj:
return server_obj.get_items(library_id, start_index, limit)
return None
@@ -301,7 +301,7 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
媒体库项目详情
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if server_obj:
return server_obj.get_iteminfo(item_id)
return None
@@ -312,7 +312,9 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
获取剧集信息
"""
server_obj: TrimeMedia = self.get_instance(server)
if not isinstance(item_id, str):
return None
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return None
_, seasoninfo = server_obj.get_tv_episodes(item_id=item_id)
@@ -329,10 +331,10 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
获取媒体服务器正在播放信息
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return []
return server_obj.get_resume(num=count)
return server_obj.get_resume(num=count) or []
def mediaserver_play_url(
self, server: str, item_id: Union[str, int]
@@ -340,7 +342,9 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
获取媒体库播放地址
"""
server_obj: TrimeMedia = self.get_instance(server)
if not isinstance(item_id, str):
return None
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return None
return server_obj.get_play_url(item_id)
@@ -354,10 +358,10 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
"""
获取媒体服务器最新入库条目
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return []
return server_obj.get_latest(num=count)
return server_obj.get_latest(num=count) or []
def mediaserver_latest_images(
self,
@@ -374,7 +378,31 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
:param remote: True为外网链接, False为内网链接
:return: 图片链接列表
"""
server_obj: TrimeMedia = self.get_instance(server)
server_obj: Optional[TrimeMedia] = self.get_instance(server)
if not server_obj:
return []
return server_obj.get_latest_backdrops(num=count, remote=remote)
return server_obj.get_latest_backdrops(num=count, remote=remote) or []
def mediaserver_image_cookies(
self,
server: Optional[str] = None,
image_url: Optional[str] = None,
**kwargs,
) -> Optional[str | dict]:
"""
获取飞牛影视服务器的图片Cookies
:param server: 媒体服务器名称
:param image_url: 图片网址
"""
if not image_url:
return None
if server:
server_obj = self.get_instance(server)
if not server_obj:
return None
return server_obj.get_image_cookies(image_url)
else:
for server_obj in self.get_instances().values():
if cookies := server_obj.get_image_cookies(image_url):
return cookies

View File

@@ -140,13 +140,13 @@ class Api:
self._token: Optional[str] = None
self._version: Optional[Version] = None
self._session = requests.Session()
self._request_utils = RequestUtils(session=self._session)
self._request_utils = RequestUtils(session=self._session, timeout=10)
def sys_version(self) -> Optional[Version]:
"""
飞牛影视版本号
"""
if (res := self.__request_api("/sys/version")) and res.success:
if (res := self.request("/sys/version")) and res.success:
if res.data:
self._version = Version(
frontend=res.data.get("version"),
@@ -162,7 +162,7 @@ class Api:
:return: 成功返回token 否则返回None
"""
if (
res := self.__request_api(
res := self.request(
"/login",
data={
"username": username,
@@ -178,7 +178,9 @@ class Api:
"""
退出账号
"""
if (res := self.__request_api("/user/logout", method="post")) and res.success:
if not self._token:
return True
if (res := self.request("/user/logout", method="post")) and res.success:
if res.data:
self._token = None
return True
@@ -188,7 +190,9 @@ class Api:
"""
用户列表(仅管理员有权访问)
"""
if (res := self.__request_api("/manager/user/list")) and res.success:
if (res := self.request("/manager/user/list")) and res.success:
if not res.data:
return []
return [
User(
guid=info.get("guid"),
@@ -203,7 +207,7 @@ class Api:
"""
当前用户信息
"""
if (res := self.__request_api("/user/info")) and res.success:
if (res := self.request("/user/info")) and res.success:
_user = User("", "")
_user.__dict__.update(res.data)
return _user
@@ -213,7 +217,7 @@ class Api:
"""
媒体数量统计
"""
if (res := self.__request_api("/mediadb/sum")) and res.success:
if (res := self.request("/mediadb/sum")) and res.success:
sums = MediaDbSummary()
sums.__dict__.update(res.data)
return sums
@@ -223,9 +227,9 @@ class Api:
"""
媒体库列表(普通用户)
"""
if (res := self.__request_api("/mediadb/list")) and res.success:
if (res := self.request("/mediadb/list")) and res.success:
_items = []
for info in res.data:
for info in res.data or []:
mdb = MediaDb(
guid=info.get("guid"),
category=Category(info.get("category")),
@@ -250,9 +254,9 @@ class Api:
"""
媒体库列表(管理员)
"""
if (res := self.__request_api("/mdb/list")) and res.success:
if (res := self.request("/mdb/list")) and res.success:
_items = []
for info in res.data:
for info in res.data or []:
mdb = MediaDb(
guid=info.get("guid"),
category=Category(info.get("category")),
@@ -271,7 +275,7 @@ class Api:
"""
扫描所有媒体库
"""
if (res := self.__request_api("/mdb/scanall", method="post")) and res.success:
if (res := self.request("/mdb/scanall", method="post")) and res.success:
if res.data:
return True
return False
@@ -280,9 +284,7 @@ class Api:
"""
扫描指定媒体库
"""
if (
res := self.__request_api(f"/mdb/scan/{mdb.guid}", data={})
) and res.success:
if (res := self.request(f"/mdb/scan/{mdb.guid}", data={})) and res.success:
if res.data:
return True
return False
@@ -291,9 +293,7 @@ class Api:
"""
当前正在运行的任务
"""
if (
res := self.__request_api("/task/running")
) and res.success:
if (res := self.request("/task/running")) and res.success:
if res.data:
# TODO 具体正在运行的任务
return True
@@ -341,7 +341,9 @@ class Api:
if exclude_grouped_video:
post["exclude_grouped_video"] = 1
if (res := self.__request_api("/item/list", data=post)) and res.success:
if (res := self.request("/item/list", data=post)) and res.success:
if not res.data:
return []
return [self.__build_item(info) for info in res.data.get("list", [])]
return None
@@ -350,8 +352,10 @@ class Api:
搜索影片、演员
"""
if (
res := self.__request_api("/search/list", params={"q": keywords})
res := self.request("/search/list", params={"q": keywords})
) and res.success:
if not res.data:
return []
return [self.__build_item(info) for info in res.data]
return None
@@ -359,7 +363,7 @@ class Api:
"""
查询媒体详情
"""
if (res := self.__request_api(f"/item/{guid}")) and res.success:
if (res := self.request(f"/item/{guid}")) and res.success:
return self.__build_item(res.data)
return None
@@ -370,7 +374,7 @@ class Api:
:param delete_file: True删除媒体文件False仅从媒体库移除
"""
if (
res := self.__request_api(
res := self.request(
f"/item/{guid}",
method="delete",
data={"delete_file": 1 if delete_file else 0, "media_guids": []},
@@ -384,7 +388,9 @@ class Api:
"""
查询季列表
"""
if (res := self.__request_api(f"/season/list/{tv_guid}")) and res.success:
if (res := self.request(f"/season/list/{tv_guid}")) and res.success:
if not res.data:
return []
return [self.__build_item(info) for info in res.data]
return None
@@ -392,7 +398,9 @@ class Api:
"""
查询剧集列表
"""
if (res := self.__request_api(f"/episode/list/{season_guid}")) and res.success:
if (res := self.request(f"/episode/list/{season_guid}")) and res.success:
if not res.data:
return []
return [self.__build_item(info) for info in res.data]
return None
@@ -400,7 +408,9 @@ class Api:
"""
继续观看列表
"""
if (res := self.__request_api("/play/list")) and res.success:
if (res := self.request("/play/list")) and res.success:
if not res.data:
return []
return [self.__build_item(info) for info in res.data]
return None
@@ -431,7 +441,7 @@ class Api:
sign = md5.hexdigest()
return f"nonce={nonce}&timestamp={ts}&sign={sign}"
def __request_api(
def request(
self,
api: str,
method: Optional[str] = None,
@@ -482,6 +492,8 @@ class Api:
queries_unquoted = None
headers = {
"User-Agent": settings.USER_AGENT,
"Accept": "application/json",
"Referer": self._host,
"Authorization": self._token,
"authx": self.__get_authx(api_path, json_body or queries_unquoted),
}

View File

@@ -5,6 +5,7 @@ import app.modules.trimemedia.api as fnapi
from app import schemas
from app.log import logger
from app.schemas import MediaType
from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils
@@ -13,12 +14,14 @@ class TrimeMedia:
_password: Optional[str] = None
_userinfo: Optional[fnapi.User] = None
_host: Optional[str] = None
_playhost: Optional[str] = None
_libraries: dict[str, fnapi.MediaDb] = {}
_sync_libraries: List[str] = []
_api: Optional[fnapi.Api] = None
_version: Optional[fnapi.Version] = None
def __init__(
self,
@@ -34,20 +37,19 @@ class TrimeMedia:
return
self._username = username
self._password = password
self._host = host
self._sync_libraries = sync_libraries or []
if (api := self.__create_api(host)) is None:
if not self.reconnect():
logger.error(f"请检查服务端地址 {host}")
return
self._api = api
if play_api := self.__create_api(play_host):
self._playhost = play_api.host
if result := self.__create_api(play_host):
self._playhost = result.api.host
result.api.close()
elif play_host:
logger.warning(f"请检查外网播放地址 {play_host}")
self._playhost = UrlUtils.standardize_base_url(play_host).rstrip("/")
self.reconnect()
@property
def api(self) -> Optional[fnapi.Api]:
"""
@@ -55,14 +57,26 @@ class TrimeMedia:
"""
return self._api
@property
def version(self) -> Optional[fnapi.Version]:
"""
获得飞牛API的版本
"""
return self._version
class _ApiCreateResult:
api: fnapi.Api
version: fnapi.Version
@staticmethod
def __create_api(host: Optional[str]) -> Optional[fnapi.Api]:
def __create_api(host: Optional[str]) -> Optional["TrimeMedia._ApiCreateResult"]:
"""
创建一个飞牛API
:param host: 服务端地址
:return: 如果地址无效、不可访问则返回None
"""
if not host:
return None
api_key = "16CCEB3D-AB42-077D-36A1-F355324E4237"
@@ -70,21 +84,35 @@ class TrimeMedia:
if not host.endswith("/v"):
# 尝试补上结尾的/v 测试能否正常访问
api = fnapi.Api(host + "/v", api_key)
if api.sys_version():
return api
res = TrimeMedia._ApiCreateResult()
res.api = fnapi.Api(host + "/v", api_key)
if fnver := res.api.sys_version():
res.version = fnver
return res
# 测试用户配置的地址
api = fnapi.Api(host, api_key)
return api if api.sys_version() else None
res = TrimeMedia._ApiCreateResult()
res.api = fnapi.Api(host, api_key)
if fnver := res.api.sys_version():
res.version = fnver
return res
return None
def close(self):
self.disconnect()
def is_configured(self) -> bool:
return self._api is not None
return bool(self._host and self._username and self._password)
def is_authenticated(self) -> bool:
return self.is_configured() and self._api.token is not None
"""
是否已登录
"""
return (
self.is_configured()
and self._api is not None
and self._api.token is not None
and self._userinfo is not None
)
def is_inactive(self) -> bool:
"""
@@ -101,10 +129,17 @@ class TrimeMedia:
"""
if not self.is_configured():
return False
if (fnver := self._api.sys_version()) is None:
self.disconnect()
if result := self.__create_api(self._host):
self._api = result.api
self._version = result.version
# 版本号:0.8.53, 服务版本:0.8.23
# 版本号:0.8.56, 服务版本:0.8.23 接口/memory/user/list改为/manager/user/list
logger.debug(
f"版本号:{result.version.frontend}, 服务版本:{result.version.backend}"
)
else:
return False
# 版本号:0.8.36, 服务版本:0.8.19
logger.debug(f"版本号:{fnver.frontend}, 服务版本:{fnver.backend}")
if self._api.login(self._username, self._password) is None:
return False
self._userinfo = self._api.user_info()
@@ -119,9 +154,10 @@ class TrimeMedia:
"""
断开与飞牛的连接
"""
if self.is_authenticated():
if self._api:
self._api.logout()
self._api.close()
self._api = None
self._userinfo = None
logger.debug(f"{self._username} 已断开飞牛影视")
@@ -163,7 +199,8 @@ class TrimeMedia:
for img_path in library.posters or []
],
link=f"{self._playhost or self._api.host}/library/{library.guid}",
server_type='trimemedia'
server_type="trimemedia",
use_cookies=True,
)
)
return libraries
@@ -205,10 +242,12 @@ class TrimeMedia:
return None
if not self.is_configured():
return None
feiniu = fnapi.Api(self._api.host, self._api.apikey)
if token := feiniu.login(username, password):
feiniu.logout()
return token
if result := self.__create_api(self._host):
try:
return result.api.login(username, password)
finally:
result.api.logout()
result.api.close()
def get_movies(
self, title: str, year: Optional[str] = None, tmdb_id: Optional[int] = None
@@ -410,7 +449,7 @@ class TrimeMedia:
item_type=item_type,
title=item.title,
original_title=item.original_title,
year=year,
year=str(year),
tmdbid=item.tmdb_id,
imdbid=item.imdb_id,
user_state=user_state,
@@ -459,7 +498,8 @@ class TrimeMedia:
if item.duration and item.ts is not None
else 0
),
server_type='trimemedia',
server_type="trimemedia",
use_cookies=True,
)
def get_items(
@@ -576,6 +616,7 @@ class TrimeMedia:
if (item_details := self._api.item(item.guid)) is None:
continue
if remote:
# FIXME 新版飞牛的壁纸无法直接在浏览器中访问
img_host = self._playhost or self._api.host
else:
img_host = self._api.host
@@ -604,3 +645,15 @@ class TrimeMedia:
)
else False
)
def get_image_cookies(self, image_url: str):
"""
获得指定图片的Cookies
"""
if not self.is_authenticated():
return None
if not image_url or not SecurityUtils.is_safe_url(
image_url, [self._api.host], strict=True
):
return None
return {"Trim-MC-token": self._api.token}

View File

@@ -139,7 +139,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
logger.error(f"VoceChat消息处理发生错误{str(err)}")
return None
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息内容

View File

@@ -71,7 +71,7 @@ class WebPushModule(_ModuleBase, _MessageBase):
def init_setting(self) -> Tuple[str, Union[str, bool]]:
pass
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息内容

View File

@@ -184,7 +184,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
logger.error(f"微信消息处理发生错误:{str(err)}")
return None
def post_message(self, message: Notification) -> None:
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
:param message: 消息内容

View File

@@ -46,12 +46,18 @@ class FileMonitorHandler(FileSystemEventHandler):
self.callback = callback
def on_created(self, event: FileSystemEvent):
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
file_size=Path(event.src_path).stat().st_size)
try:
self.callback.event_handler(event=event, text="创建", event_path=event.src_path,
file_size=Path(event.src_path).stat().st_size)
except Exception as e:
logger.error(f"on_created 异常: {e}")
def on_moved(self, event: FileSystemMovedEvent):
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
file_size=Path(event.dest_path).stat().st_size)
try:
self.callback.event_handler(event=event, text="移动", event_path=event.dest_path,
file_size=Path(event.dest_path).stat().st_size)
except Exception as e:
logger.error(f"on_moved 异常: {e}")
class Monitor(metaclass=SingletonClass):

View File

@@ -1,6 +1,6 @@
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, List, Dict, Tuple, Optional
from typing import Any, List, Dict, Tuple, Optional, Type
from app.chain import ChainBase
from app.core.config import settings
@@ -200,6 +200,20 @@ class _PluginBase(metaclass=ABCMeta):
"""
pass
def get_agent_tools(self) -> List[Type]:
"""
获取插件智能体工具
返回工具类列表,每个工具类必须继承自 MoviePilotTool
[ToolClass1, ToolClass2, ...]
对工具类的要求:
1、工具类必须继承自 app.agent.tools.base.MoviePilotTool
2、工具类需要实现 run 方法(异步方法)
3、工具类需要定义 name 和 description 属性
4、工具类可以定义 args_schema 来指定输入参数模型
"""
pass
@abstractmethod
def stop_service(self):
"""

View File

@@ -1,4 +1,5 @@
import asyncio
import gc
import inspect
import multiprocessing
import threading
@@ -30,6 +31,7 @@ from app.helper.wallpaper import WallpaperHelper
from app.log import logger
from app.schemas import Notification, NotificationType, Workflow, ConfigChangeEventData
from app.schemas.types import EventType, SystemConfigKey
from app.utils.gc import get_memory_usage
from app.utils.singleton import SingletonClass
from app.utils.timer import TimerUtils
@@ -181,6 +183,11 @@ class Scheduler(metaclass=SingletonClass):
"name": "订阅日历缓存",
"func": SubscribeChain().cache_calendar,
"running": False
},
"full_gc": {
"name": "主动内存回收",
"func": self.full_gc,
"running": False
}
}
@@ -413,6 +420,19 @@ class Scheduler(metaclass=SingletonClass):
}
)
# 主动内存回收
if settings.MEMORY_GC_INTERVAL:
self._scheduler.add_job(
self.start,
"interval",
id="full_gc",
name="主动内存回收",
minutes=settings.MEMORY_GC_INTERVAL,
kwargs={
'job_id': 'full_gc'
}
)
# 初始化工作流服务
self.init_workflow_jobs()
@@ -747,6 +767,17 @@ class Scheduler(metaclass=SingletonClass):
"""
SchedulerChain().clear_cache()
@staticmethod
def full_gc():
"""
主动内存回收
"""
memory_before = get_memory_usage()
collected = gc.collect()
memory_after = get_memory_usage()
memory_freed = memory_before - memory_after
logger.info(f"主动内存回收完成,回收对象数: {collected},释放内存: {memory_freed:.2f} MB")
def user_auth(self):
"""
用户认证检查

58
app/schemas/agent.py Normal file
View File

@@ -0,0 +1,58 @@
"""AI智能体相关数据模型"""
from datetime import datetime
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field, ConfigDict, field_serializer
class ConversationMemory(BaseModel):
"""对话记忆模型"""
session_id: str = Field(description="会话ID")
user_id: Optional[str] = Field(default=None, description="用户ID")
title: Optional[str] = Field(default=None, description="会话标题")
messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
model_config = ConfigDict()
@field_serializer('created_at', 'updated_at', when_used='json')
def serialize_datetime(self, value: datetime) -> str:
return value.isoformat()
class AgentState(BaseModel):
"""AI智能体状态模型"""
session_id: str = Field(description="会话ID")
current_task: Optional[str] = Field(default=None, description="当前任务")
is_thinking: bool = Field(default=False, description="是否正在思考")
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
model_config = ConfigDict()
@field_serializer('last_activity', when_used='json')
def serialize_datetime(self, value: datetime) -> str:
return value.isoformat()
class UserMessage(BaseModel):
"""用户消息模型"""
session_id: str = Field(description="会话ID")
content: str = Field(description="消息内容")
user_id: Optional[str] = Field(default=None, description="用户ID")
channel: Optional[str] = Field(default=None, description="消息渠道")
source: Optional[str] = Field(default=None, description="消息来源")
class ToolResult(BaseModel):
"""工具执行结果模型"""
session_id: str = Field(description="会话ID")
call_id: str = Field(description="调用ID")
success: bool = Field(description="是否成功")
result: Optional[str] = Field(default=None, description="执行结果")
error: Optional[str] = Field(default=None, description="错误信息")

View File

@@ -86,7 +86,7 @@ class MediaInfo(BaseModel):
# IMDB ID
imdb_id: Optional[str] = None
# TVDB ID
tvdb_id: Optional[str] = None
tvdb_id: Optional[int] = None
# 豆瓣ID
douban_id: Optional[str] = None
# Bangumi ID
@@ -158,6 +158,8 @@ class MediaInfo(BaseModel):
production_countries: Optional[list] = Field(default_factory=list)
# 语种
spoken_languages: Optional[list] = Field(default_factory=list)
# 所有发行日期
release_dates: list = Field(default_factory=list)
# 状态
status: Optional[str] = None
# 标签
@@ -167,7 +169,7 @@ class MediaInfo(BaseModel):
# 评价数量
vote_count: Optional[int] = 0
# 流行度
popularity: Optional[int] = 0
popularity: Optional[float] = 0.0
# 时长
runtime: Optional[int] = None
# 下一集

View File

@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Optional, Dict, Any, List, Set, Callable
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator
from app.schemas.message import MessageChannel
from app.schemas.file import FileItem
@@ -68,7 +68,8 @@ class AuthCredentials(ChainEventData):
channel: Optional[str] = Field(default=None, description="认证渠道")
service: Optional[str] = Field(default=None, description="服务名称")
@root_validator(pre=True)
@model_validator(mode='before')
@classmethod
def check_fields_based_on_grant_type(cls, values): # noqa
grant_type = values.get("grant_type")
if not grant_type:

View File

@@ -1,6 +1,6 @@
from typing import Optional, Any
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class DownloadHistory(BaseModel):
@@ -51,8 +51,7 @@ class DownloadHistory(BaseModel):
# 自定义剧集组
episode_group: Optional[str] = None
class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)
class TransferHistory(BaseModel):
@@ -97,5 +96,4 @@ class TransferHistory(BaseModel):
# 日期
date: Optional[str] = None
class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)

Some files were not shown because too many files have changed in this diff Show More