mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-13 07:26:45 +00:00
Compare commits
679 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48b1ac28de | ||
|
|
6e329b17a9 | ||
|
|
6a492198a8 | ||
|
|
8bf9b6e7cb | ||
|
|
42e23ef564 | ||
|
|
c6806ee648 | ||
|
|
076fae696c | ||
|
|
ed294d3ea4 | ||
|
|
043be409d0 | ||
|
|
a5e7483870 | ||
|
|
365335be46 | ||
|
|
62543dd171 | ||
|
|
e2eef8ff21 | ||
|
|
3acf937d56 | ||
|
|
d572e523ba | ||
|
|
82113abe88 | ||
|
|
b7d121c58f | ||
|
|
6d5a85b144 | ||
|
|
78121917c6 | ||
|
|
a0913f0e32 | ||
|
|
e96e284715 | ||
|
|
c572a1b607 | ||
|
|
1845311f98 | ||
|
|
4f806db8b7 | ||
|
|
22858cc1e9 | ||
|
|
a0329a3eb0 | ||
|
|
b3e92088ee | ||
|
|
46db1c20f1 | ||
|
|
9d182e53b2 | ||
|
|
1205fc7fdb | ||
|
|
ff2826a448 | ||
|
|
ee750115ec | ||
|
|
0e13d22c97 | ||
|
|
8e7d040ac4 | ||
|
|
6755202958 | ||
|
|
8b7374a687 | ||
|
|
c17cca2365 | ||
|
|
8016a9539a | ||
|
|
e885fb15a0 | ||
|
|
c7f098771b | ||
|
|
fcd0908032 | ||
|
|
7ff1285084 | ||
|
|
b45b603b97 | ||
|
|
247208b8a9 | ||
|
|
182c46037b | ||
|
|
438d3210bc | ||
|
|
d523c7c916 | ||
|
|
09a19e94d5 | ||
|
|
3971c145df | ||
|
|
055117d83d | ||
|
|
c6baf43986 | ||
|
|
4ff16af3a7 | ||
|
|
17a1bd352b | ||
|
|
7421ca09cc | ||
|
|
9797e696e5 | ||
|
|
c36d6d8b2d | ||
|
|
3873786b99 | ||
|
|
76fdba7f09 | ||
|
|
72799e9638 | ||
|
|
2e77d03fe9 | ||
|
|
0c58eae5e7 | ||
|
|
b609567c38 | ||
|
|
7ecfa44fa0 | ||
|
|
a685b1dc3b | ||
|
|
63ce49a17c | ||
|
|
820fbe4076 | ||
|
|
efa05b7775 | ||
|
|
003781e903 | ||
|
|
ee71bafc96 | ||
|
|
bdd5f1231e | ||
|
|
6fee532c96 | ||
|
|
78aaad7b59 | ||
|
|
b128b0ede2 | ||
|
|
737d2f3bc6 | ||
|
|
179be53a65 | ||
|
|
1867f5e7c2 | ||
|
|
6662d24565 | ||
|
|
5880566a99 | ||
|
|
5d05b32711 | ||
|
|
fa2b720e92 | ||
|
|
d381238f83 | ||
|
|
751d627ead | ||
|
|
3e66a8de9b | ||
|
|
266052b12b | ||
|
|
803f4328f4 | ||
|
|
8e95568e11 | ||
|
|
ab09ee4819 | ||
|
|
41f94a172f | ||
|
|
566e597994 | ||
|
|
765fb9c05f | ||
|
|
b6720a19f7 | ||
|
|
3b130651c4 | ||
|
|
3f6c35dabe | ||
|
|
db2a952bca | ||
|
|
0ea9770bc3 | ||
|
|
0b20956c90 | ||
|
|
9f73b47d54 | ||
|
|
ce9c99af71 | ||
|
|
784024fb5d | ||
|
|
1145b32299 | ||
|
|
ab71df0011 | ||
|
|
fb137252a9 | ||
|
|
f57a680306 | ||
|
|
8bb3eaa320 | ||
|
|
9489730a44 | ||
|
|
d4795bb897 | ||
|
|
63775872c7 | ||
|
|
beff508a1f | ||
|
|
deaae8a2c6 | ||
|
|
46a27bd50c | ||
|
|
24f2993433 | ||
|
|
c80bfbfac5 | ||
|
|
06abfc45c7 | ||
|
|
440a773081 | ||
|
|
0797bcb38b | ||
|
|
d463b5bf0d | ||
|
|
0733c8edcc | ||
|
|
86c7c05cb1 | ||
|
|
18ff7ce753 | ||
|
|
8f2ed1004d | ||
|
|
14961323c3 | ||
|
|
f8c682b183 | ||
|
|
dd92708f60 | ||
|
|
4d9eeccefa | ||
|
|
cd7b251031 | ||
|
|
db614180b9 | ||
|
|
b6e527e5f4 | ||
|
|
77c0f8f39e | ||
|
|
58816d73c8 | ||
|
|
3b194d282e | ||
|
|
397f66433d | ||
|
|
04a4ed1d0e | ||
|
|
625850d4e7 | ||
|
|
6c572baca5 | ||
|
|
ee0406a13f | ||
|
|
608a049ba3 | ||
|
|
4d9b5198e2 | ||
|
|
24b6c970aa | ||
|
|
239c47f469 | ||
|
|
f0fc64c517 | ||
|
|
8481fd38ce | ||
|
|
5f425129d5 | ||
|
|
92955b1315 | ||
|
|
a3872d5bb5 | ||
|
|
a123ff2c04 | ||
|
|
188de34306 | ||
|
|
3d43750e9b | ||
|
|
fea228c68d | ||
|
|
a71a28e563 | ||
|
|
3b5d4982b5 | ||
|
|
b201e9ab8c | ||
|
|
d30b9282fd | ||
|
|
4f304a70b7 | ||
|
|
59a54d4f04 | ||
|
|
1e94d794ed | ||
|
|
5bd210406b | ||
|
|
e00514d36d | ||
|
|
f013bf1931 | ||
|
|
107cbbad1d | ||
|
|
481f1f9d30 | ||
|
|
704364061c | ||
|
|
c1bd2d6cf1 | ||
|
|
a018e1228c | ||
|
|
d962d9c7f6 | ||
|
|
4ea28cbca5 | ||
|
|
1b48b8b4cc | ||
|
|
73df197e33 | ||
|
|
bdc66e55ca | ||
|
|
926343ee86 | ||
|
|
8e6021c5e7 | ||
|
|
ac2b6c76ce | ||
|
|
9e966d0a7f | ||
|
|
6c10defaa1 | ||
|
|
b6a76f6f7c | ||
|
|
84e5b77a5c | ||
|
|
89b0ea0bf1 | ||
|
|
48aeb98bf1 | ||
|
|
8a5d864812 | ||
|
|
ae79e645a6 | ||
|
|
0947deb372 | ||
|
|
69c92911a2 | ||
|
|
b16bb37b75 | ||
|
|
9c9ec8adf2 | ||
|
|
eb0e67fc42 | ||
|
|
9cc50bddab | ||
|
|
d3ba0fa487 | ||
|
|
39f6505a80 | ||
|
|
36a6802439 | ||
|
|
d7e2633a92 | ||
|
|
88049e741e | ||
|
|
ff7fb14087 | ||
|
|
816c64bd48 | ||
|
|
d2756e6f2d | ||
|
|
147e12acbb | ||
|
|
4098018ee9 | ||
|
|
133e7578b9 | ||
|
|
74a2bdbf09 | ||
|
|
f22bc68af4 | ||
|
|
26cc6da650 | ||
|
|
d21f1f1b87 | ||
|
|
7cdaafffe1 | ||
|
|
0265dca197 | ||
|
|
9d68366043 | ||
|
|
c8c671d915 | ||
|
|
142daa9d15 | ||
|
|
2552219991 | ||
|
|
a038b698d7 | ||
|
|
a3b222574e | ||
|
|
e0cd467293 | ||
|
|
9c056030d2 | ||
|
|
19efa9d4cc | ||
|
|
90633a6495 | ||
|
|
edc432fbd8 | ||
|
|
1b7bdbf516 | ||
|
|
8c1be70c85 | ||
|
|
b8e0c0db9e | ||
|
|
7b7fb6cc82 | ||
|
|
62512ba215 | ||
|
|
e1beb64c01 | ||
|
|
c81f26ddad | ||
|
|
340114c2a1 | ||
|
|
cd7767b331 | ||
|
|
25289dad8a | ||
|
|
47c6917129 | ||
|
|
6379cda148 | ||
|
|
91a124ab8f | ||
|
|
2357a7135e | ||
|
|
da0b3b3de9 | ||
|
|
6664fb1716 | ||
|
|
1206f24fa9 | ||
|
|
ffb5823e84 | ||
|
|
d45a7fb262 | ||
|
|
918d192c0f | ||
|
|
f7cd6eac50 | ||
|
|
88f4428ff0 | ||
|
|
069ea22ba2 | ||
|
|
8fac8c5307 | ||
|
|
2285befebb | ||
|
|
1cd0648e4e | ||
|
|
0b7ba285c6 | ||
|
|
30446c4526 | ||
|
|
9b843c9ed2 | ||
|
|
2ce1c3bef8 | ||
|
|
e463094dc7 | ||
|
|
71a9fe10f4 | ||
|
|
ba146e13ef | ||
|
|
c060d7e3e0 | ||
|
|
ba96678822 | ||
|
|
4f6354f383 | ||
|
|
2766e80346 | ||
|
|
7cc3777a60 | ||
|
|
cb1dd9f17d | ||
|
|
31f342fe4f | ||
|
|
e90359eb08 | ||
|
|
58b0768a30 | ||
|
|
3b04506893 | ||
|
|
354165aa0a | ||
|
|
343109836f | ||
|
|
fcadac2adb | ||
|
|
5e7dcdfe97 | ||
|
|
2ec9a57391 | ||
|
|
973c545723 | ||
|
|
fd62eecfef | ||
|
|
b5ca7058c2 | ||
|
|
57a48f099f | ||
|
|
4699f511bf | ||
|
|
cd8f7e72e0 | ||
|
|
78803fa284 | ||
|
|
2e8d75df16 | ||
|
|
7e3bbfd960 | ||
|
|
1734d53b3c | ||
|
|
f37540f4e5 | ||
|
|
addb9d836a | ||
|
|
4184d8c7ac | ||
|
|
724c15a68c | ||
|
|
499bdf9b48 | ||
|
|
41cd1ccda1 | ||
|
|
b9521cb3a9 | ||
|
|
1f40663b90 | ||
|
|
5261ed7c4c | ||
|
|
aa8768b18a | ||
|
|
aad07433f4 | ||
|
|
4a7630079b | ||
|
|
44a6ee1994 | ||
|
|
56bd6e69ed | ||
|
|
d1e04588d0 | ||
|
|
21cdaef6d5 | ||
|
|
a1723d18fb | ||
|
|
9e065138e9 | ||
|
|
1c73c92bfd | ||
|
|
bcd560d74e | ||
|
|
02339562ed | ||
|
|
e5804378c2 | ||
|
|
da1c8a162d | ||
|
|
d457a23a1f | ||
|
|
b6154e58b8 | ||
|
|
5f18776c61 | ||
|
|
68b0b9ec7a | ||
|
|
0f5036972e | ||
|
|
0b199b8421 | ||
|
|
a59730f6eb | ||
|
|
c6c84fe65b | ||
|
|
03c757bba6 | ||
|
|
bfeb8d238a | ||
|
|
daf0c08c4b | ||
|
|
d12c1b9ac4 | ||
|
|
bc242f4fd4 | ||
|
|
a240c1bca9 | ||
|
|
219aa6c574 | ||
|
|
abca1b481a | ||
|
|
db72fd2ef5 | ||
|
|
31cca58943 | ||
|
|
c06a4b759c | ||
|
|
f05a23a490 | ||
|
|
1e0f2ffde0 | ||
|
|
06df42ee3d | ||
|
|
65ee1638f7 | ||
|
|
87eefe7673 | ||
|
|
5c124d3988 | ||
|
|
8c69ce624f | ||
|
|
bb73acdde5 | ||
|
|
993bc3775b | ||
|
|
3d2ff28bcd | ||
|
|
9b78deb802 | ||
|
|
dadc525d0b | ||
|
|
22b2140c94 | ||
|
|
f07496a4a0 | ||
|
|
1b2938cbc8 | ||
|
|
d4d2f58830 | ||
|
|
b3113e13ec | ||
|
|
055c8e26f0 | ||
|
|
2a7a7239d7 | ||
|
|
2fa40dac3f | ||
|
|
6b4fbd7dc2 | ||
|
|
5b0bb19717 | ||
|
|
843dfc430a | ||
|
|
69cb07c527 | ||
|
|
89e8a64734 | ||
|
|
5eb2dec32d | ||
|
|
db0ea7d6c4 | ||
|
|
1eb85003de | ||
|
|
cca170f84a | ||
|
|
c8c016caa8 | ||
|
|
45d5874026 | ||
|
|
69b1ce60ff | ||
|
|
3ff3e4b106 | ||
|
|
dc50a68b01 | ||
|
|
968cfd8654 | ||
|
|
cf28d93be6 | ||
|
|
be08d6ebb5 | ||
|
|
4bc24f3b00 | ||
|
|
15833f94cf | ||
|
|
aeb297efcf | ||
|
|
d48c6b98e8 | ||
|
|
b79ccfafed | ||
|
|
c87ba59552 | ||
|
|
91fd71c858 | ||
|
|
6f64e67538 | ||
|
|
bd7a0b072f | ||
|
|
01ca001c97 | ||
|
|
324ad2a87c | ||
|
|
d9ad2630f0 | ||
|
|
83958a4a48 | ||
|
|
f6a6efdc42 | ||
|
|
1bbe7657b9 | ||
|
|
38189753b5 | ||
|
|
5b0e658617 | ||
|
|
b6cf54d57f | ||
|
|
e8058c8813 | ||
|
|
784868048d | ||
|
|
2bf9779f2f | ||
|
|
d98ceea381 | ||
|
|
1ab2da74b9 | ||
|
|
086b1f1403 | ||
|
|
3723cf8ac2 | ||
|
|
19608fa98e | ||
|
|
b0d17deda1 | ||
|
|
4c979c458e | ||
|
|
c5e93169ad | ||
|
|
1e2ca294de | ||
|
|
7165c4a275 | ||
|
|
cbe81ba33c | ||
|
|
fdbfae953d | ||
|
|
c7ba274877 | ||
|
|
8b15a16ca1 | ||
|
|
9f2c8d3811 | ||
|
|
7343dfbed8 | ||
|
|
90f74d8d2b | ||
|
|
7e3e0e1178 | ||
|
|
d890e38a10 | ||
|
|
e505b5c85f | ||
|
|
6230f55116 | ||
|
|
c8d0c14ebc | ||
|
|
6ac8455c74 | ||
|
|
143b21631f | ||
|
|
d760facad8 | ||
|
|
3a1a4c5cfe | ||
|
|
c3045e2cd4 | ||
|
|
1efb9af7ab | ||
|
|
e03471159a | ||
|
|
a92e493742 | ||
|
|
225d413ed1 | ||
|
|
184e4ba7d5 | ||
|
|
917cae27b1 | ||
|
|
60e0463051 | ||
|
|
c15022c7d5 | ||
|
|
2a84e3a606 | ||
|
|
fddbbd5714 | ||
|
|
51b8f7c713 | ||
|
|
e97c246741 | ||
|
|
9a81f55ac0 | ||
|
|
a38b702acc | ||
|
|
e4e0605e92 | ||
|
|
8875a8f12c | ||
|
|
4dd1deefa5 | ||
|
|
1f6dc93ea3 | ||
|
|
426e920fff | ||
|
|
1f6bbce326 | ||
|
|
41f89a35fa | ||
|
|
099d7874d7 | ||
|
|
e2367103a1 | ||
|
|
37f8ba7d72 | ||
|
|
c20bd84edd | ||
|
|
b4ee0d2487 | ||
|
|
420fa7645f | ||
|
|
5bb1e72760 | ||
|
|
e2a007b62a | ||
|
|
210813367f | ||
|
|
770a50764e | ||
|
|
e339a22aa4 | ||
|
|
913afed378 | ||
|
|
db3efb4452 | ||
|
|
840351acb7 | ||
|
|
da76a7f299 | ||
|
|
cbd999f88d | ||
|
|
2fa8a266c5 | ||
|
|
08aa749a53 | ||
|
|
2379f04d2a | ||
|
|
0e73598d1c | ||
|
|
964e6eb0e8 | ||
|
|
0430e6c6d4 | ||
|
|
db88358eca | ||
|
|
723e9b0018 | ||
|
|
f3db27a8da | ||
|
|
0fb7a73fc9 | ||
|
|
418e6bd085 | ||
|
|
5a5c4ace6b | ||
|
|
c2c8214075 | ||
|
|
e5d2ade6e6 | ||
|
|
e32b6e07b4 | ||
|
|
cc69d3b8d1 | ||
|
|
1dd3af44b5 | ||
|
|
8ab233baef | ||
|
|
104138b9a7 | ||
|
|
0c8fd5121a | ||
|
|
61f26d331b | ||
|
|
97817cd808 | ||
|
|
45bcc63c06 | ||
|
|
00779d0f10 | ||
|
|
d657bf8ed8 | ||
|
|
4fcdd05e6a | ||
|
|
e6916946a9 | ||
|
|
acd7013dc6 | ||
|
|
039d876e3f | ||
|
|
3fc2c7d6cc | ||
|
|
109164b673 | ||
|
|
673a03e656 | ||
|
|
1e976e6d96 | ||
|
|
8efba30adb | ||
|
|
713d44eac3 | ||
|
|
aea44c1d97 | ||
|
|
1e61e60d73 | ||
|
|
a0e4b4a56e | ||
|
|
983f8fcb03 | ||
|
|
6afdde7dc1 | ||
|
|
6873de7243 | ||
|
|
ee4d6d0db3 | ||
|
|
dee1212a76 | ||
|
|
ceda69aedd | ||
|
|
75ea7d7601 | ||
|
|
8b75d2312c | ||
|
|
ca51880798 | ||
|
|
8b708e8939 | ||
|
|
b6ff9f7196 | ||
|
|
67229fd032 | ||
|
|
d382eab355 | ||
|
|
d8f10e9ac4 | ||
|
|
749aaeb003 | ||
|
|
c5a3bbcecf | ||
|
|
27ac41531b | ||
|
|
423c9af786 | ||
|
|
232759829e | ||
|
|
71f7bc7b1b | ||
|
|
ae4f03e272 | ||
|
|
acb5a7e50b | ||
|
|
c8749b3c9c | ||
|
|
49647e3bb5 | ||
|
|
48d353aa90 | ||
|
|
edec18cacb | ||
|
|
cd8661abc1 | ||
|
|
5f6310f5d6 | ||
|
|
42d955b175 | ||
|
|
21541bc468 | ||
|
|
f14f4e1e9b | ||
|
|
6d1de8a2e4 | ||
|
|
0053d31f84 | ||
|
|
f077a9684b | ||
|
|
2428d58e93 | ||
|
|
5340e3a0a7 | ||
|
|
70dd8f0f1d | ||
|
|
8fa76504c3 | ||
|
|
0899cb4e1d | ||
|
|
ee7a2a70a6 | ||
|
|
d57d1ac15e | ||
|
|
68c29d89c9 | ||
|
|
721648ffdf | ||
|
|
8437f39bf6 | ||
|
|
48b15c60e7 | ||
|
|
e350122125 | ||
|
|
0cce97f373 | ||
|
|
d8cacc0811 | ||
|
|
7abaf70bb8 | ||
|
|
232fe4d15e | ||
|
|
d6d12c0335 | ||
|
|
8e4f12804b | ||
|
|
c21ba5c521 | ||
|
|
dfa3d47261 | ||
|
|
924f59afff | ||
|
|
673b282d6c | ||
|
|
1c761f89e5 | ||
|
|
f61cd969b9 | ||
|
|
e39a130306 | ||
|
|
13b6ea985e | ||
|
|
2f1e55fa1e | ||
|
|
776f629771 | ||
|
|
d9e9edb2c4 | ||
|
|
753c074e59 | ||
|
|
d92c82775a | ||
|
|
215cc09c1f | ||
|
|
7f302c13c7 | ||
|
|
de6a094d10 | ||
|
|
a94e1a8314 | ||
|
|
f5efdd665b | ||
|
|
43e25e8717 | ||
|
|
a8026fefc1 | ||
|
|
fdb36957c9 | ||
|
|
ea433ff807 | ||
|
|
8902fb50d6 | ||
|
|
b6aa013eb3 | ||
|
|
034b43bf70 | ||
|
|
59e9032286 | ||
|
|
52a98efd0a | ||
|
|
90cc91aa7f | ||
|
|
1973a26e83 | ||
|
|
6519ad25ca | ||
|
|
cacfde8166 | ||
|
|
df85873726 | ||
|
|
dfea294cc9 | ||
|
|
d35b855404 | ||
|
|
7a1cbf70e3 | ||
|
|
f260990b86 | ||
|
|
6affbe9b55 | ||
|
|
dbe3a10697 | ||
|
|
3c25306a5d | ||
|
|
17f4d49731 | ||
|
|
e213b5cc64 | ||
|
|
65e5dad44b | ||
|
|
62ad38ea5d | ||
|
|
f98f4c1f77 | ||
|
|
e9f02b58b7 | ||
|
|
05495e481d | ||
|
|
5bb2167b78 | ||
|
|
b4e0ed66cf | ||
|
|
70a0563435 | ||
|
|
955912b832 | ||
|
|
b65ee75b3d | ||
|
|
f642493a38 | ||
|
|
7f1bfb1e07 | ||
|
|
8931e2e016 | ||
|
|
0465fa77c2 | ||
|
|
575d503cb9 | ||
|
|
a4fdbdb9ad | ||
|
|
b9cb781a4e | ||
|
|
a3adf867b7 | ||
|
|
d52cbd2f74 | ||
|
|
8d0003db94 | ||
|
|
b775e89e77 | ||
|
|
0e14b097ba | ||
|
|
51848b8d8d | ||
|
|
72658c3e60 | ||
|
|
036cb6f3b0 | ||
|
|
1a86d96bfa | ||
|
|
f67db38a25 | ||
|
|
028d18826a | ||
|
|
29a605f265 | ||
|
|
4b6959470d | ||
|
|
600767d2bf | ||
|
|
3efbd47ffd | ||
|
|
d17e85217b | ||
|
|
e608089805 | ||
|
|
b852acec28 | ||
|
|
2a3ea8315d | ||
|
|
9271ee833c | ||
|
|
570d4ad1a3 | ||
|
|
dccdf3231a | ||
|
|
b8ee777fd2 | ||
|
|
a2fd3a8d90 | ||
|
|
bbffb1420b | ||
|
|
8ea0a32879 | ||
|
|
8c27b8c33e | ||
|
|
5c61b22c2f | ||
|
|
9da9d765a0 | ||
|
|
f64363728e | ||
|
|
378777dc7c | ||
|
|
6156b9a481 | ||
|
|
8c516c5691 | ||
|
|
bf9a149898 | ||
|
|
277cde8db2 | ||
|
|
e06bdaf53e | ||
|
|
da367bd138 | ||
|
|
d336bcbf1f | ||
|
|
a8aedba6ff | ||
|
|
9ede86c6a3 | ||
|
|
1468f2b082 | ||
|
|
e04ae70f89 | ||
|
|
7f7d2c9ba8 | ||
|
|
d73deef8dc | ||
|
|
f93a1540af | ||
|
|
c8bd9cb716 | ||
|
|
2ed13c7e5b | ||
|
|
647c0929c5 | ||
|
|
a61533a131 | ||
|
|
bc5e682308 | ||
|
|
25a481df12 | ||
|
|
764c10fae4 | ||
|
|
d8249d4e38 | ||
|
|
0e3e42b398 | ||
|
|
7d3b64dcf9 | ||
|
|
2c8d525796 | ||
|
|
4869f071ab | ||
|
|
3029eeaf6f | ||
|
|
33fb692aee | ||
|
|
6a075d144f | ||
|
|
aa23315599 | ||
|
|
8d0bb35505 | ||
|
|
32e76bc6ce | ||
|
|
6c02766000 | ||
|
|
52ef390464 | ||
|
|
43a557601e | ||
|
|
82ff7fc090 | ||
|
|
db40b5105b | ||
|
|
b2a379b84b | ||
|
|
97cbd816fe | ||
|
|
7de3bb2a91 | ||
|
|
3a8a2bcab4 | ||
|
|
eb1adbe992 | ||
|
|
b55966d42b | ||
|
|
451ca9cb5a | ||
|
|
1e2c607ced | ||
|
|
5ff7da0d19 | ||
|
|
8e06c6f8e6 | ||
|
|
4497cd3904 | ||
|
|
2945679a94 | ||
|
|
1eaf7e3c85 | ||
|
|
8146b680c6 | ||
|
|
99e667382f | ||
|
|
4c03759d3f | ||
|
|
8593a6cdd0 | ||
|
|
cd18c31618 | ||
|
|
f29c918700 | ||
|
|
0f0c3e660b | ||
|
|
1cf4639db3 | ||
|
|
f5da9b5780 | ||
|
|
e4c87c8a96 | ||
|
|
4b4bf153f0 | ||
|
|
ec227d0d56 | ||
|
|
53c8c50779 | ||
|
|
07b4c8b462 | ||
|
|
f3cfc5b9f0 |
@@ -1,3 +1,84 @@
|
||||
# Ignore git
|
||||
# Git
|
||||
.github
|
||||
.git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
docs/
|
||||
README.md
|
||||
LICENSE
|
||||
|
||||
# Development files
|
||||
.pylintrc
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
*.so
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.hypothesis/
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Test files
|
||||
tests/
|
||||
test_*
|
||||
*_test.py
|
||||
|
||||
# Build artifacts
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
|
||||
# Other
|
||||
app.ico
|
||||
frozen.spec
|
||||
2
.github/ISSUE_TEMPLATE/rfc.yml
vendored
2
.github/ISSUE_TEMPLATE/rfc.yml
vendored
@@ -10,7 +10,7 @@ body:
|
||||
目的是让协作的开发者间清晰的知道「要做什么」和「具体会怎么做」,以及所有的开发者都能公开透明的参与讨论;
|
||||
以便评估和讨论产生的影响 (遗漏的考虑、向后兼容性、与现有功能的冲突),
|
||||
因此提案侧重在对解决问题的 **方案、设计、步骤** 的描述上。
|
||||
|
||||
|
||||
如果仅希望讨论是否添加或改进某功能本身,请使用 -> [Issue: 功能改进](https://github.com/jxxghp/MoviePilot/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+)
|
||||
- type: textarea
|
||||
id: background
|
||||
|
||||
60
.github/workflows/beta.yml
vendored
Normal file
60
.github/workflows/beta.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: MoviePilot Builder Beta
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
Docker-build:
|
||||
runs-on: ubuntu-latest
|
||||
name: Build Docker Image
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Release version
|
||||
id: release_version
|
||||
run: |
|
||||
app_version=$(cat version.py |sed -ne "s/APP_VERSION\s=\s'v\(.*\)'/\1/gp")
|
||||
echo "app_version=$app_version" >> $GITHUB_ENV
|
||||
|
||||
- name: Docker Meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
|
||||
ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=raw,value=beta
|
||||
|
||||
- name: Set Up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set Up Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: |
|
||||
linux/amd64
|
||||
linux/arm64/v8
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-to: type=gha, scope=${{ github.workflow }}-docker
|
||||
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
with:
|
||||
images: |
|
||||
${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
|
||||
${{ secrets.DOCKER_USERNAME }}/moviepilot
|
||||
ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.app_version }}
|
||||
@@ -92,6 +93,6 @@ jobs:
|
||||
body: ${{ env.RELEASE_BODY }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
make_latest: false
|
||||
make_latest: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
24
.github/workflows/pylint.yml
vendored
24
.github/workflows/pylint.yml
vendored
@@ -8,17 +8,17 @@ jobs:
|
||||
pylint:
|
||||
runs-on: ubuntu-latest
|
||||
name: Pylint Code Quality Check
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip'
|
||||
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt', '**/requirements.in') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
else
|
||||
echo "⚠️ 未找到依赖文件,仅安装 pylint"
|
||||
fi
|
||||
|
||||
|
||||
- name: Verify pylint config
|
||||
run: |
|
||||
# 检查项目中的pylint配置文件是否存在
|
||||
@@ -57,35 +57,35 @@ jobs:
|
||||
run: |
|
||||
# 运行pylint,检查主要的Python文件
|
||||
echo "🚀 运行 Pylint 错误检查..."
|
||||
|
||||
|
||||
# 检查主要目录 - 只关注错误,如果有错误则退出
|
||||
echo "📂 检查 app/ 目录..."
|
||||
pylint app/ --output-format=colorized --reports=yes --score=yes
|
||||
|
||||
|
||||
# 检查根目录的Python文件
|
||||
echo "📂 检查根目录 Python 文件..."
|
||||
for file in $(find . -name "*.py" -not -path "./.*" -not -path "./.venv/*" -not -path "./build/*" -not -path "./dist/*" -not -path "./tests/*" -not -path "./docs/*" -not -path "./__pycache__/*" -maxdepth 1); do
|
||||
echo "检查文件: $file"
|
||||
pylint "$file" --output-format=colorized || exit 1
|
||||
done
|
||||
|
||||
|
||||
# 生成详细报告
|
||||
echo "📊 生成 Pylint 详细报告..."
|
||||
pylint app/ --output-format=json > pylint-report.json || true
|
||||
|
||||
|
||||
# 显示评分(仅供参考)
|
||||
echo "📈 Pylint 评分(仅供参考):"
|
||||
pylint app/ --score=yes --reports=no | tail -2 || true
|
||||
|
||||
|
||||
- name: Upload pylint report
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: pylint-report
|
||||
path: pylint-report.json
|
||||
|
||||
|
||||
- name: Summary
|
||||
run: |
|
||||
echo "🎉 Pylint 检查完成!"
|
||||
echo "✅ 没有发现语法错误或严重问题"
|
||||
echo "📊 详细报告已保存为构建工件"
|
||||
echo "📊 详细报告已保存为构建工件"
|
||||
@@ -12,7 +12,7 @@ jobs=0
|
||||
# 只关注错误级别的问题,禁用警告、约定和重构建议
|
||||
# E = Error (错误) - 会导致构建失败
|
||||
# W = Warning (警告) - 仅显示,不会失败
|
||||
# R = Refactor (重构建议) - 仅显示,不会失败
|
||||
# R = Refactor (重构建议) - 仅显示,不会失败
|
||||
# C = Convention (约定) - 仅显示,不会失败
|
||||
# I = Information (信息) - 仅显示,不会失败
|
||||
|
||||
@@ -80,4 +80,4 @@ ignore-imports=yes
|
||||
|
||||
[TYPECHECK]
|
||||
# 生成缺失成员提示的类列表
|
||||
generated-members=requests.packages.urllib3
|
||||
generated-members=requests.packages.urllib3
|
||||
27
README.md
27
README.md
@@ -18,17 +18,19 @@
|
||||
|
||||
## 主要特性
|
||||
|
||||
- 前后端分离,基于FastApi + Vue3,前端项目地址:[MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend),API:http://localhost:3001/docs
|
||||
- 前后端分离,基于FastApi + Vue3。
|
||||
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
|
||||
- 重新设计了用户界面,更加美观易用。
|
||||
|
||||
## 安装使用
|
||||
|
||||
访问官方Wiki:https://wiki.movie-pilot.org
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
|
||||
## 参与开发
|
||||
|
||||
需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
```shell
|
||||
@@ -38,10 +40,11 @@ git clone https://github.com/jxxghp/MoviePilot
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,设置`app`为源代码根目录,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 main.py
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
@@ -54,6 +57,20 @@ yarn dev
|
||||
```
|
||||
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
|
||||
|
||||
## 相关项目
|
||||
|
||||
- [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
|
||||
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
|
||||
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
|
||||
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
|
||||
|
||||
## 免责申明
|
||||
|
||||
- 本软件仅供学习交流使用,任何人不得将本软件用于商业用途,任何人不得将本软件用于违法犯罪活动,软件对用户行为不知情,一切责任由使用者承担。
|
||||
- 本软件代码开源,基于开源代码进行修改,人为去除相关限制导致软件被分发、传播并造成责任事件的,需由代码修改发布者承担全部责任,不建议对用户认证机制进行规避或修改并公开发布。
|
||||
- 本项目不接受捐赠,没有在任何地方发布捐赠信息页面,软件本身不收费也不提供任何收费相关服务,请仔细辨别避免误导。
|
||||
|
||||
## 贡献者
|
||||
|
||||
<a href="https://github.com/jxxghp/MoviePilot/graphs/contributors">
|
||||
|
||||
355
app/agent/__init__.py
Normal file
355
app/agent/__init__.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
|
||||
return agent_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
agent.source = source
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
33
app/agent/callback/__init__.py
Normal file
33
app/agent/callback/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import threading
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
280
app/agent/memory/__init__.py
Normal file
280
app/agent/memory/__init__.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 每小时清理一次
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
current_time = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
for cache_key in expired_sessions:
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
70
app/agent/prompt/Agent Prompt.txt
Normal file
70
app/agent/prompt/Agent Prompt.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
|
||||
## Working Principles
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
|
||||
## Common Operation Workflows
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
118
app/agent/prompt/__init__.py
Normal file
118
app/agent/prompt/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
31
app/agent/tools/__init__.py
Normal file
31
app/agent/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""MoviePilot工具模块"""
|
||||
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from .factory import MoviePilotToolFactory
|
||||
|
||||
__all__ = [
|
||||
"MoviePilotTool",
|
||||
"SearchMediaTool",
|
||||
"AddSubscribeTool",
|
||||
"SearchTorrentsTool",
|
||||
"AddDownloadTool",
|
||||
"QuerySubscribesTool",
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"QuerySitesTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
"MoviePilotToolFactory"
|
||||
]
|
||||
73
app/agent/tools/base.py
Normal file
73
app/agent/tools/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""MoviePilot工具基类"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.chain import ChainBase
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
await self.send_tool_message(f"▶️️{explanation}")
|
||||
return await self.run(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
),
|
||||
escape_markdown=False
|
||||
)
|
||||
84
app/agent/tools/factory.py
Normal file
84
app/agent/tools/factory.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
AddSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
for plugin_info in plugin_tools_info:
|
||||
plugin_id = plugin_info.get("plugin_id")
|
||||
plugin_name = plugin_info.get("plugin_name")
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
0
app/agent/tools/impl/__init__.py
Normal file
0
app/agent/tools/impl/__init__.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
92
app/agent/tools/impl/add_download.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
60
app/agent/tools/impl/add_subscribe.py
Normal file
60
app/agent/tools/impl/add_subscribe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id
|
||||
)
|
||||
if sid:
|
||||
return f"成功添加订阅:{title} ({year})"
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加订阅失败: {e}", exc_info=True)
|
||||
return f"添加订阅时发生错误: {str(e)}"
|
||||
84
app/agent/tools/impl/get_recommendations.py
Normal file
84
app/agent/tools/impl/get_recommendations.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""获取推荐工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
name_dicts = {
|
||||
"tmdb_trending": "TMDB 热门推荐",
|
||||
"douban_hot": "豆瓣热门推荐",
|
||||
"bangumi_calendar": "番组计划推荐"
|
||||
}
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
results = await recommend_chain.async_tmdb_trending(limit=limit)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(limit=limit)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(limit=limit)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(limit=limit)
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
# r 已经是字典格式(to_dict的结果)
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"overview": r.get("overview", "")[:200] + "..." if r.get("overview") and len(r.get("overview", "")) > 200 else r.get("overview"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
return f"获取推荐时发生错误: {str(e)}"
|
||||
34
app/agent/tools/impl/query_downloaders.py
Normal file
34
app/agent/tools/impl/query_downloaders.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryDownloadersInput(BaseModel):
|
||||
"""查询下载器工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
|
||||
if downloaders_config:
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载器失败: {e}")
|
||||
return f"查询下载器时发生错误: {str(e)}"
|
||||
80
app/agent/tools/impl/query_downloads.py
Normal file
80
app/agent/tools/impl/query_downloads.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDownloadsInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadsInput
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
# 使用 DownloadChain.downloading 方法获取正在下载的任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
if status != "all" and dl.status != status:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
41
app/agent/tools/impl/query_media_library.py
Normal file
41
app/agent/tools/impl/query_media_library.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class QueryMediaLibraryInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryMediaLibraryInput
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_server_oper = MediaServerOper()
|
||||
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.to_dict() for m in filtered_medias])
|
||||
return "媒体库中未找到相关媒体"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
66
app/agent/tools/impl/query_sites.py
Normal file
66
app/agent/tools/impl/query_sites.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
# 获取所有站点(按优先级排序)
|
||||
sites = await site_oper.async_list()
|
||||
filtered_sites = []
|
||||
for site in sites:
|
||||
# 按状态过滤
|
||||
if status == "active" and not site.is_active:
|
||||
continue
|
||||
if status == "inactive" and site.is_active:
|
||||
continue
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and name.lower() not in (site.name or "").lower():
|
||||
continue
|
||||
filtered_sites.append(site)
|
||||
if filtered_sites:
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_sites = []
|
||||
for s in filtered_sites:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"domain": s.domain,
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关站点"
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
73
app/agent/tools/impl/query_subscribes.py
Normal file
73
app/agent/tools/impl/query_subscribes.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""查询订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
return f"查询订阅时发生错误: {str(e)}"
|
||||
96
app/agent/tools/impl/search_media.py
Normal file
96
app/agent/tools/impl/search_media.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""搜索媒体工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
"""搜索媒体工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 构建搜索标题
|
||||
search_title = title
|
||||
if year:
|
||||
search_title = f"{title} {year}"
|
||||
if media_type:
|
||||
search_title = f"{search_title} {media_type}"
|
||||
if season:
|
||||
search_title = f"{search_title} S{season:02d}"
|
||||
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=search_title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
simplified = {
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
"douban_id": r.douban_id,
|
||||
"overview": r.overview[:200] + "..." if r.overview and len(r.overview) > 200 else r.overview,
|
||||
"vote_average": r.vote_average,
|
||||
"poster_path": r.poster_path,
|
||||
"detail_link": r.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
else:
|
||||
return f"未找到相关媒体资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索媒体失败: {str(e)}"
|
||||
logger.error(f"搜索媒体失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
122
app/agent/tools/impl/search_torrents.py
Normal file
122
app/agent/tools/impl/search_torrents.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
31
app/agent/tools/impl/send_message.py
Normal file
31
app/agent/tools/impl/send_message.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return f"发送消息时发生错误: {str(e)}"
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, site, message, webhook, subscribe, \
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent
|
||||
|
||||
|
||||
@@ -11,63 +11,63 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
|
||||
def bangumi_credits(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_credits(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询Bangumi演职员表
|
||||
"""
|
||||
persons = BangumiChain().bangumi_credits(bangumiid)
|
||||
persons = await BangumiChain().async_bangumi_credits(bangumiid)
|
||||
if persons:
|
||||
return persons[(page - 1) * count: page * count]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo])
|
||||
def bangumi_recommend(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_recommend(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询Bangumi推荐
|
||||
"""
|
||||
medias = BangumiChain().bangumi_recommend(bangumiid)
|
||||
medias = await BangumiChain().async_bangumi_recommend(bangumiid)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
|
||||
def bangumi_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物详情
|
||||
"""
|
||||
return BangumiChain().person_detail(person_id=person_id)
|
||||
return await BangumiChain().async_person_detail(person_id=person_id)
|
||||
|
||||
|
||||
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
|
||||
def bangumi_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物参演作品
|
||||
"""
|
||||
medias = BangumiChain().person_credits(person_id=person_id)
|
||||
medias = await BangumiChain().async_person_credits(person_id=person_id)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo)
|
||||
def bangumi_info(bangumiid: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_info(bangumiid: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询Bangumi详情
|
||||
"""
|
||||
info = BangumiChain().bangumi_info(bangumiid)
|
||||
info = await BangumiChain().async_bangumi_info(bangumiid)
|
||||
if info:
|
||||
return MediaInfo(bangumi_info=info).to_dict()
|
||||
else:
|
||||
|
||||
@@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
|
||||
|
||||
@router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo])
|
||||
def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询后台服务信息
|
||||
"""
|
||||
@@ -119,24 +119,25 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/schedule2", summary="后台服务(API_TOKEN)", response_model=List[schemas.ScheduleInfo])
|
||||
def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询下载器信息 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return schedule()
|
||||
return await schedule()
|
||||
|
||||
|
||||
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
|
||||
def transfer(days: Optional[int] = 7, db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def transfer(days: Optional[int] = 7,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询文件整理统计信息
|
||||
"""
|
||||
transfer_stat = TransferHistory.statistic(db, days)
|
||||
transfer_stat = await TransferHistory.async_statistic(db, days)
|
||||
return [stat[1] for stat in transfer_stat]
|
||||
|
||||
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=int)
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=float)
|
||||
def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率
|
||||
@@ -144,7 +145,7 @@ def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.cpu_usage()
|
||||
|
||||
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=int)
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=float)
|
||||
def cpu2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率 API_TOKEN认证(?token=xxx)
|
||||
@@ -166,3 +167,19 @@ def memory2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
获取当前内存使用率 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return memory()
|
||||
|
||||
|
||||
@router.get("/network", summary="获取当前网络流量", response_model=List[int])
|
||||
def network(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取当前网络流量(上行和下行流量,单位:bytes/s)
|
||||
"""
|
||||
return SystemUtils.network_usage()
|
||||
|
||||
|
||||
@router.get("/network2", summary="获取当前网络流量(API_TOKEN)", response_model=List[int])
|
||||
def network2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前网络流量 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return network()
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import Any, List, Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app import schemas
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.schemas import DiscoverSourceEventData
|
||||
from app.schemas.types import ChainEventType, MediaType
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -31,100 +31,100 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
|
||||
def bangumi(type: Optional[int] = 2,
|
||||
cat: Optional[int] = None,
|
||||
sort: Optional[str] = 'rank',
|
||||
year: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi(type: Optional[int] = 2,
|
||||
cat: Optional[int] = None,
|
||||
sort: Optional[str] = 'rank',
|
||||
year: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
探索Bangumi
|
||||
"""
|
||||
medias = BangumiChain().discover(type=type, cat=cat, sort=sort, year=year,
|
||||
limit=count, offset=(page - 1) * count)
|
||||
medias = await BangumiChain().async_discover(type=type, cat=cat, sort=sort, year=year,
|
||||
limit=count, offset=(page - 1) * count)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
|
||||
def douban_movies(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_movies(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣电影信息
|
||||
"""
|
||||
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
movies = await DoubanChain().async_douban_discover(mtype=MediaType.MOVIE,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
|
||||
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
|
||||
def douban_tvs(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_tvs(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
tvs = await DoubanChain().async_douban_discover(mtype=MediaType.TV,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
|
||||
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览TMDB电影信息
|
||||
"""
|
||||
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
movies = await TmdbChain().async_tmdb_discover(mtype=MediaType.MOVIE,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return [movie.to_dict() for movie in movies] if movies else []
|
||||
|
||||
|
||||
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览TMDB剧集信息
|
||||
"""
|
||||
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
tvs = await TmdbChain().async_tmdb_discover(mtype=MediaType.TV,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return [tv.to_dict() for tv in tvs] if tvs else []
|
||||
|
||||
@@ -12,54 +12,54 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
|
||||
def douban_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物详情
|
||||
"""
|
||||
return DoubanChain().person_detail(person_id=person_id)
|
||||
return await DoubanChain().async_person_detail(person_id=person_id)
|
||||
|
||||
|
||||
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
|
||||
def douban_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物参演作品
|
||||
"""
|
||||
medias = DoubanChain().person_credits(person_id=person_id, page=page)
|
||||
medias = await DoubanChain().async_person_credits(person_id=person_id, page=page)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson])
|
||||
def douban_credits(doubanid: str,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_credits(doubanid: str,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据豆瓣ID查询演员阵容,type_name: 电影/电视剧
|
||||
"""
|
||||
mediatype = MediaType(type_name)
|
||||
if mediatype == MediaType.MOVIE:
|
||||
return DoubanChain().movie_credits(doubanid=doubanid)
|
||||
return await DoubanChain().async_movie_credits(doubanid=doubanid)
|
||||
elif mediatype == MediaType.TV:
|
||||
return DoubanChain().tv_credits(doubanid=doubanid)
|
||||
return await DoubanChain().async_tv_credits(doubanid=doubanid)
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/recommend/{doubanid}/{type_name}", summary="豆瓣推荐电影/电视剧", response_model=List[schemas.MediaInfo])
|
||||
def douban_recommend(doubanid: str,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_recommend(doubanid: str,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据豆瓣ID查询推荐电影/电视剧,type_name: 电影/电视剧
|
||||
"""
|
||||
mediatype = MediaType(type_name)
|
||||
if mediatype == MediaType.MOVIE:
|
||||
medias = DoubanChain().movie_recommend(doubanid=doubanid)
|
||||
medias = await DoubanChain().async_movie_recommend(doubanid=doubanid)
|
||||
elif mediatype == MediaType.TV:
|
||||
medias = DoubanChain().tv_recommend(doubanid=doubanid)
|
||||
medias = await DoubanChain().async_tv_recommend(doubanid=doubanid)
|
||||
else:
|
||||
return []
|
||||
if medias:
|
||||
@@ -68,12 +68,12 @@ def douban_recommend(doubanid: str,
|
||||
|
||||
|
||||
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
|
||||
def douban_info(doubanid: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_info(doubanid: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据豆瓣ID查询豆瓣媒体信息
|
||||
"""
|
||||
doubaninfo = DoubanChain().douban_info(doubanid=doubanid)
|
||||
doubaninfo = await DoubanChain().async_douban_info(doubanid=doubanid)
|
||||
if doubaninfo:
|
||||
return MediaInfo(douban_info=doubaninfo).to_dict()
|
||||
else:
|
||||
|
||||
@@ -40,10 +40,12 @@ def download(
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 手动下载始终使用选择的下载器
|
||||
torrentinfo.site_downloader = downloader
|
||||
# 上下文
|
||||
context = Context(
|
||||
meta_info=metainfo,
|
||||
@@ -51,7 +53,7 @@ def download(
|
||||
torrent_info=torrentinfo
|
||||
)
|
||||
did = DownloadChain().download_single(context=context, username=current_user.name,
|
||||
downloader=downloader, save_path=save_path, source="Manual")
|
||||
save_path=save_path, source="Manual")
|
||||
if not did:
|
||||
return schemas.Response(success=False, message="任务添加失败")
|
||||
return schemas.Response(success=True, data={
|
||||
@@ -62,6 +64,9 @@ def download(
|
||||
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
|
||||
def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
bangumiid: Annotated[int | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
@@ -71,12 +76,12 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 上下文
|
||||
context = Context(
|
||||
meta_info=metainfo,
|
||||
@@ -114,7 +119,7 @@ def stop(hashString: str, name: Optional[str] = None,
|
||||
|
||||
|
||||
@router.get("/clients", summary="查询可用下载器", response_model=List[dict])
|
||||
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询可用下载器
|
||||
"""
|
||||
|
||||
@@ -2,51 +2,52 @@ from typing import List, Any, Optional
|
||||
|
||||
import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_db
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.schemas.types import EventType, MediaType
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
|
||||
def download_history(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def download_history(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询下载历史记录
|
||||
"""
|
||||
return DownloadHistory.list_by_page(db, page, count)
|
||||
return await DownloadHistory.async_list_by_page(db, page, count)
|
||||
|
||||
|
||||
@router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response)
|
||||
def delete_download_history(history_in: schemas.DownloadHistory,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def delete_download_history(history_in: schemas.DownloadHistory,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
删除下载历史记录
|
||||
"""
|
||||
DownloadHistory.delete(db, history_in.id)
|
||||
await DownloadHistory.async_delete(db, history_in.id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/transfer", summary="查询整理记录", response_model=schemas.Response)
|
||||
def transfer_history(title: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
status: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def transfer_history(title: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
status: Optional[bool] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询整理记录
|
||||
"""
|
||||
@@ -60,16 +61,16 @@ def transfer_history(title: Optional[str] = None,
|
||||
if title:
|
||||
words = jieba.cut(title, HMM=False)
|
||||
title = "%".join(words)
|
||||
total = TransferHistory.count_by_title(db, title=title, status=status)
|
||||
result = TransferHistory.list_by_title(db, title=title, page=page,
|
||||
count=count, status=status)
|
||||
total = await TransferHistory.async_count_by_title(db, title=title, status=status)
|
||||
result = await TransferHistory.async_list_by_title(db, title=title, page=page,
|
||||
count=count, status=status)
|
||||
else:
|
||||
result = TransferHistory.list_by_page(db, page=page, count=count, status=status)
|
||||
total = TransferHistory.count(db, status=status)
|
||||
result = await TransferHistory.async_list_by_page(db, page=page, count=count, status=status)
|
||||
total = await TransferHistory.async_count(db, status=status)
|
||||
|
||||
return schemas.Response(success=True,
|
||||
data={
|
||||
"list": result,
|
||||
"list": [item.to_dict() for item in result],
|
||||
"total": total,
|
||||
})
|
||||
|
||||
@@ -79,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
deletesrc: Optional[bool] = False,
|
||||
deletedest: Optional[bool] = False,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
删除整理记录
|
||||
"""
|
||||
@@ -89,7 +90,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
# 册除媒体库文件
|
||||
if deletedest and history.dest_fileitem:
|
||||
dest_fileitem = schemas.FileItem(**history.dest_fileitem)
|
||||
StorageChain().delete_media_file(fileitem=dest_fileitem, mtype=MediaType(history.type))
|
||||
StorageChain().delete_media_file(dest_fileitem)
|
||||
|
||||
# 删除源文件
|
||||
if deletesrc and history.src_fileitem:
|
||||
@@ -111,10 +112,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
|
||||
|
||||
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
|
||||
def delete_transfer_history(db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
清空整理记录
|
||||
"""
|
||||
TransferHistory.truncate(db)
|
||||
await TransferHistory.async_truncate(db)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -8,8 +8,10 @@ from app import schemas
|
||||
from app.chain.user import UserChain
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -29,7 +31,10 @@ def login_access_token(
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
# 是否显示配置向导
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user_or_message.id,
|
||||
@@ -44,7 +49,8 @@ def login_access_token(
|
||||
user_name=user_or_message.name,
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions= user_or_message.permissions or {},
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,61 +18,61 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context)
|
||||
def recognize(title: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def recognize(title: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据标题、副标题识别媒体信息
|
||||
"""
|
||||
# 识别媒体信息
|
||||
metainfo = MetaInfo(title, subtitle)
|
||||
mediainfo = MediaChain().recognize_by_meta(metainfo)
|
||||
mediainfo = await MediaChain().async_recognize_by_meta(metainfo)
|
||||
if mediainfo:
|
||||
return Context(meta_info=metainfo, media_info=mediainfo).to_dict()
|
||||
return schemas.Context()
|
||||
|
||||
|
||||
@router.get("/recognize2", summary="识别种子媒体信息(API_TOKEN)", response_model=schemas.Context)
|
||||
def recognize2(_: Annotated[str, Depends(verify_apitoken)],
|
||||
title: str,
|
||||
subtitle: Optional[str] = None
|
||||
) -> Any:
|
||||
async def recognize2(_: Annotated[str, Depends(verify_apitoken)],
|
||||
title: str,
|
||||
subtitle: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
根据标题、副标题识别媒体信息 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
# 识别媒体信息
|
||||
return recognize(title, subtitle)
|
||||
return await recognize(title, subtitle)
|
||||
|
||||
|
||||
@router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context)
|
||||
def recognize_file(path: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def recognize_file(path: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据文件路径识别媒体信息
|
||||
"""
|
||||
# 识别媒体信息
|
||||
context = MediaChain().recognize_by_path(path)
|
||||
context = await MediaChain().async_recognize_by_path(path)
|
||||
if context:
|
||||
return context.to_dict()
|
||||
return schemas.Context()
|
||||
|
||||
|
||||
@router.get("/recognize_file2", summary="识别文件媒体信息(API_TOKEN)", response_model=schemas.Context)
|
||||
def recognize_file2(path: str,
|
||||
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
async def recognize_file2(path: str,
|
||||
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
根据文件路径识别媒体信息 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
# 识别媒体信息
|
||||
return recognize_file(path)
|
||||
return await recognize_file(path)
|
||||
|
||||
|
||||
@router.get("/search", summary="搜索媒体/人物信息", response_model=List[dict])
|
||||
def search(title: str,
|
||||
type: Optional[str] = "media",
|
||||
page: int = 1,
|
||||
count: int = 8,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def search(title: str,
|
||||
type: Optional[str] = "media",
|
||||
page: int = 1,
|
||||
count: int = 8,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
模糊搜索媒体/人物信息列表 media:媒体信息,person:人物信息
|
||||
"""
|
||||
@@ -86,14 +86,15 @@ def search(title: str,
|
||||
return obj.source
|
||||
|
||||
result = []
|
||||
media_chain = MediaChain()
|
||||
if type == "media":
|
||||
_, medias = MediaChain().search(title=title)
|
||||
_, medias = await media_chain.async_search(title=title)
|
||||
if medias:
|
||||
result = [media.to_dict() for media in medias]
|
||||
elif type == "collection":
|
||||
result = MediaChain().search_collections(name=title)
|
||||
result = await media_chain.async_search_collections(name=title)
|
||||
else:
|
||||
result = MediaChain().search_persons(name=title)
|
||||
result = await media_chain.async_search_persons(name=title)
|
||||
if result:
|
||||
# 按设置的顺序对结果进行排序
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') or []
|
||||
@@ -101,7 +102,8 @@ def search(title: str,
|
||||
for index, source in enumerate(setting_order):
|
||||
sort_order[source] = index
|
||||
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return result[(page - 1) * count:page * count]
|
||||
return result[(page - 1) * count:page * count]
|
||||
return []
|
||||
|
||||
|
||||
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
|
||||
@@ -123,13 +125,13 @@ def scrape(fileitem: schemas.FileItem,
|
||||
if storage == "local":
|
||||
if not scrape_path.exists():
|
||||
return schemas.Response(success=False, message="刮削路径不存在")
|
||||
# 手动刮削
|
||||
# 手动刮削 (暂时使用同步版本,可以后续优化为异步)
|
||||
chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True)
|
||||
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
|
||||
|
||||
|
||||
@router.get("/category", summary="查询自动分类配置", response_model=dict)
|
||||
def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询自动分类配置
|
||||
"""
|
||||
@@ -137,37 +139,37 @@ def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/group/seasons/{episode_group}", summary="查询剧集组季信息", response_model=List[schemas.MediaSeason])
|
||||
def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询剧集组季信息(themoviedb)
|
||||
"""
|
||||
return TmdbChain().tmdb_group_seasons(group_id=episode_group)
|
||||
return await TmdbChain().async_tmdb_group_seasons(group_id=episode_group)
|
||||
|
||||
|
||||
@router.get("/groups/{tmdbid}", summary="查询媒体剧集组", response_model=List[dict])
|
||||
def seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def groups(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询媒体剧集组列表(themoviedb)
|
||||
"""
|
||||
mediainfo = MediaChain().recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
|
||||
mediainfo = await MediaChain().async_recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
|
||||
if not mediainfo:
|
||||
return []
|
||||
return mediainfo.episode_groups
|
||||
|
||||
|
||||
@router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason])
|
||||
def seasons(mediaid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
year: str = None,
|
||||
season: int = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def seasons(mediaid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
year: str = None,
|
||||
season: int = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询媒体季信息
|
||||
"""
|
||||
if mediaid:
|
||||
if mediaid.startswith("tmdb:"):
|
||||
tmdbid = int(mediaid[5:])
|
||||
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
|
||||
if seasons_info:
|
||||
if season:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
@@ -176,17 +178,17 @@ def seasons(mediaid: Optional[str] = None,
|
||||
meta = MetaInfo(title)
|
||||
if year:
|
||||
meta.year = year
|
||||
mediainfo = MediaChain().recognize_media(meta, mtype=MediaType.TV)
|
||||
mediainfo = await MediaChain().async_recognize_media(meta, mtype=MediaType.TV)
|
||||
if mediainfo:
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
seasons_info = TmdbChain().tmdb_seasons(tmdbid=mediainfo.tmdb_id)
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
|
||||
if seasons_info:
|
||||
if season:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
else:
|
||||
sea = season or 1
|
||||
return schemas.MediaSeason(
|
||||
return [schemas.MediaSeason(
|
||||
season_number=sea,
|
||||
poster_path=mediainfo.poster_path,
|
||||
name=f"第 {sea} 季",
|
||||
@@ -194,39 +196,40 @@ def seasons(mediaid: Optional[str] = None,
|
||||
overview=mediainfo.overview,
|
||||
vote_average=mediainfo.vote_average,
|
||||
episode_count=mediainfo.number_of_episodes
|
||||
)
|
||||
)]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
|
||||
def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据媒体ID查询themoviedb或豆瓣媒体信息,type_name: 电影/电视剧
|
||||
"""
|
||||
mtype = MediaType(type_name)
|
||||
mediainfo = None
|
||||
mediachain = MediaChain()
|
||||
if mediaid.startswith("tmdb:"):
|
||||
mediainfo = MediaChain().recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
|
||||
mediainfo = await mediachain.async_recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
|
||||
elif mediaid.startswith("douban:"):
|
||||
mediainfo = MediaChain().recognize_media(doubanid=mediaid[7:], mtype=mtype)
|
||||
mediainfo = await mediachain.async_recognize_media(doubanid=mediaid[7:], mtype=mtype)
|
||||
elif mediaid.startswith("bangumi:"):
|
||||
mediainfo = MediaChain().recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
|
||||
mediainfo = await mediachain.async_recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
|
||||
else:
|
||||
# 广播事件解析媒体信息
|
||||
event_data = MediaRecognizeConvertEventData(
|
||||
mediaid=mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data and event.event_data.media_dict:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
new_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
mediainfo = MediaChain().recognize_media(tmdbid=new_id, mtype=mtype)
|
||||
mediainfo = await mediachain.async_recognize_media(tmdbid=new_id, mtype=mtype)
|
||||
elif event_data.convert_type == "douban":
|
||||
mediainfo = MediaChain().recognize_media(doubanid=new_id, mtype=mtype)
|
||||
mediainfo = await mediachain.async_recognize_media(doubanid=new_id, mtype=mtype)
|
||||
elif title:
|
||||
# 使用名称识别兜底
|
||||
meta = MetaInfo(title)
|
||||
@@ -234,10 +237,10 @@ def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str
|
||||
meta.year = year
|
||||
if mtype:
|
||||
meta.type = mtype
|
||||
mediainfo = MediaChain().recognize_media(meta=meta)
|
||||
mediainfo = await mediachain.async_recognize_media(meta=meta)
|
||||
# 识别
|
||||
if mediainfo:
|
||||
MediaChain().obtain_images(mediainfo)
|
||||
await mediachain.async_obtain_images(mediainfo)
|
||||
return mediainfo.to_dict()
|
||||
|
||||
return schemas.MediaInfo()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
@@ -9,7 +9,7 @@ from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_db
|
||||
from app.db import get_async_db
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.db.models import MediaServerItem
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -43,13 +43,13 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
|
||||
|
||||
|
||||
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
|
||||
def exists_local(title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def exists_local(title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
判断本地是否存在
|
||||
"""
|
||||
@@ -59,7 +59,7 @@ def exists_local(title: Optional[str] = None,
|
||||
# 返回对象
|
||||
ret_info = {}
|
||||
# 本地数据库是否存在
|
||||
exist: MediaServerItem = MediaServerOper(db).exists(
|
||||
exist: MediaServerItem = await MediaServerOper(db).async_exists(
|
||||
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
|
||||
)
|
||||
if exist:
|
||||
@@ -79,7 +79,7 @@ def exists(media_in: schemas.MediaInfo,
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
@@ -108,7 +108,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
meta.year = media_in.year
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
@@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False,
|
||||
|
||||
|
||||
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
|
||||
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询可用媒体服务器
|
||||
"""
|
||||
|
||||
@@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Request
|
||||
from pywebpush import WebPushException, webpush
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db import get_db
|
||||
from app.db import get_async_db
|
||||
from app.db.models import User
|
||||
from app.db.models.message import Message
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
@@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
|
||||
|
||||
|
||||
@router.get("/web", summary="获取WEB消息", response_model=List[dict])
|
||||
def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
|
||||
db: Session = Depends(get_db),
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20):
|
||||
async def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20):
|
||||
"""
|
||||
获取WEB消息列表
|
||||
"""
|
||||
ret_messages = []
|
||||
messages = Message.list_by_page(db, page=page, count=count)
|
||||
messages = await Message.async_list_by_page(db, page=page, count=count)
|
||||
for message in messages:
|
||||
try:
|
||||
ret_messages.append(message.to_dict())
|
||||
@@ -128,11 +128,11 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None,
|
||||
|
||||
|
||||
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)
|
||||
def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
|
||||
async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.dict()
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
@@ -148,7 +148,7 @@ def send_notification(payload: schemas.SubscriptionMessage, _: schemas.TokenPayl
|
||||
try:
|
||||
webpush(
|
||||
subscription_info=sub,
|
||||
data=json.dumps(payload.dict()),
|
||||
data=json.dumps(payload.model_dump()),
|
||||
vapid_private_key=settings.VAPID.get("privateKey"),
|
||||
vapid_claims={
|
||||
"sub": settings.VAPID.get("subject")
|
||||
|
||||
@@ -2,17 +2,21 @@ import mimetypes
|
||||
import shutil
|
||||
from typing import Annotated, Any, List, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from starlette import status
|
||||
from starlette.responses import FileResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.command import Command
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_apikey, verify_token
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
from app.factory import app
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.log import logger
|
||||
@@ -136,13 +140,14 @@ def register_plugin(plugin_id: str):
|
||||
|
||||
|
||||
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
|
||||
def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
|
||||
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
|
||||
async def all_plugins(_: User = Depends(get_current_active_superuser_async),
|
||||
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
|
||||
"""
|
||||
查询所有插件清单,包括本地插件和在线插件,插件状态:installed, market, all
|
||||
"""
|
||||
# 本地插件
|
||||
local_plugins = PluginManager().get_local_plugins()
|
||||
plugin_manager = PluginManager()
|
||||
local_plugins = plugin_manager.get_local_plugins()
|
||||
# 已安装插件
|
||||
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
|
||||
if state == "installed":
|
||||
@@ -151,7 +156,7 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
|
||||
# 未安装的本地插件
|
||||
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
|
||||
# 在线插件
|
||||
online_plugins = PluginManager().get_online_plugins(force)
|
||||
online_plugins = await plugin_manager.async_get_online_plugins(force)
|
||||
if not online_plugins:
|
||||
# 没有获取在线插件
|
||||
if state == "market":
|
||||
@@ -184,7 +189,7 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
|
||||
|
||||
|
||||
@router.get("/installed", summary="已安装插件", response_model=List[str])
|
||||
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
查询用户已安装插件清单
|
||||
"""
|
||||
@@ -192,15 +197,15 @@ def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -
|
||||
|
||||
|
||||
@router.get("/statistic", summary="插件安装统计", response_model=dict)
|
||||
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
插件安装统计
|
||||
"""
|
||||
return PluginHelper().get_statistic()
|
||||
return await PluginHelper().async_get_statistic()
|
||||
|
||||
|
||||
@router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response)
|
||||
def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
重新加载插件
|
||||
"""
|
||||
@@ -212,22 +217,23 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_
|
||||
|
||||
|
||||
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
|
||||
def install(plugin_id: str,
|
||||
repo_url: Optional[str] = "",
|
||||
force: Optional[bool] = False,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def install(plugin_id: str,
|
||||
repo_url: Optional[str] = "",
|
||||
force: Optional[bool] = False,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
安装插件
|
||||
"""
|
||||
# 已安装插件
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
|
||||
plugin_helper = PluginHelper()
|
||||
if not force and plugin_id in PluginManager().get_plugin_ids():
|
||||
PluginHelper().install_reg(pid=plugin_id)
|
||||
await plugin_helper.async_install_reg(pid=plugin_id)
|
||||
else:
|
||||
# 插件不存在或需要强制安装,下载安装并注册插件
|
||||
if repo_url:
|
||||
state, msg = PluginHelper().install(pid=plugin_id, repo_url=repo_url)
|
||||
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
|
||||
# 安装失败则直接响应
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=msg)
|
||||
@@ -238,14 +244,14 @@ def install(plugin_id: str,
|
||||
if plugin_id not in install_plugins:
|
||||
install_plugins.append(plugin_id)
|
||||
# 保存设置
|
||||
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
await SystemConfigOper().async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
# 重新加载插件
|
||||
reload_plugin(plugin_id)
|
||||
await run_in_threadpool(reload_plugin, plugin_id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/remotes", summary="获取插件联邦组件列表", response_model=List[dict])
|
||||
def remotes(token: str) -> Any:
|
||||
async def remotes(token: str) -> Any:
|
||||
"""
|
||||
获取插件联邦组件列表
|
||||
"""
|
||||
@@ -256,11 +262,12 @@ def remotes(token: str) -> Any:
|
||||
|
||||
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
|
||||
def plugin_form(plugin_id: str,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
|
||||
_: User = Depends(get_current_active_superuser)) -> dict:
|
||||
"""
|
||||
根据插件ID获取插件配置表单或Vue组件URL
|
||||
"""
|
||||
plugin_instance = PluginManager().running_plugins.get(plugin_id)
|
||||
plugin_manager = PluginManager()
|
||||
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
|
||||
if not plugin_instance:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"插件 {plugin_id} 不存在或未加载")
|
||||
|
||||
@@ -271,7 +278,7 @@ def plugin_form(plugin_id: str,
|
||||
return {
|
||||
"render_mode": render_mode,
|
||||
"conf": conf,
|
||||
"model": PluginManager().get_plugin_config(plugin_id) or model
|
||||
"model": plugin_manager.get_plugin_config(plugin_id) or model
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} 调用方法 get_form 出错: {str(e)}")
|
||||
@@ -279,7 +286,7 @@ def plugin_form(plugin_id: str,
|
||||
|
||||
|
||||
@router.get("/page/{plugin_id}", summary="获取插件数据页面")
|
||||
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
|
||||
def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict:
|
||||
"""
|
||||
根据插件ID获取插件数据页面
|
||||
"""
|
||||
@@ -328,7 +335,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
|
||||
|
||||
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
|
||||
def reset_plugin(plugin_id: str,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
根据插件ID重置插件配置及数据
|
||||
"""
|
||||
@@ -343,7 +350,7 @@ def reset_plugin(plugin_id: str,
|
||||
|
||||
|
||||
@router.get("/file/{plugin_id}/{filepath:path}", summary="获取插件静态文件")
|
||||
def plugin_static_file(plugin_id: str, filepath: str):
|
||||
async def plugin_static_file(plugin_id: str, filepath: str):
|
||||
"""
|
||||
获取插件静态文件
|
||||
"""
|
||||
@@ -352,11 +359,11 @@ def plugin_static_file(plugin_id: str, filepath: str):
|
||||
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower()
|
||||
plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
|
||||
plugin_file_path = plugin_base_dir / filepath
|
||||
if not plugin_file_path.exists():
|
||||
if not await plugin_file_path.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
|
||||
if not plugin_file_path.is_file():
|
||||
if not await plugin_file_path.is_file():
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{plugin_file_path} 不是文件")
|
||||
|
||||
# 判断 MIME 类型
|
||||
@@ -371,14 +378,25 @@ def plugin_static_file(plugin_id: str, filepath: str):
|
||||
response_type = 'application/octet-stream'
|
||||
|
||||
try:
|
||||
return FileResponse(plugin_file_path, media_type=response_type)
|
||||
# 异步生成器函数,用于流式读取文件
|
||||
async def file_generator():
|
||||
async with aiofiles.open(plugin_file_path, mode='rb') as file:
|
||||
# 8KB 块大小
|
||||
while chunk := await file.read(8192):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
file_generator(),
|
||||
media_type=response_type,
|
||||
headers={"Content-Disposition": f"inline; filename={plugin_file_path.name}"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating/sending FileResponse for {plugin_file_path}: {e}", exc_info=True)
|
||||
logger.error(f"Error creating/sending StreamingResponse for {plugin_file_path}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal Server Error")
|
||||
|
||||
|
||||
@router.get("/folders", summary="获取插件文件夹配置", response_model=dict)
|
||||
def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
|
||||
async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict:
|
||||
"""
|
||||
获取插件文件夹分组配置
|
||||
"""
|
||||
@@ -391,7 +409,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe
|
||||
|
||||
|
||||
@router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response)
|
||||
def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
保存插件文件夹分组配置
|
||||
"""
|
||||
@@ -404,7 +422,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur
|
||||
|
||||
|
||||
@router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response)
|
||||
def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def create_plugin_folder(folder_name: str,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
创建新的插件文件夹
|
||||
"""
|
||||
@@ -418,33 +437,65 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get
|
||||
|
||||
|
||||
@router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response)
|
||||
def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def delete_plugin_folder(folder_name: str,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
删除插件文件夹
|
||||
"""
|
||||
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
|
||||
if folder_name in folders:
|
||||
del folders[folder_name]
|
||||
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
|
||||
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
|
||||
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在")
|
||||
|
||||
|
||||
@router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response)
|
||||
def update_folder_plugins(folder_name: str, plugin_ids: List[str], _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
更新指定文件夹中的插件列表
|
||||
"""
|
||||
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
|
||||
folders[folder_name] = plugin_ids
|
||||
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
|
||||
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
|
||||
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
|
||||
|
||||
|
||||
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
|
||||
def clone_plugin(plugin_id: str,
|
||||
clone_data: dict,
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
创建插件分身
|
||||
"""
|
||||
try:
|
||||
success, message = PluginManager().clone_plugin(
|
||||
plugin_id=plugin_id,
|
||||
suffix=clone_data.get("suffix", ""),
|
||||
name=clone_data.get("name", ""),
|
||||
description=clone_data.get("description", ""),
|
||||
version=clone_data.get("version", ""),
|
||||
icon=clone_data.get("icon", "")
|
||||
)
|
||||
|
||||
if success:
|
||||
# 注册插件服务
|
||||
reload_plugin(message)
|
||||
# 将分身插件添加到原插件所在的文件夹中
|
||||
_add_clone_to_plugin_folder(plugin_id, message)
|
||||
return schemas.Response(success=True, message="插件分身创建成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件分身失败:{str(e)}")
|
||||
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", summary="获取插件配置")
|
||||
def plugin_config(plugin_id: str,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
|
||||
async def plugin_config(plugin_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> dict:
|
||||
"""
|
||||
根据插件ID获取插件配置信息
|
||||
"""
|
||||
@@ -453,7 +504,7 @@ def plugin_config(plugin_id: str,
|
||||
|
||||
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
|
||||
def set_plugin_config(plugin_id: str, conf: dict,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
更新插件配置
|
||||
"""
|
||||
@@ -469,7 +520,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
|
||||
|
||||
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
|
||||
def uninstall_plugin(plugin_id: str,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
卸载插件
|
||||
"""
|
||||
@@ -507,36 +558,6 @@ def uninstall_plugin(plugin_id: str,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
|
||||
def clone_plugin(plugin_id: str,
|
||||
clone_data: dict,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
创建插件分身
|
||||
"""
|
||||
try:
|
||||
success, message = PluginManager().clone_plugin(
|
||||
plugin_id=plugin_id,
|
||||
suffix=clone_data.get("suffix", ""),
|
||||
name=clone_data.get("name", ""),
|
||||
description=clone_data.get("description", ""),
|
||||
version=clone_data.get("version", ""),
|
||||
icon=clone_data.get("icon", "")
|
||||
)
|
||||
|
||||
if success:
|
||||
# 注册插件服务
|
||||
reload_plugin(message)
|
||||
# 将分身插件添加到原插件所在的文件夹中
|
||||
_add_clone_to_plugin_folder(plugin_id, message)
|
||||
return schemas.Response(success=True, message="插件分身创建成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件分身失败:{str(e)}")
|
||||
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
|
||||
|
||||
|
||||
def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
|
||||
"""
|
||||
将分身插件添加到原插件所在的文件夹中
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, List, Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app import schemas
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.schemas.types import ChainEventType
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.schemas import RecommendSourceEventData
|
||||
from app.schemas.types import ChainEventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -29,163 +29,163 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
|
||||
def bangumi_calendar(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_calendar(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览Bangumi每日放送
|
||||
"""
|
||||
return RecommendChain().bangumi_calendar(page=page, count=count)
|
||||
return await RecommendChain().async_bangumi_calendar(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
|
||||
def douban_showing(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_showing(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣正在热映
|
||||
"""
|
||||
return RecommendChain().douban_movie_showing(page=page, count=count)
|
||||
return await RecommendChain().async_douban_movie_showing(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
|
||||
def douban_movies(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_movies(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣电影信息
|
||||
"""
|
||||
return RecommendChain().douban_movies(sort=sort, tags=tags, page=page, count=count)
|
||||
return await RecommendChain().async_douban_movies(sort=sort, tags=tags, page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
|
||||
def douban_tvs(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
return RecommendChain().douban_tvs(sort=sort, tags=tags, page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
|
||||
def douban_movie_top250(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
return RecommendChain().douban_movie_top250(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
|
||||
def douban_tv_weekly_chinese(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
中国每周剧集口碑榜
|
||||
"""
|
||||
return RecommendChain().douban_tv_weekly_chinese(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
|
||||
def douban_tv_weekly_global(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
全球每周剧集口碑榜
|
||||
"""
|
||||
return RecommendChain().douban_tv_weekly_global(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
|
||||
def douban_tv_animation(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
热门动画剧集
|
||||
"""
|
||||
return RecommendChain().douban_tv_animation(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
|
||||
def douban_movie_hot(page: Optional[int] = 1,
|
||||
async def douban_tvs(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
return await RecommendChain().async_douban_tvs(sort=sort, tags=tags, page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
|
||||
async def douban_movie_top250(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
return await RecommendChain().async_douban_movie_top250(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
|
||||
async def douban_tv_weekly_chinese(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
中国每周剧集口碑榜
|
||||
"""
|
||||
return await RecommendChain().async_douban_tv_weekly_chinese(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
|
||||
async def douban_tv_weekly_global(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
全球每周剧集口碑榜
|
||||
"""
|
||||
return await RecommendChain().async_douban_tv_weekly_global(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
|
||||
async def douban_tv_animation(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
热门动画剧集
|
||||
"""
|
||||
return await RecommendChain().async_douban_tv_animation(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
|
||||
async def douban_movie_hot(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
热门电影
|
||||
"""
|
||||
return RecommendChain().douban_movie_hot(page=page, count=count)
|
||||
return await RecommendChain().async_douban_movie_hot(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
|
||||
def douban_tv_hot(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def douban_tv_hot(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
热门电视剧
|
||||
"""
|
||||
return RecommendChain().douban_tv_hot(page=page, count=count)
|
||||
return await RecommendChain().async_douban_tv_hot(page=page, count=count)
|
||||
|
||||
|
||||
@router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览TMDB电影信息
|
||||
"""
|
||||
return RecommendChain().tmdb_movies(sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return await RecommendChain().async_tmdb_movies(sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
|
||||
|
||||
@router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
浏览TMDB剧集信息
|
||||
"""
|
||||
return RecommendChain().tmdb_tvs(sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return await RecommendChain().async_tmdb_tvs(sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
|
||||
|
||||
@router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_trending(page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_trending(page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
TMDB流行趋势
|
||||
"""
|
||||
return RecommendChain().tmdb_trending(page=page)
|
||||
return await RecommendChain().async_tmdb_trending(page=page)
|
||||
|
||||
@@ -16,23 +16,23 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
|
||||
def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询搜索结果
|
||||
"""
|
||||
torrents = SearchChain().last_search_results()
|
||||
torrents = await SearchChain().async_last_search_results() or []
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
|
||||
|
||||
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
|
||||
def search_by_id(mediaid: str,
|
||||
mtype: Optional[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[str] = None,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def search_by_id(mediaid: str,
|
||||
mtype: Optional[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[str] = None,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
@@ -49,55 +49,59 @@ def search_by_id(mediaid: str,
|
||||
else:
|
||||
site_list = None
|
||||
torrents = None
|
||||
media_chain = MediaChain()
|
||||
search_chain = SearchChain()
|
||||
# 根据前缀识别媒体ID
|
||||
if mediaid.startswith("tmdb:"):
|
||||
tmdbid = int(mediaid.replace("tmdb:", ""))
|
||||
if settings.RECOGNIZE_SOURCE == "douban":
|
||||
# 通过TMDBID识别豆瓣ID
|
||||
doubaninfo = MediaChain().get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
|
||||
doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
|
||||
if doubaninfo:
|
||||
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
|
||||
else:
|
||||
torrents = SearchChain().search_by_id(tmdbid=tmdbid, mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(tmdbid=tmdbid, mtype=media_type, area=area,
|
||||
season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
elif mediaid.startswith("douban:"):
|
||||
doubanid = mediaid.replace("douban:", "")
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
# 通过豆瓣ID识别TMDBID
|
||||
tmdbinfo = MediaChain().get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
|
||||
tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
|
||||
if tmdbinfo:
|
||||
if tmdbinfo.get('season') and not media_season:
|
||||
media_season = tmdbinfo.get('season')
|
||||
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
|
||||
else:
|
||||
torrents = SearchChain().search_by_id(doubanid=doubanid, mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(doubanid=doubanid, mtype=media_type, area=area,
|
||||
season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
elif mediaid.startswith("bangumi:"):
|
||||
bangumiid = int(mediaid.replace("bangumi:", ""))
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
# 通过BangumiID识别TMDBID
|
||||
tmdbinfo = MediaChain().get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
|
||||
tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
|
||||
if tmdbinfo:
|
||||
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
|
||||
else:
|
||||
# 通过BangumiID识别豆瓣ID
|
||||
doubaninfo = MediaChain().get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
|
||||
doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
|
||||
if doubaninfo:
|
||||
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
|
||||
else:
|
||||
@@ -106,18 +110,18 @@ def search_by_id(mediaid: str,
|
||||
mediaid=mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
if event_data.media_dict:
|
||||
search_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
torrents = SearchChain().search_by_id(tmdbid=search_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(tmdbid=search_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
elif event_data.convert_type == "douban":
|
||||
torrents = SearchChain().search_by_id(doubanid=search_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(doubanid=search_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
else:
|
||||
if not title:
|
||||
return schemas.Response(success=False, message="未知的媒体ID")
|
||||
@@ -130,14 +134,16 @@ def search_by_id(mediaid: str,
|
||||
if media_season:
|
||||
meta.type = MediaType.TV
|
||||
meta.begin_season = media_season
|
||||
mediainfo = MediaChain().recognize_media(meta=meta)
|
||||
mediainfo = await media_chain.async_recognize_media(meta=meta)
|
||||
if mediainfo:
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
torrents = SearchChain().search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type,
|
||||
area=area,
|
||||
season=media_season, cache_local=True)
|
||||
else:
|
||||
torrents = SearchChain().search_by_id(doubanid=mediainfo.douban_id, mtype=media_type, area=area,
|
||||
season=media_season, cache_local=True)
|
||||
torrents = await search_chain.async_search_by_id(doubanid=mediainfo.douban_id, mtype=media_type,
|
||||
area=area,
|
||||
season=media_season, cache_local=True)
|
||||
# 返回搜索结果
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
@@ -146,16 +152,18 @@ def search_by_id(mediaid: str,
|
||||
|
||||
|
||||
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
|
||||
def search_by_title(keyword: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def search_by_title(keyword: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
torrents = SearchChain().search_by_title(title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
cache_local=True)
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
cache_local=True
|
||||
)
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
@@ -9,10 +10,10 @@ from app.api.endpoints.plugin import register_plugin_api
|
||||
from app.chain.site import SiteChain
|
||||
from app.chain.torrents import TorrentsChain
|
||||
from app.command import Command
|
||||
from app.core.event import EventManager
|
||||
from app.core.event import eventmanager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_async_db
|
||||
from app.db.models import User
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteicon import SiteIcon
|
||||
@@ -20,8 +21,8 @@ from app.db.models.sitestatistic import SiteStatistic
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.types import SystemConfigKey, EventType
|
||||
from app.utils.string import StringUtils
|
||||
@@ -30,20 +31,20 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
|
||||
def read_sites(db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
|
||||
async def read_sites(db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser)) -> List[dict]:
|
||||
"""
|
||||
获取站点列表
|
||||
"""
|
||||
return Site.list_order_by_pri(db)
|
||||
return await Site.async_list_order_by_pri(db)
|
||||
|
||||
|
||||
@router.post("/", summary="新增站点", response_model=schemas.Response)
|
||||
def add_site(
|
||||
async def add_site(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
site_in: schemas.Site,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)
|
||||
_: User = Depends(get_current_active_superuser)
|
||||
) -> Any:
|
||||
"""
|
||||
新增站点
|
||||
@@ -53,10 +54,10 @@ def add_site(
|
||||
if SitesHelper().auth_level < 2:
|
||||
return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!")
|
||||
domain = StringUtils.get_url_domain(site_in.url)
|
||||
site_info = SitesHelper().get_indexer(domain)
|
||||
site_info = await SitesHelper().async_get_indexer(domain)
|
||||
if not site_info:
|
||||
return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确")
|
||||
if Site.get_by_domain(db, domain):
|
||||
if await Site.async_get_by_domain(db, domain):
|
||||
return schemas.Response(success=False, message=f"{domain} 站点己存在")
|
||||
# 保存站点信息
|
||||
site_in.domain = domain
|
||||
@@ -66,42 +67,42 @@ def add_site(
|
||||
site_in.name = site_info.get("name")
|
||||
site_in.id = None
|
||||
site_in.public = 1 if site_info.get("public") else 0
|
||||
site = Site(**site_in.dict())
|
||||
site = Site(**site_in.model_dump())
|
||||
site.create(db)
|
||||
# 通知站点更新
|
||||
EventManager().send_event(EventType.SiteUpdated, {
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": domain
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.put("/", summary="更新站点", response_model=schemas.Response)
|
||||
def update_site(
|
||||
async def update_site(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
site_in: schemas.Site,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)
|
||||
_: User = Depends(get_current_active_superuser)
|
||||
) -> Any:
|
||||
"""
|
||||
更新站点信息
|
||||
"""
|
||||
site = Site.get(db, site_in.id)
|
||||
site = await Site.async_get(db, site_in.id)
|
||||
if not site:
|
||||
return schemas.Response(success=False, message="站点不存在")
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
site.update(db, site_in.dict())
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
EventManager().send_event(EventType.SiteUpdated, {
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
|
||||
def cookie_cloud_sync(background_tasks: BackgroundTasks,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def cookie_cloud_sync(background_tasks: BackgroundTasks,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
运行CookieCloud同步站点信息
|
||||
"""
|
||||
@@ -110,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks,
|
||||
|
||||
|
||||
@router.get("/reset", summary="重置站点", response_model=schemas.Response)
|
||||
def reset(db: Session = Depends(get_db),
|
||||
def reset(db: AsyncSession = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
清空所有站点数据并重新同步CookieCloud站点信息
|
||||
@@ -121,25 +122,25 @@ def reset(db: Session = Depends(get_db),
|
||||
# 启动定时服务
|
||||
Scheduler().start("cookiecloud", manual=True)
|
||||
# 插件站点删除
|
||||
EventManager().send_event(EventType.SiteDeleted,
|
||||
{
|
||||
"site_id": "*"
|
||||
})
|
||||
eventmanager.send_event(EventType.SiteDeleted,
|
||||
{
|
||||
"site_id": "*"
|
||||
})
|
||||
return schemas.Response(success=True, message="站点已重置!")
|
||||
|
||||
|
||||
@router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response)
|
||||
def update_sites_priority(
|
||||
async def update_sites_priority(
|
||||
priorities: List[dict],
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
批量更新站点优先级
|
||||
"""
|
||||
for priority in priorities:
|
||||
site = Site.get(db, priority.get("id"))
|
||||
site = await Site.async_get(db, priority.get("id"))
|
||||
if site:
|
||||
site.update(db, {"pri": priority.get("pri")})
|
||||
await site.async_update(db, {"pri": priority.get("pri")})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -150,7 +151,7 @@ def update_cookie(
|
||||
password: str,
|
||||
code: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
使用用户密码更新站点Cookie
|
||||
"""
|
||||
@@ -173,7 +174,7 @@ def update_cookie(
|
||||
def refresh_userdata(
|
||||
site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
刷新站点用户数据
|
||||
"""
|
||||
@@ -191,34 +192,34 @@ def refresh_userdata(
|
||||
|
||||
|
||||
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
|
||||
def read_userdata_latest(
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def read_userdata_latest(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
查询所有站点最新用户数据
|
||||
"""
|
||||
user_datas = SiteUserData.get_latest(db)
|
||||
user_datas = await SiteUserData.async_get_latest(db)
|
||||
if not user_datas:
|
||||
return []
|
||||
return [user_data.to_dict() for user_data in user_datas]
|
||||
|
||||
|
||||
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
|
||||
def read_userdata(
|
||||
async def read_userdata(
|
||||
site_id: int,
|
||||
workdate: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
查询站点用户数据
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
if not user_data:
|
||||
return schemas.Response(success=False, data=[])
|
||||
return schemas.Response(success=True, data=user_data)
|
||||
@@ -242,19 +243,19 @@ def test_site(site_id: int,
|
||||
|
||||
|
||||
@router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response)
|
||||
def site_icon(site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def site_icon(site_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取站点图标:base64或者url
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
icon = SiteIcon.get_by_domain(db, site.domain)
|
||||
icon = await SiteIcon.async_get_by_domain(db, site.domain)
|
||||
if not icon:
|
||||
return schemas.Response(success=False, message="站点图标不存在!")
|
||||
return schemas.Response(success=True, data={
|
||||
@@ -263,19 +264,19 @@ def site_icon(site_id: int,
|
||||
|
||||
|
||||
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
|
||||
def site_category(site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def site_category(site_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取站点分类
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
indexer = SitesHelper().get_indexer(site.domain)
|
||||
indexer = await SitesHelper().async_get_indexer(site.domain)
|
||||
if not indexer:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -293,38 +294,38 @@ def site_category(site_id: int,
|
||||
|
||||
|
||||
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
|
||||
def site_resource(site_id: int,
|
||||
keyword: Optional[str] = None,
|
||||
cat: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
async def site_resource(site_id: int,
|
||||
keyword: Optional[str] = None,
|
||||
cat: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
浏览站点资源
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
|
||||
torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
|
||||
if not torrents:
|
||||
return []
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
|
||||
|
||||
@router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site)
|
||||
def read_site_by_domain(
|
||||
async def read_site_by_domain(
|
||||
site_url: str,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
通过域名获取站点信息
|
||||
"""
|
||||
domain = StringUtils.get_url_domain(site_url)
|
||||
site = Site.get_by_domain(db, domain)
|
||||
site = await Site.async_get_by_domain(db, domain)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -333,25 +334,36 @@ def read_site_by_domain(
|
||||
return site
|
||||
|
||||
|
||||
@router.get("/statistic/{site_url}", summary="站点统计信息", response_model=schemas.SiteStatistic)
|
||||
def read_site_by_domain(
|
||||
@router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic)
|
||||
async def read_statistic_by_domain(
|
||||
site_url: str,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
通过域名获取站点统计信息
|
||||
"""
|
||||
domain = StringUtils.get_url_domain(site_url)
|
||||
sitestatistic = SiteStatistic.get_by_domain(db, domain)
|
||||
sitestatistic = await SiteStatistic.async_get_by_domain(db, domain)
|
||||
if sitestatistic:
|
||||
return sitestatistic
|
||||
return schemas.SiteStatistic(domain=domain)
|
||||
|
||||
|
||||
@router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic])
|
||||
async def read_statistics(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有站点统计信息
|
||||
"""
|
||||
return await SiteStatistic.async_list(db)
|
||||
|
||||
|
||||
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
|
||||
def read_rss_sites(db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
|
||||
async def read_rss_sites(db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
|
||||
"""
|
||||
获取站点列表
|
||||
"""
|
||||
@@ -359,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db),
|
||||
selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
|
||||
|
||||
# 所有站点
|
||||
all_site = Site.list_order_by_pri(db)
|
||||
all_site = await Site.async_list_order_by_pri(db)
|
||||
if not selected_sites:
|
||||
return all_site
|
||||
|
||||
@@ -369,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db),
|
||||
|
||||
|
||||
@router.get("/auth", summary="查询认证站点", response_model=dict)
|
||||
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
|
||||
async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
|
||||
"""
|
||||
获取可认证站点列表
|
||||
"""
|
||||
@@ -387,7 +399,7 @@ def auth_site(
|
||||
if not auth_info or not auth_info.site or not auth_info.params:
|
||||
return schemas.Response(success=False, message="请输入认证站点和认证参数")
|
||||
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.model_dump())
|
||||
# 认证成功后,重新初始化插件
|
||||
PluginManager().init_config()
|
||||
Scheduler().init_plugin_jobs()
|
||||
@@ -397,12 +409,12 @@ def auth_site(
|
||||
|
||||
|
||||
@router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response)
|
||||
def site_mapping(_: User = Depends(get_current_active_superuser)):
|
||||
async def site_mapping(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
获取站点域名到名称的映射关系
|
||||
"""
|
||||
try:
|
||||
sites = SiteOper().list()
|
||||
sites = await SiteOper().async_list()
|
||||
mapping = {}
|
||||
for site in sites:
|
||||
mapping[site.domain] = site.name
|
||||
@@ -411,16 +423,24 @@ def site_mapping(_: User = Depends(get_current_active_superuser)):
|
||||
return schemas.Response(success=False, message=f"获取映射失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/supporting", summary="获取支持的站点列表", response_model=dict)
|
||||
async def support_sites(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
获取支持的站点列表
|
||||
"""
|
||||
return SitesHelper().get_indexsites()
|
||||
|
||||
|
||||
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
|
||||
def read_site(
|
||||
async def read_site(
|
||||
site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)
|
||||
) -> Any:
|
||||
"""
|
||||
通过ID获取站点信息
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -430,18 +450,18 @@ def read_site(
|
||||
|
||||
|
||||
@router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response)
|
||||
def delete_site(
|
||||
async def delete_site(
|
||||
site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser)
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)
|
||||
) -> Any:
|
||||
"""
|
||||
删除站点
|
||||
"""
|
||||
Site.delete(db, site_id)
|
||||
await Site.async_delete(db, site_id)
|
||||
# 插件站点删除
|
||||
EventManager().send_event(EventType.SiteDeleted,
|
||||
{
|
||||
"site_id": site_id
|
||||
})
|
||||
await eventmanager.async_send_event(EventType.SiteDeleted,
|
||||
{
|
||||
"site_id": site_id
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -12,9 +12,10 @@ from app.core.config import settings
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.core.security import verify_token
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.schemas.types import ProgressKey
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -80,7 +81,7 @@ def list_files(fileitem: schemas.FileItem,
|
||||
file_list = StorageChain().list_files(fileitem)
|
||||
if file_list:
|
||||
if sort == "name":
|
||||
file_list.sort(key=lambda x: x.name or "")
|
||||
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
|
||||
else:
|
||||
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
|
||||
return file_list
|
||||
@@ -171,15 +172,14 @@ def rename(fileitem: schemas.FileItem,
|
||||
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
|
||||
if sub_files:
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.BatchRename)
|
||||
progress = ProgressHelper(ProgressKey.BatchRename)
|
||||
progress.start()
|
||||
total = len(sub_files)
|
||||
handled = 0
|
||||
for sub_file in sub_files:
|
||||
handled += 1
|
||||
progress.update(value=handled / total * 100,
|
||||
text=f"正在处理 {sub_file.name} ...",
|
||||
key=ProgressKey.BatchRename)
|
||||
text=f"正在处理 {sub_file.name} ...")
|
||||
if sub_file.type == "dir":
|
||||
continue
|
||||
if not sub_file.extension:
|
||||
@@ -190,19 +190,19 @@ def rename(fileitem: schemas.FileItem,
|
||||
meta = MetaInfoPath(sub_path)
|
||||
mediainfo = transferchain.recognize_media(meta)
|
||||
if not mediainfo:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
|
||||
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
|
||||
if not new_path:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
|
||||
ret: schemas.Response = rename(fileitem=sub_file,
|
||||
new_name=Path(new_path).name,
|
||||
recursive=False)
|
||||
if not ret.success:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
# 重命名自己
|
||||
result = StorageChain().rename_file(fileitem, new_name)
|
||||
if result:
|
||||
@@ -222,7 +222,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
|
||||
|
||||
|
||||
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
|
||||
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
|
||||
async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
"""
|
||||
查询支持的整理方式
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Any, Annotated, Optional
|
||||
|
||||
import cn2an
|
||||
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
@@ -11,12 +12,12 @@ from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db import get_db
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.db.user_oper import get_current_active_user_async
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.types import MediaType, EventType, SystemConfigKey
|
||||
@@ -34,28 +35,28 @@ def start_subscribe_add(title: str, year: str,
|
||||
|
||||
|
||||
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
|
||||
def read_subscribes(
|
||||
db: Session = Depends(get_db),
|
||||
async def read_subscribes(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询所有订阅
|
||||
"""
|
||||
return Subscribe.list(db)
|
||||
return await Subscribe.async_list(db)
|
||||
|
||||
|
||||
@router.get("/list", summary="查询所有订阅(API_TOKEN)", response_model=List[schemas.Subscribe])
|
||||
def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询所有订阅 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return read_subscribes()
|
||||
return await read_subscribes()
|
||||
|
||||
|
||||
@router.post("/", summary="新增订阅", response_model=schemas.Response)
|
||||
def create_subscribe(
|
||||
async def create_subscribe(
|
||||
*,
|
||||
subscribe_in: schemas.Subscribe,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
新增订阅
|
||||
@@ -77,31 +78,35 @@ def create_subscribe(
|
||||
title = None
|
||||
# 订阅用户
|
||||
subscribe_in.username = current_user.name
|
||||
sid, message = SubscribeChain().add(mtype=mtype,
|
||||
title=title,
|
||||
exist_ok=True,
|
||||
**subscribe_in.dict())
|
||||
# 转化为字典
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if subscribe_in.id:
|
||||
subscribe_dict.pop("id", None)
|
||||
sid, message = await SubscribeChain().async_add(mtype=mtype,
|
||||
title=title,
|
||||
exist_ok=True,
|
||||
**subscribe_dict)
|
||||
return schemas.Response(
|
||||
success=bool(sid), message=message, data={"id": sid}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/", summary="更新订阅", response_model=schemas.Response)
|
||||
def update_subscribe(
|
||||
async def update_subscribe(
|
||||
*,
|
||||
subscribe_in: schemas.Subscribe,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
更新订阅信息
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subscribe_in.id)
|
||||
subscribe = await Subscribe.async_get(db, subscribe_in.id)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
# 避免更新缺失集数
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if not subscribe_in.lack_episode:
|
||||
# 没有缺失集数时,缺失集数清空,避免更新为0
|
||||
subscribe_dict.pop("lack_episode")
|
||||
@@ -114,50 +119,55 @@ def update_subscribe(
|
||||
# 是否手动修改过总集数
|
||||
if subscribe_in.total_episode != subscribe.total_episode:
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
subscribe.update(db, subscribe_dict)
|
||||
# 更新到数据库
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_in.id)
|
||||
# 发送订阅调整事件
|
||||
eventmanager.send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe.id,
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_in.id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response)
|
||||
def update_subscribe_status(
|
||||
async def update_subscribe_status(
|
||||
subid: int,
|
||||
state: str,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
更新订阅状态
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subid)
|
||||
subscribe = await Subscribe.async_get(db, subid)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
valid_states = ["R", "P", "S"]
|
||||
if state not in valid_states:
|
||||
return schemas.Response(success=False, message="无效的订阅状态")
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe.update(db, {
|
||||
await subscribe.async_update(db, {
|
||||
"state": state
|
||||
})
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subid)
|
||||
# 发送订阅调整事件
|
||||
eventmanager.send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe.id,
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subid,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe)
|
||||
def subscribe_mediaid(
|
||||
async def subscribe_mediaid(
|
||||
mediaid: str,
|
||||
season: Optional[int] = None,
|
||||
title: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
|
||||
@@ -167,23 +177,23 @@ def subscribe_mediaid(
|
||||
tmdbid = mediaid[5:]
|
||||
if not tmdbid or not str(tmdbid).isdigit():
|
||||
return Subscribe()
|
||||
result = Subscribe.exists(db, tmdbid=int(tmdbid), season=season)
|
||||
result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
|
||||
elif mediaid.startswith("douban:"):
|
||||
doubanid = mediaid[7:]
|
||||
if not doubanid:
|
||||
return Subscribe()
|
||||
result = Subscribe.get_by_doubanid(db, doubanid)
|
||||
result = await Subscribe.async_get_by_doubanid(db, doubanid)
|
||||
if not result and title:
|
||||
title_check = True
|
||||
elif mediaid.startswith("bangumi:"):
|
||||
bangumiid = mediaid[8:]
|
||||
if not bangumiid or not str(bangumiid).isdigit():
|
||||
return Subscribe()
|
||||
result = Subscribe.get_by_bangumiid(db, int(bangumiid))
|
||||
result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
|
||||
if not result and title:
|
||||
title_check = True
|
||||
else:
|
||||
result = Subscribe.get_by_mediaid(db, mediaid)
|
||||
result = await Subscribe.async_get_by_mediaid(db, mediaid)
|
||||
if not result and title:
|
||||
title_check = True
|
||||
# 使用名称检查订阅
|
||||
@@ -191,7 +201,7 @@ def subscribe_mediaid(
|
||||
meta = MetaInfo(title)
|
||||
if season:
|
||||
meta.begin_season = season
|
||||
result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season)
|
||||
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
|
||||
|
||||
return result if result else Subscribe()
|
||||
|
||||
@@ -207,26 +217,30 @@ def refresh_subscribes(
|
||||
|
||||
|
||||
@router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response)
|
||||
def reset_subscribes(
|
||||
async def reset_subscribes(
|
||||
subid: int,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
重置订阅
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subid)
|
||||
subscribe = await Subscribe.async_get(db, subid)
|
||||
if subscribe:
|
||||
# 在更新之前获取旧数据
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe.update(db, {
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, {
|
||||
"note": [],
|
||||
"lack_episode": subscribe.total_episode,
|
||||
"state": "R"
|
||||
})
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subid)
|
||||
# 发送订阅调整事件
|
||||
eventmanager.send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe.id,
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subid,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
@@ -243,7 +257,7 @@ def check_subscribes(
|
||||
|
||||
|
||||
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
|
||||
def search_subscribes(
|
||||
async def search_subscribes(
|
||||
background_tasks: BackgroundTasks,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -262,7 +276,7 @@ def search_subscribes(
|
||||
|
||||
|
||||
@router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response)
|
||||
def search_subscribe(
|
||||
async def search_subscribe(
|
||||
subscribe_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@@ -282,10 +296,10 @@ def search_subscribe(
|
||||
|
||||
|
||||
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
|
||||
def delete_subscribe_by_mediaid(
|
||||
async def delete_subscribe_by_mediaid(
|
||||
mediaid: str,
|
||||
season: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
@@ -296,25 +310,28 @@ def delete_subscribe_by_mediaid(
|
||||
tmdbid = mediaid[5:]
|
||||
if not tmdbid or not str(tmdbid).isdigit():
|
||||
return schemas.Response(success=False)
|
||||
subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season)
|
||||
subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
|
||||
delete_subscribes.extend(subscribes)
|
||||
elif mediaid.startswith("douban:"):
|
||||
doubanid = mediaid[7:]
|
||||
if not doubanid:
|
||||
return schemas.Response(success=False)
|
||||
subscribe = Subscribe().get_by_doubanid(db, doubanid)
|
||||
subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
|
||||
if subscribe:
|
||||
delete_subscribes.append(subscribe)
|
||||
else:
|
||||
subscribe = Subscribe().get_by_mediaid(db, mediaid)
|
||||
subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
|
||||
if subscribe:
|
||||
delete_subscribes.append(subscribe)
|
||||
for subscribe in delete_subscribes:
|
||||
Subscribe().delete(db, subscribe.id)
|
||||
# 在删除之前获取订阅信息
|
||||
subscribe_info = subscribe.to_dict()
|
||||
subscribe_id = subscribe.id
|
||||
await Subscribe.async_delete(db, subscribe_id)
|
||||
# 发送事件
|
||||
eventmanager.send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe.id,
|
||||
"subscribe_info": subscribe.to_dict()
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -373,42 +390,54 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
|
||||
|
||||
|
||||
@router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe])
|
||||
def subscribe_history(
|
||||
async def subscribe_history(
|
||||
mtype: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询电影/电视剧订阅历史
|
||||
"""
|
||||
return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
|
||||
def delete_subscribe(
|
||||
async def delete_subscribe(
|
||||
history_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
删除订阅历史
|
||||
"""
|
||||
SubscribeHistory.delete(db, history_id)
|
||||
await SubscribeHistory.async_delete(db, history_id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo])
|
||||
def popular_subscribes(
|
||||
async def popular_subscribes(
|
||||
stype: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询热门订阅
|
||||
"""
|
||||
subscribes = SubscribeHelper().get_statistic(stype=stype, page=page, count=count)
|
||||
subscribes = await SubscribeHelper().async_get_statistic(
|
||||
stype=stype,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
if subscribes:
|
||||
ret_medias = []
|
||||
for sub in subscribes:
|
||||
@@ -444,14 +473,14 @@ def popular_subscribes(
|
||||
|
||||
|
||||
@router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe])
|
||||
def user_subscribes(
|
||||
async def user_subscribes(
|
||||
username: str,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询用户订阅
|
||||
"""
|
||||
return Subscribe.list_by_username(db, username)
|
||||
return await Subscribe.async_list_by_username(db, username)
|
||||
|
||||
|
||||
@router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo)
|
||||
@@ -469,51 +498,51 @@ def subscribe_files(
|
||||
|
||||
|
||||
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
|
||||
def subscribe_share(
|
||||
async def subscribe_share(
|
||||
sub: schemas.SubscribeShare,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
分享订阅
|
||||
"""
|
||||
state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id,
|
||||
share_title=sub.share_title,
|
||||
share_comment=sub.share_comment,
|
||||
share_user=sub.share_user)
|
||||
state, errmsg = await SubscribeHelper().async_sub_share(subscribe_id=sub.subscribe_id,
|
||||
share_title=sub.share_title,
|
||||
share_comment=sub.share_comment,
|
||||
share_user=sub.share_user)
|
||||
return schemas.Response(success=state, message=errmsg)
|
||||
|
||||
|
||||
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
|
||||
def subscribe_share_delete(
|
||||
async def subscribe_share_delete(
|
||||
share_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
删除分享
|
||||
"""
|
||||
state, errmsg = SubscribeHelper().share_delete(share_id=share_id)
|
||||
state, errmsg = await SubscribeHelper().async_share_delete(share_id=share_id)
|
||||
return schemas.Response(success=state, message=errmsg)
|
||||
|
||||
|
||||
@router.post("/fork", summary="复用订阅", response_model=schemas.Response)
|
||||
def subscribe_fork(
|
||||
async def subscribe_fork(
|
||||
sub: schemas.SubscribeShare,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
current_user: User = Depends(get_current_active_user_async)) -> Any:
|
||||
"""
|
||||
复用订阅
|
||||
"""
|
||||
sub_dict = sub.dict()
|
||||
sub_dict = sub.model_dump()
|
||||
sub_dict.pop("id")
|
||||
for key in list(sub_dict.keys()):
|
||||
if not hasattr(schemas.Subscribe(), key):
|
||||
sub_dict.pop(key)
|
||||
result = create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
|
||||
current_user=current_user)
|
||||
result = await create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
|
||||
current_user=current_user)
|
||||
if result.success:
|
||||
SubscribeHelper().sub_fork(share_id=sub.id)
|
||||
await SubscribeHelper().async_sub_fork(share_id=sub.id)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str])
|
||||
def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询已Follow的订阅分享人
|
||||
"""
|
||||
@@ -521,7 +550,7 @@ def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any
|
||||
|
||||
|
||||
@router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response)
|
||||
def follow_subscriber(
|
||||
async def follow_subscriber(
|
||||
share_uid: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -530,12 +559,12 @@ def follow_subscriber(
|
||||
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
|
||||
if share_uid and share_uid not in subscribers:
|
||||
subscribers.append(share_uid)
|
||||
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
|
||||
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response)
|
||||
def unfollow_subscriber(
|
||||
async def unfollow_subscriber(
|
||||
share_uid: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -544,51 +573,74 @@ def unfollow_subscriber(
|
||||
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
|
||||
if share_uid and share_uid in subscribers:
|
||||
subscribers.remove(share_uid)
|
||||
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
|
||||
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare])
|
||||
def popular_subscribes(
|
||||
async def popular_subscribes(
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询分享的订阅
|
||||
"""
|
||||
return SubscribeHelper().get_shares(name=name, page=page, count=count)
|
||||
return await SubscribeHelper().async_get_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
|
||||
async def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询订阅分享统计
|
||||
返回每个分享人分享的媒体数量以及总的复用人次
|
||||
"""
|
||||
return await SubscribeHelper().async_get_share_statistics()
|
||||
|
||||
|
||||
@router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe)
|
||||
def read_subscribe(
|
||||
async def read_subscribe(
|
||||
subscribe_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据订阅编号查询订阅信息
|
||||
"""
|
||||
if not subscribe_id:
|
||||
return Subscribe()
|
||||
return Subscribe.get(db, subscribe_id)
|
||||
return await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
|
||||
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
|
||||
def delete_subscribe(
|
||||
async def delete_subscribe(
|
||||
subscribe_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
删除订阅信息
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subscribe_id)
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if subscribe:
|
||||
subscribe.delete(db, subscribe_id)
|
||||
# 在删除之前获取订阅信息
|
||||
subscribe_info = subscribe.to_dict()
|
||||
await Subscribe.async_delete(db, subscribe_id)
|
||||
# 发送事件
|
||||
eventmanager.send_event(EventType.SubscribeDeleted, {
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe.to_dict()
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -11,25 +10,29 @@ from typing import Optional, Union, Annotated
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.module import ModuleManager
|
||||
from app.core.security import verify_apitoken, verify_resource_token, verify_token
|
||||
from app.core.event import eventmanager
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.log import logger
|
||||
@@ -37,7 +40,7 @@ from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import SystemConfigKey, EventType
|
||||
from app.utils.crypto import HashUtils
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from version import APP_VERSION
|
||||
@@ -45,85 +48,84 @@ from version import APP_VERSION
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def fetch_image(
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: bool = False,
|
||||
use_disk_cache: bool = False,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Response:
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
if not url:
|
||||
raise HTTPException(status_code=404, detail="URL not provided")
|
||||
return None
|
||||
|
||||
if allowed_domains is None:
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS)
|
||||
|
||||
# 验证URL安全性
|
||||
if not SecurityUtils.is_safe_url(url, allowed_domains):
|
||||
raise HTTPException(status_code=404, detail="Unsafe URL")
|
||||
|
||||
# 后续观察系统性能表现,如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求,可以考虑重新引入内存缓存
|
||||
cache_path = None
|
||||
if use_disk_cache:
|
||||
# 生成缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = settings.CACHE_PATH / "images" / sanitized_path
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
if not cache_path.suffix:
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 确保缓存路径和文件类型合法
|
||||
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
|
||||
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
|
||||
# 缓存对像,缓存过期时间为全局图片缓存天数
|
||||
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
|
||||
if cache_path.exists():
|
||||
try:
|
||||
content = cache_path.read_bytes()
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
return Response(content=content, media_type="image/jpeg", headers=headers)
|
||||
except Exception as e:
|
||||
# 如果读取磁盘缓存发生异常,这里仅记录日志,尝试再次请求远端进行处理
|
||||
logger.debug(f"Failed to read cache file {cache_path}: {e}")
|
||||
if use_cache:
|
||||
content = await cache_backend.get(cache_path.as_posix(), region="images")
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
Image.open(io.BytesIO(response.content)).verify()
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
raise HTTPException(status_code=502, detail="Invalid image format")
|
||||
logger.warn(f"Invalid image format for URL {url}: {e}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 获取请求响应头
|
||||
response_headers = response.headers
|
||||
|
||||
cache_control_header = response_headers.get("Cache-Control", "")
|
||||
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
|
||||
|
||||
# 如果需要使用磁盘缓存,则保存到磁盘
|
||||
if use_disk_cache and cache_path:
|
||||
try:
|
||||
if not cache_path.parent.exists():
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
temp_path = Path(tmp_file.name)
|
||||
temp_path.replace(cache_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to write cache file {cache_path}: {e}")
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
await cache_backend.set(cache_path.as_posix(), content, region="images")
|
||||
logger.debug(f"Image cached at {cache_path.as_posix()}")
|
||||
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
@@ -131,8 +133,8 @@ def fetch_image(
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(status_code=304, headers=headers)
|
||||
|
||||
# 响应
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
@@ -141,9 +143,11 @@ def fetch_image(
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
def proxy_img(
|
||||
async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
) -> Response:
|
||||
@@ -154,12 +158,17 @@ def proxy_img(
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
return fetch_image(url=imgurl, proxy=proxy, use_disk_cache=False,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
|
||||
|
||||
@router.get("/cache/image", summary="图片缓存")
|
||||
def cache_img(
|
||||
async def cache_img(
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
@@ -169,7 +178,8 @@ def cache_img(
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
proxy = "doubanio.com" not in url
|
||||
return fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match)
|
||||
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
|
||||
|
||||
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
|
||||
@@ -181,25 +191,28 @@ def get_global_setting(token: str):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN"}
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
)
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": SubscribeHelper().is_admin_user(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
|
||||
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
|
||||
def get_env_setting(_: User = Depends(get_current_active_superuser)):
|
||||
async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
)
|
||||
info.update({
|
||||
@@ -213,8 +226,8 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
|
||||
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
|
||||
def set_env_setting(env: dict,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
async def set_env_setting(env: dict,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
更新系统环境变量(仅管理员)
|
||||
"""
|
||||
@@ -236,7 +249,7 @@ def set_env_setting(env: dict,
|
||||
if success_updates:
|
||||
for key in success_updates.keys():
|
||||
# 发送配置变更事件
|
||||
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=getattr(settings, key, None),
|
||||
change_type="update"
|
||||
@@ -256,16 +269,16 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
"""
|
||||
progress = ProgressHelper()
|
||||
progress = ProgressHelper(process_type)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
detail = progress.get(process_type)
|
||||
detail = progress.get()
|
||||
yield f"data: {json.dumps(detail)}\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
await asyncio.sleep(0.5)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
@@ -273,8 +286,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
|
||||
|
||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||
def get_setting(key: str,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
async def get_setting(key: str,
|
||||
_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统设置(仅管理员)
|
||||
"""
|
||||
@@ -288,10 +301,10 @@ def get_setting(key: str,
|
||||
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
async def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
更新系统设置(仅管理员)
|
||||
@@ -300,7 +313,7 @@ def set_setting(
|
||||
success, message = settings.update_setting(key=key, value=value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
@@ -312,10 +325,10 @@ def set_setting(
|
||||
if isinstance(value, list):
|
||||
value = list(filter(None, value))
|
||||
value = value if value else None
|
||||
success = SystemConfigOper().set(key, value)
|
||||
success = await SystemConfigOper().async_set(key, value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
@@ -355,60 +368,106 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
length = -1 时, 返回text/plain
|
||||
否则 返回格式SSE
|
||||
"""
|
||||
log_path = settings.LOG_PATH / logfile
|
||||
base_path = AsyncPath(settings.LOG_PATH)
|
||||
log_path = base_path / logfile
|
||||
|
||||
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
if not log_path.exists() or not log_path.is_file():
|
||||
if not await log_path.exists() or not await log_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
async def log_generator():
|
||||
try:
|
||||
# 使用固定大小的双向队列来限制内存使用
|
||||
lines_queue = deque(maxlen=max(length, 50))
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f:
|
||||
# 逐行读取文件,将每一行存入队列
|
||||
file_content = await f.read()
|
||||
for line in file_content.splitlines():
|
||||
# 获取文件大小
|
||||
file_stat = await log_path.stat()
|
||||
file_size = file_stat.st_size
|
||||
|
||||
# 读取历史日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
# 优化大文件读取策略
|
||||
if file_size > 100 * 1024:
|
||||
# 只读取最后100KB的内容
|
||||
bytes_to_read = min(file_size, 100 * 1024)
|
||||
position = file_size - bytes_to_read
|
||||
await f.seek(position)
|
||||
content = await f.read()
|
||||
# 找到第一个完整的行
|
||||
first_newline = content.find('\n')
|
||||
if first_newline != -1:
|
||||
content = content[first_newline + 1:]
|
||||
else:
|
||||
# 小文件直接读取全部内容
|
||||
content = await f.read()
|
||||
|
||||
# 按行分割并添加到队列,只保留非空行
|
||||
lines = [line.strip() for line in content.splitlines() if line.strip()]
|
||||
# 只取最后N行
|
||||
for line in lines[-max(length, 50):]:
|
||||
lines_queue.append(line)
|
||||
for line in lines_queue:
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
# 输出历史日志
|
||||
for line in lines_queue:
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
# 实时监听新日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
# 移动文件指针到文件末尾,继续监听新增内容
|
||||
await f.seek(0, 2)
|
||||
# 记录初始文件大小
|
||||
initial_stat = await log_path.stat()
|
||||
initial_size = initial_stat.st_size
|
||||
# 实时监听新日志,使用更短的轮询间隔
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
line = await f.readline()
|
||||
if not line:
|
||||
# 检查文件是否有新内容
|
||||
current_stat = await log_path.stat()
|
||||
current_size = current_stat.st_size
|
||||
if current_size > initial_size:
|
||||
# 文件有新内容,读取新行
|
||||
line = await f.readline()
|
||||
if line:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
initial_size = current_size
|
||||
else:
|
||||
# 没有新内容,短暂等待
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
yield f"data: {line}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as err:
|
||||
logger.error(f"日志读取异常: {err}")
|
||||
yield f"data: 日志读取异常: {err}\n\n"
|
||||
|
||||
# 根据length参数返回不同的响应
|
||||
if length == -1:
|
||||
# 返回全部日志作为文本响应
|
||||
if not log_path.exists():
|
||||
if not await log_path.exists():
|
||||
return Response(content="日志文件不存在!", media_type="text/plain")
|
||||
with open(log_path, "r", encoding='utf-8') as file:
|
||||
text = file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
return Response(content=text, media_type="text/plain")
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
return Response(content=text, media_type="text/plain")
|
||||
except Exception as e:
|
||||
return Response(content=f"读取日志文件失败: {e}", media_type="text/plain")
|
||||
else:
|
||||
# 返回SSE流响应
|
||||
return StreamingResponse(log_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
|
||||
def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询Github所有Release版本
|
||||
"""
|
||||
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
|
||||
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
|
||||
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res:
|
||||
ver_json = version_res.json()
|
||||
@@ -450,11 +509,11 @@ def ruletest(title: str,
|
||||
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
async def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试网络连通性
|
||||
@@ -462,43 +521,68 @@ def nettest(
|
||||
# 记录开始的毫秒数
|
||||
start_time = datetime.now()
|
||||
headers = None
|
||||
if "github" in url or "{GITHUB_PROXY}" in url:
|
||||
# 当前使用的加速代理
|
||||
proxy_name = ""
|
||||
if "github" in url:
|
||||
# 这是github的连通性测试
|
||||
headers = settings.GITHUB_HEADERS
|
||||
if "{GITHUB_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
|
||||
)
|
||||
headers = settings.GITHUB_HEADERS
|
||||
if settings.GITHUB_PROXY:
|
||||
proxy_name = "Github加速代理"
|
||||
if "{PIP_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{PIP_PROXY}",
|
||||
UrlUtils.standardize_base_url(
|
||||
settings.PIP_PROXY or "https://pypi.org/simple/"
|
||||
),
|
||||
)
|
||||
if settings.PIP_PROXY:
|
||||
proxy_name = "PIP加速代理"
|
||||
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
|
||||
url = url.replace(
|
||||
"{PIP_PROXY}",
|
||||
UrlUtils.standardize_base_url(settings.PIP_PROXY or "https://pypi.org/simple/"),
|
||||
)
|
||||
result = RequestUtils(
|
||||
result = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY if proxy else None,
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
ua=settings.USER_AGENT,
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
).get_res(url)
|
||||
# 计时结束的毫秒数
|
||||
end_time = datetime.now()
|
||||
time = round((end_time - start_time).total_seconds() * 1000)
|
||||
# 计算相关秒数
|
||||
if result is None:
|
||||
return schemas.Response(success=False, message="无法连接", data={"time": time})
|
||||
return schemas.Response(
|
||||
success=False, message=f"{proxy_name}无法连接", data={"time": time}
|
||||
)
|
||||
elif result.status_code == 200:
|
||||
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
|
||||
# 通常是被加速代理跳转到其它页面了
|
||||
logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
|
||||
if proxy_name:
|
||||
message = f"{proxy_name}已失效,请检查配置"
|
||||
else:
|
||||
message = f"无效响应,不匹配 {include}"
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"无效响应,不匹配 {include}",
|
||||
message=message,
|
||||
data={"time": time},
|
||||
)
|
||||
return schemas.Response(success=True, data={"time": time})
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False, message=f"错误码:{result.status_code}", data={"time": time}
|
||||
)
|
||||
if proxy_name:
|
||||
# 加速代理失败
|
||||
message = f"{proxy_name}已失效,错误码:{result.status_code}"
|
||||
else:
|
||||
message = f"错误码:{result.status_code}"
|
||||
if "github" in url:
|
||||
# 非加速代理访问github
|
||||
if result.status_code == 401:
|
||||
message = "Github Token已失效,请检查配置"
|
||||
elif result.status_code in {403, 429}:
|
||||
message = "触发限流,请配置Github Token"
|
||||
return schemas.Response(success=False, message=message, data={"time": time})
|
||||
|
||||
|
||||
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)
|
||||
|
||||
@@ -11,28 +11,28 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
|
||||
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID查询themoviedb所有季信息
|
||||
"""
|
||||
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
|
||||
if seasons_info:
|
||||
return seasons_info
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_similar(tmdbid: int,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_similar(tmdbid: int,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID查询类似电影/电视剧,type_name: 电影/电视剧
|
||||
"""
|
||||
mediatype = MediaType(type_name)
|
||||
if mediatype == MediaType.MOVIE:
|
||||
medias = TmdbChain().movie_similar(tmdbid=tmdbid)
|
||||
medias = await TmdbChain().async_movie_similar(tmdbid=tmdbid)
|
||||
elif mediatype == MediaType.TV:
|
||||
medias = TmdbChain().tv_similar(tmdbid=tmdbid)
|
||||
medias = await TmdbChain().async_tv_similar(tmdbid=tmdbid)
|
||||
else:
|
||||
return []
|
||||
if medias:
|
||||
@@ -41,17 +41,17 @@ def tmdb_similar(tmdbid: int,
|
||||
|
||||
|
||||
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_recommend(tmdbid: int,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_recommend(tmdbid: int,
|
||||
type_name: str,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID查询推荐电影/电视剧,type_name: 电影/电视剧
|
||||
"""
|
||||
mediatype = MediaType(type_name)
|
||||
if mediatype == MediaType.MOVIE:
|
||||
medias = TmdbChain().movie_recommend(tmdbid=tmdbid)
|
||||
medias = await TmdbChain().async_movie_recommend(tmdbid=tmdbid)
|
||||
elif mediatype == MediaType.TV:
|
||||
medias = TmdbChain().tv_recommend(tmdbid=tmdbid)
|
||||
medias = await TmdbChain().async_tv_recommend(tmdbid=tmdbid)
|
||||
else:
|
||||
return []
|
||||
if medias:
|
||||
@@ -60,63 +60,63 @@ def tmdb_recommend(tmdbid: int,
|
||||
|
||||
|
||||
@router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_collection(collection_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_collection(collection_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据合集ID查询合集详情
|
||||
"""
|
||||
medias = TmdbChain().tmdb_collection(collection_id=collection_id)
|
||||
medias = await TmdbChain().async_tmdb_collection(collection_id=collection_id)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias][(page - 1) * count:page * count]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson])
|
||||
def tmdb_credits(tmdbid: int,
|
||||
type_name: str,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_credits(tmdbid: int,
|
||||
type_name: str,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID查询演员阵容,type_name: 电影/电视剧
|
||||
"""
|
||||
mediatype = MediaType(type_name)
|
||||
if mediatype == MediaType.MOVIE:
|
||||
persons = TmdbChain().movie_credits(tmdbid=tmdbid, page=page)
|
||||
persons = await TmdbChain().async_movie_credits(tmdbid=tmdbid, page=page)
|
||||
elif mediatype == MediaType.TV:
|
||||
persons = TmdbChain().tv_credits(tmdbid=tmdbid, page=page)
|
||||
persons = await TmdbChain().async_tv_credits(tmdbid=tmdbid, page=page)
|
||||
else:
|
||||
return []
|
||||
return persons or []
|
||||
|
||||
|
||||
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
|
||||
def tmdb_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物详情
|
||||
"""
|
||||
return TmdbChain().person_detail(person_id=person_id)
|
||||
return await TmdbChain().async_person_detail(person_id=person_id)
|
||||
|
||||
|
||||
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
|
||||
def tmdb_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物参演作品
|
||||
"""
|
||||
medias = TmdbChain().person_credits(person_id=person_id, page=page)
|
||||
medias = await TmdbChain().async_person_credits(person_id=person_id, page=page)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
|
||||
def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID查询某季的所有信信息
|
||||
"""
|
||||
return TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)
|
||||
return await TmdbChain().async_tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)
|
||||
|
||||
@@ -9,14 +9,14 @@ from app.core.config import settings
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
from app.utils.crypto import HashUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/cache", summary="获取种子缓存", response_model=schemas.Response)
|
||||
def torrents_cache(_: User = Depends(get_current_active_superuser)):
|
||||
async def torrents_cache(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
获取当前种子缓存数据
|
||||
"""
|
||||
@@ -24,9 +24,9 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
# 获取spider和rss两种缓存
|
||||
if settings.SUBSCRIBE_MODE == "rss":
|
||||
cache_info = torrents_chain.get_torrents("rss")
|
||||
cache_info = await torrents_chain.async_get_torrents("rss")
|
||||
else:
|
||||
cache_info = torrents_chain.get_torrents("spider")
|
||||
cache_info = await torrents_chain.async_get_torrents("spider")
|
||||
|
||||
# 统计信息
|
||||
torrent_count = sum(len(torrents) for torrents in cache_info.values())
|
||||
@@ -62,9 +62,8 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
|
||||
})
|
||||
|
||||
|
||||
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存",
|
||||
response_model=schemas.Response)
|
||||
def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser)):
|
||||
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存", response_model=schemas.Response)
|
||||
async def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
删除指定的种子缓存
|
||||
:param domain: 站点域名
|
||||
@@ -76,7 +75,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
|
||||
|
||||
try:
|
||||
# 获取当前缓存
|
||||
cache_data = torrents_chain.get_torrents()
|
||||
cache_data = await torrents_chain.async_get_torrents()
|
||||
|
||||
if domain not in cache_data:
|
||||
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
|
||||
@@ -92,7 +91,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
|
||||
return schemas.Response(success=False, message="未找到指定的种子")
|
||||
|
||||
# 保存更新后的缓存
|
||||
torrents_chain.save_cache(cache_data, torrents_chain.cache_file)
|
||||
await torrents_chain.async_save_cache(cache_data, torrents_chain.cache_file)
|
||||
|
||||
return schemas.Response(success=True, message="种子删除成功")
|
||||
except Exception as e:
|
||||
@@ -100,14 +99,14 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
|
||||
|
||||
|
||||
@router.delete("/cache", summary="清理种子缓存", response_model=schemas.Response)
|
||||
def clear_cache(_: User = Depends(get_current_active_superuser)):
|
||||
async def clear_cache(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
清理所有种子缓存
|
||||
"""
|
||||
torrents_chain = TorrentsChain()
|
||||
|
||||
try:
|
||||
torrents_chain.clear_torrents()
|
||||
await torrents_chain.async_clear_torrents()
|
||||
return schemas.Response(success=True, message="种子缓存清理完成")
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=f"清理失败:{str(e)}")
|
||||
@@ -135,9 +134,9 @@ def refresh_cache(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
|
||||
@router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response)
|
||||
def reidentify_cache(domain: str, torrent_hash: str,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
async def reidentify_cache(domain: str, torrent_hash: str,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
重新识别指定的种子
|
||||
:param domain: 站点域名
|
||||
@@ -152,7 +151,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
|
||||
|
||||
try:
|
||||
# 获取当前缓存
|
||||
cache_data = torrents_chain.get_torrents()
|
||||
cache_data = await torrents_chain.async_get_torrents()
|
||||
|
||||
if domain not in cache_data:
|
||||
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
|
||||
@@ -168,14 +167,13 @@ def reidentify_cache(domain: str, torrent_hash: str,
|
||||
return schemas.Response(success=False, message="未找到指定的种子")
|
||||
|
||||
# 重新识别
|
||||
meta = MetaInfo(title=target_context.torrent_info.title,
|
||||
subtitle=target_context.torrent_info.description)
|
||||
meta = MetaInfo(title=target_context.torrent_info.title, subtitle=target_context.torrent_info.description)
|
||||
if tmdbid or doubanid:
|
||||
# 手动指定媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
|
||||
mediainfo = await media_chain.async_recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
|
||||
else:
|
||||
# 自动重新识别
|
||||
mediainfo = media_chain.recognize_by_meta(meta)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
# 创建空的媒体信息
|
||||
@@ -188,7 +186,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
|
||||
target_context.media_info = mediainfo
|
||||
|
||||
# 保存更新后的缓存
|
||||
torrents_chain.save_cache(cache_data, TorrentsChain().cache_file)
|
||||
await torrents_chain.async_save_cache(cache_data, TorrentsChain().cache_file)
|
||||
|
||||
return schemas.Response(success=True, message="重新识别完成", data={
|
||||
"media_name": mediainfo.title if mediainfo else "",
|
||||
|
||||
@@ -8,11 +8,14 @@ from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.schemas import MediaType, FileItem, ManualTransferItem
|
||||
|
||||
router = APIRouter()
|
||||
@@ -35,11 +38,19 @@ def query_name(path: str, filetype: str,
|
||||
if not new_path:
|
||||
return schemas.Response(success=False, message="未识别到新名称")
|
||||
if filetype == "dir":
|
||||
parents = Path(new_path).parents
|
||||
if len(parents) > 2:
|
||||
new_name = parents[1].name
|
||||
media_path = DirectoryHelper.get_media_root_path(
|
||||
rename_format=settings.RENAME_FORMAT(mediainfo.type),
|
||||
rename_path=Path(new_path),
|
||||
)
|
||||
if media_path:
|
||||
new_name = media_path.name
|
||||
else:
|
||||
new_name = parents[0].name
|
||||
# fallback
|
||||
parents = Path(new_path).parents
|
||||
if len(parents) > 2:
|
||||
new_name = parents[1].name
|
||||
else:
|
||||
new_name = parents[0].name
|
||||
else:
|
||||
new_name = Path(new_path).name
|
||||
return schemas.Response(success=True, data={
|
||||
@@ -48,7 +59,7 @@ def query_name(path: str, filetype: str,
|
||||
|
||||
|
||||
@router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob])
|
||||
def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询整理队列
|
||||
:param _: Token校验
|
||||
@@ -57,13 +68,15 @@ def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response)
|
||||
def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询整理队列
|
||||
:param fileitem: 文件项
|
||||
:param _: Token校验
|
||||
"""
|
||||
TransferChain().remove_from_queue(fileitem)
|
||||
# 取消整理
|
||||
global_vars.stop_transfer(fileitem.path)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -71,7 +84,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v
|
||||
def manual_transfer(transer_item: ManualTransferItem,
|
||||
background: Optional[bool] = False,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
手动转移,文件或历史记录,支持自定义剧集识别格式
|
||||
:param transer_item: 手工整理项
|
||||
@@ -98,7 +111,7 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
if history.dest_fileitem:
|
||||
# 删除旧的已整理文件
|
||||
dest_fileitem = FileItem(**history.dest_fileitem)
|
||||
state = StorageChain().delete_media_file(dest_fileitem, mtype=MediaType(history.type))
|
||||
state = StorageChain().delete_media_file(dest_fileitem)
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=f"{dest_fileitem.path} 删除失败")
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ import re
|
||||
from typing import Annotated, Any, List, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core.security import get_password_hash
|
||||
from app.db import get_db
|
||||
from app.db import get_async_db
|
||||
from app.db.models.user import User
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_user
|
||||
from app.db.user_oper import get_current_active_superuser_async, \
|
||||
get_current_active_user_async, get_current_active_user
|
||||
from app.db.userconfig_oper import UserConfigOper
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
@@ -17,50 +18,48 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", summary="所有用户", response_model=List[schemas.User])
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
async def list_users(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_superuser_async),
|
||||
) -> Any:
|
||||
"""
|
||||
查询用户列表
|
||||
"""
|
||||
users = current_user.list(db)
|
||||
return users
|
||||
return await current_user.async_list(db)
|
||||
|
||||
|
||||
@router.post("/", summary="新增用户", response_model=schemas.Response)
|
||||
def create_user(
|
||||
async def create_user(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
user_in: schemas.UserCreate,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
current_user: User = Depends(get_current_active_superuser_async),
|
||||
) -> Any:
|
||||
"""
|
||||
新增用户
|
||||
"""
|
||||
user = current_user.get_by_name(db, name=user_in.name)
|
||||
user = await current_user.async_get_by_name(db, name=user_in.name)
|
||||
if user:
|
||||
return schemas.Response(success=False, message="用户已存在")
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
user_info["hashed_password"] = get_password_hash(user_info["password"])
|
||||
user_info.pop("password")
|
||||
user = User(**user_info)
|
||||
user.create(db)
|
||||
return schemas.Response(success=True)
|
||||
user = await User(**user_info).async_create(db)
|
||||
return schemas.Response(success=True if user else False)
|
||||
|
||||
|
||||
@router.put("/", summary="更新用户", response_model=schemas.Response)
|
||||
def update_user(
|
||||
async def update_user(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
user_in: schemas.UserUpdate,
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
current_user: User = Depends(get_current_active_superuser_async),
|
||||
) -> Any:
|
||||
"""
|
||||
更新用户
|
||||
"""
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
# 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项
|
||||
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$'
|
||||
@@ -69,24 +68,24 @@ def update_user(
|
||||
message="密码需要同时包含字母、数字、特殊字符中的至少两项,且长度大于6位")
|
||||
user_info["hashed_password"] = get_password_hash(user_info["password"])
|
||||
user_info.pop("password")
|
||||
user = User.get_by_id(db, user_id=user_info["id"])
|
||||
user = await current_user.async_get_by_id(db, user_id=user_info["id"])
|
||||
user_name = user_info.get("name")
|
||||
if not user_name:
|
||||
return schemas.Response(success=False, message="用户名不能为空")
|
||||
# 新用户名去重
|
||||
users = User.list(db)
|
||||
users = await current_user.async_list(db)
|
||||
for u in users:
|
||||
if u.name == user_name and u.id != user_info["id"]:
|
||||
return schemas.Response(success=False, message="用户名已被使用")
|
||||
if not user:
|
||||
return schemas.Response(success=False, message="用户不存在")
|
||||
user.update(db, user_info)
|
||||
await user.async_update(db, user_info)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/current", summary="当前登录用户信息", response_model=schemas.User)
|
||||
def read_current_user(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
async def read_current_user(
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""
|
||||
当前登录用户信息
|
||||
@@ -95,18 +94,18 @@ def read_current_user(
|
||||
|
||||
|
||||
@router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response)
|
||||
def upload_avatar(user_id: int, db: Session = Depends(get_db), file: UploadFile = File(...),
|
||||
_: User = Depends(get_current_active_user)):
|
||||
async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), file: UploadFile = File(...),
|
||||
_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
上传用户头像
|
||||
"""
|
||||
# 将文件转换为Base64
|
||||
file_base64 = base64.b64encode(file.file.read())
|
||||
# 更新到用户表
|
||||
user = User.get(db, user_id)
|
||||
user = await User.async_get(db, user_id)
|
||||
if not user:
|
||||
return schemas.Response(success=False, message="用户不存在")
|
||||
user.update(db, {
|
||||
await user.async_update(db, {
|
||||
"avatar": f"data:image/ico;base64,{file_base64}"
|
||||
})
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
@@ -121,31 +120,31 @@ def otp_generate(
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
def otp_judge(
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
def otp_disable(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
current_user.update_otp_by_name(db, current_user.name, False, "")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any:
|
||||
user: User = User.get_by_name(db, userid)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
@@ -165,9 +164,9 @@ def get_config(key: str,
|
||||
|
||||
@router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response)
|
||||
def set_config(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
更新用户配置
|
||||
@@ -177,49 +176,49 @@ def set_config(
|
||||
|
||||
|
||||
@router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response)
|
||||
def delete_user_by_id(
|
||||
async def delete_user_by_id(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
current_user: User = Depends(get_current_active_superuser_async),
|
||||
) -> Any:
|
||||
"""
|
||||
通过唯一ID删除用户
|
||||
"""
|
||||
user = current_user.get_by_id(db, user_id=user_id)
|
||||
user = await current_user.async_get_by_id(db, user_id=user_id)
|
||||
if not user:
|
||||
return schemas.Response(success=False, message="用户不存在")
|
||||
user.delete_by_id(db, user_id)
|
||||
await current_user.async_delete(db, user_id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response)
|
||||
def delete_user_by_name(
|
||||
async def delete_user_by_name(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
user_name: str,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
current_user: User = Depends(get_current_active_superuser_async),
|
||||
) -> Any:
|
||||
"""
|
||||
通过用户名删除用户
|
||||
"""
|
||||
user = current_user.get_by_name(db, name=user_name)
|
||||
user = await current_user.async_get_by_name(db, name=user_name)
|
||||
if not user:
|
||||
return schemas.Response(success=False, message="用户不存在")
|
||||
user.delete_by_name(db, user_name)
|
||||
await current_user.async_delete(db, user.id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/{username}", summary="用户详情", response_model=schemas.User)
|
||||
def read_user_by_name(
|
||||
async def read_user_by_name(
|
||||
username: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
) -> Any:
|
||||
"""
|
||||
查询用户详情
|
||||
"""
|
||||
user = current_user.get_by_name(db, name=username)
|
||||
user = await current_user.async_get_by_name(db, name=username)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
||||
@@ -32,8 +32,8 @@ async def webhook_message(background_tasks: BackgroundTasks,
|
||||
|
||||
|
||||
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
|
||||
def webhook_message(background_tasks: BackgroundTasks,
|
||||
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
async def webhook_message(background_tasks: BackgroundTasks,
|
||||
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
Webhook响应,配置请求中需要添加参数:token=API_TOKEN&source=媒体服务器名
|
||||
"""
|
||||
|
||||
@@ -1,51 +1,59 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.db import get_db
|
||||
from app.db.models.workflow import Workflow
|
||||
from app.core.security import verify_token
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import Workflow
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.helper.workflow import WorkflowHelper
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.types import EventType, EVENT_TYPE_NAMES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
|
||||
def list_workflows(db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
async def list_workflows(db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取工作流列表
|
||||
"""
|
||||
return Workflow.list(db)
|
||||
return await WorkflowOper(db).async_list()
|
||||
|
||||
|
||||
@router.post("/", summary="创建工作流", response_model=schemas.Response)
|
||||
def create_workflow(workflow: schemas.Workflow,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
async def create_workflow(workflow: schemas.Workflow,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
创建工作流
|
||||
"""
|
||||
if Workflow.get_by_name(db, workflow.name):
|
||||
if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
if not workflow.add_time:
|
||||
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
|
||||
if not workflow.state:
|
||||
workflow.state = "P"
|
||||
Workflow(**workflow.dict()).create(db)
|
||||
if not workflow.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
workflow_obj = Workflow(**workflow.model_dump())
|
||||
await workflow_obj.async_create(db)
|
||||
return schemas.Response(success=True, message="创建工作流成功")
|
||||
|
||||
|
||||
@router.get("/plugin/actions", summary="查询插件动作", response_model=List[dict])
|
||||
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取所有动作
|
||||
"""
|
||||
@@ -53,60 +61,124 @@ def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends
|
||||
|
||||
|
||||
@router.get("/actions", summary="所有动作", response_model=List[dict])
|
||||
def list_actions(_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
async def list_actions(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取所有动作
|
||||
"""
|
||||
return WorkFlowManager().list_actions()
|
||||
|
||||
|
||||
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
|
||||
def get_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
@router.get("/event_types", summary="获取所有事件类型", response_model=List[dict])
|
||||
async def get_event_types(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取工作流详情
|
||||
获取所有事件类型
|
||||
"""
|
||||
return Workflow.get(db, workflow_id)
|
||||
return [{
|
||||
"title": EVENT_TYPE_NAMES.get(event_type, event_type.name),
|
||||
"value": event_type.value
|
||||
} for event_type in EventType]
|
||||
|
||||
|
||||
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
|
||||
def update_workflow(workflow: schemas.Workflow,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
@router.post("/share", summary="分享工作流", response_model=schemas.Response)
|
||||
async def workflow_share(
|
||||
workflow: schemas.WorkflowShare,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
更新工作流
|
||||
分享工作流
|
||||
"""
|
||||
wf = Workflow.get(db, workflow.id)
|
||||
if not wf:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
wf.update(db, workflow.dict())
|
||||
return schemas.Response(success=True, message="更新成功")
|
||||
if not workflow.id or not workflow.share_title or not workflow.share_user:
|
||||
return schemas.Response(success=False, message="请填写工作流ID、分享标题和分享人")
|
||||
|
||||
state, errmsg = await WorkflowHelper().async_workflow_share(workflow_id=workflow.id,
|
||||
share_title=workflow.share_title or "",
|
||||
share_comment=workflow.share_comment or "",
|
||||
share_user=workflow.share_user or "")
|
||||
return schemas.Response(success=state, message=errmsg)
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
|
||||
def delete_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
|
||||
async def workflow_share_delete(
|
||||
share_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
删除工作流
|
||||
删除分享
|
||||
"""
|
||||
workflow = Workflow.get(db, workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 删除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
# 删除工作流
|
||||
Workflow.delete(db, workflow_id)
|
||||
# 删除缓存
|
||||
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
|
||||
return schemas.Response(success=True, message="删除成功")
|
||||
state, errmsg = await WorkflowHelper().async_share_delete(share_id=share_id)
|
||||
return schemas.Response(success=state, message=errmsg)
|
||||
|
||||
|
||||
@router.post("/fork", summary="复用工作流", response_model=schemas.Response)
|
||||
async def workflow_fork(
|
||||
workflow: schemas.WorkflowShare,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.User = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
复用工作流
|
||||
"""
|
||||
if not workflow.name:
|
||||
return schemas.Response(success=False, message="工作流名称不能为空")
|
||||
|
||||
# 解析JSON数据,添加错误处理
|
||||
try:
|
||||
actions = json.loads(workflow.actions or "[]")
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="actions字段JSON格式错误")
|
||||
|
||||
try:
|
||||
flows = json.loads(workflow.flows or "[]")
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="flows字段JSON格式错误")
|
||||
|
||||
try:
|
||||
context = json.loads(workflow.context or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="context字段JSON格式错误")
|
||||
|
||||
# 创建工作流
|
||||
workflow_dict = {
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"timer": workflow.timer,
|
||||
"trigger_type": workflow.trigger_type or "timer",
|
||||
"event_type": workflow.event_type,
|
||||
"event_conditions": json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {},
|
||||
"actions": actions,
|
||||
"flows": flows,
|
||||
"context": context,
|
||||
"state": "P" # 默认暂停状态
|
||||
}
|
||||
|
||||
# 检查名称是否重复
|
||||
workflow_oper = WorkflowOper(db)
|
||||
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
|
||||
# 创建新工作流
|
||||
workflow = await Workflow(**workflow_dict).async_create(db)
|
||||
|
||||
# 更新复用次数
|
||||
if workflow:
|
||||
await WorkflowHelper().async_workflow_fork(share_id=workflow.id)
|
||||
|
||||
return schemas.Response(success=True, message="复用成功")
|
||||
|
||||
|
||||
@router.get("/shares", summary="查询分享的工作流", response_model=List[schemas.WorkflowShare])
|
||||
async def workflow_shares(
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询分享的工作流
|
||||
"""
|
||||
return await WorkflowHelper().async_get_shares(name=name, page=page, count=count)
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/run", summary="执行工作流", response_model=schemas.Response)
|
||||
def run_workflow(workflow_id: int,
|
||||
from_begin: Optional[bool] = True,
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
执行工作流
|
||||
"""
|
||||
@@ -119,15 +191,19 @@ def run_workflow(workflow_id: int,
|
||||
@router.post("/{workflow_id}/start", summary="启用工作流", response_model=schemas.Response)
|
||||
def start_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
启用工作流
|
||||
"""
|
||||
workflow = Workflow.get(db, workflow_id)
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 添加定时任务
|
||||
Scheduler().update_workflow_job(workflow)
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
# 添加定时任务
|
||||
Scheduler().update_workflow_job(workflow)
|
||||
else:
|
||||
# 事件触发:添加到事件触发器
|
||||
WorkFlowManager().load_workflow_events(workflow_id)
|
||||
# 更新状态
|
||||
workflow.update_state(db, workflow_id, "W")
|
||||
return schemas.Response(success=True)
|
||||
@@ -136,15 +212,20 @@ def start_workflow(workflow_id: int,
|
||||
@router.post("/{workflow_id}/pause", summary="停用工作流", response_model=schemas.Response)
|
||||
def pause_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
停用工作流
|
||||
"""
|
||||
workflow = Workflow.get(db, workflow_id)
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 删除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
# 根据触发类型进行不同处理
|
||||
if workflow.trigger_type == "timer":
|
||||
# 定时触发:移除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
elif workflow.trigger_type == "event":
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 停止工作流
|
||||
global_vars.stop_workflow(workflow_id)
|
||||
# 更新状态
|
||||
@@ -153,19 +234,77 @@ def pause_workflow(workflow_id: int,
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/reset", summary="重置工作流", response_model=schemas.Response)
|
||||
def reset_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
|
||||
async def reset_workflow(workflow_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
重置工作流
|
||||
"""
|
||||
workflow = Workflow.get(db, workflow_id)
|
||||
workflow = await WorkflowOper(db).async_get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 停止工作流
|
||||
global_vars.stop_workflow(workflow_id)
|
||||
# 重置工作流
|
||||
workflow.reset(db, workflow_id, reset_count=True)
|
||||
await Workflow.async_reset(db, workflow_id, reset_count=True)
|
||||
# 删除缓存
|
||||
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
|
||||
async def get_workflow(workflow_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取工作流详情
|
||||
"""
|
||||
return await WorkflowOper(db).async_get(workflow_id)
|
||||
|
||||
|
||||
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
|
||||
def update_workflow(workflow: schemas.Workflow,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
更新工作流
|
||||
"""
|
||||
if not workflow.id:
|
||||
return schemas.Response(success=False, message="工作流ID不能为空")
|
||||
workflow_oper = WorkflowOper(db)
|
||||
wf = workflow_oper.get(workflow.id)
|
||||
if not wf:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not wf.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
Scheduler().update_workflow_job(updated_workflow)
|
||||
# 更新事件注册
|
||||
WorkFlowManager().update_workflow_event(updated_workflow)
|
||||
return schemas.Response(success=True, message="更新成功")
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
|
||||
def delete_workflow(workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
删除工作流
|
||||
"""
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
# 定时触发:删除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
else:
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 删除工作流
|
||||
Workflow.delete(db, workflow_id)
|
||||
# 删除缓存
|
||||
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
|
||||
return schemas.Response(success=True, message="删除成功")
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from typing import Any, List, Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.tvdb import TvdbChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.tvdb import TvdbChain
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_apikey
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_async_db
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.schemas import RadarrMovie, SonarrSeries
|
||||
from app.schemas.types import MediaType
|
||||
@@ -19,7 +20,7 @@ arr_router = APIRouter(tags=['servarr'])
|
||||
|
||||
|
||||
@arr_router.get("/system/status", summary="系统状态")
|
||||
def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
async def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
模拟Radarr、Sonarr系统状态
|
||||
"""
|
||||
@@ -73,7 +74,7 @@ def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
|
||||
|
||||
@arr_router.get("/qualityProfile", summary="质量配置")
|
||||
def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
async def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
模拟Radarr、Sonarr质量配置
|
||||
"""
|
||||
@@ -114,7 +115,7 @@ def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
|
||||
|
||||
@arr_router.get("/rootfolder", summary="根目录")
|
||||
def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
async def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
模拟Radarr、Sonarr根目录
|
||||
"""
|
||||
@@ -130,7 +131,7 @@ def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
|
||||
|
||||
@arr_router.get("/tag", summary="标签")
|
||||
def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
async def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
模拟Radarr、Sonarr标签
|
||||
"""
|
||||
@@ -143,7 +144,7 @@ def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
|
||||
|
||||
@arr_router.get("/languageprofile", summary="语言")
|
||||
def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
async def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
模拟Radarr、Sonarr语言
|
||||
"""
|
||||
@@ -169,7 +170,7 @@ def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
|
||||
|
||||
@arr_router.get("/movie", summary="所有订阅电影", response_model=List[schemas.RadarrMovie])
|
||||
def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
查询Rardar电影
|
||||
"""
|
||||
@@ -240,7 +241,7 @@ def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
|
||||
"""
|
||||
# 查询所有电影订阅
|
||||
result = []
|
||||
subscribes = Subscribe.list(db)
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
if subscribe.type != MediaType.MOVIE.value:
|
||||
continue
|
||||
@@ -306,11 +307,12 @@ def arr_movie_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db: S
|
||||
|
||||
|
||||
@arr_router.get("/movie/{mid}", summary="电影订阅详情", response_model=schemas.RadarrMovie)
|
||||
def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
|
||||
db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
查询Rardar电影订阅
|
||||
"""
|
||||
subscribe = Subscribe.get(db, mid)
|
||||
subscribe = await Subscribe.async_get(db, mid)
|
||||
if subscribe:
|
||||
return RadarrMovie(
|
||||
id=subscribe.id,
|
||||
@@ -332,25 +334,25 @@ def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
|
||||
|
||||
|
||||
@arr_router.post("/movie", summary="新增电影订阅")
|
||||
def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
|
||||
movie: RadarrMovie,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
async def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
|
||||
movie: RadarrMovie,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
新增Rardar电影订阅
|
||||
"""
|
||||
# 检查订阅是否已存在
|
||||
subscribe = Subscribe.get_by_tmdbid(db, movie.tmdbId)
|
||||
subscribe = await Subscribe.async_get_by_tmdbid(db, movie.tmdbId)
|
||||
if subscribe:
|
||||
return {
|
||||
"id": subscribe.id
|
||||
}
|
||||
# 添加订阅
|
||||
sid, message = SubscribeChain().add(title=movie.title,
|
||||
year=movie.year,
|
||||
mtype=MediaType.MOVIE,
|
||||
tmdbid=movie.tmdbId,
|
||||
username="Seerr")
|
||||
sid, message = await SubscribeChain().async_add(title=movie.title,
|
||||
year=movie.year,
|
||||
mtype=MediaType.MOVIE,
|
||||
tmdbid=movie.tmdbId,
|
||||
username="Seerr")
|
||||
if sid:
|
||||
return {
|
||||
"id": sid
|
||||
@@ -363,13 +365,14 @@ def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
|
||||
|
||||
|
||||
@arr_router.delete("/movie/{mid}", summary="删除电影订阅", response_model=schemas.Response)
|
||||
def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
|
||||
db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
删除Rardar电影订阅
|
||||
"""
|
||||
subscribe = Subscribe.get(db, mid)
|
||||
subscribe = await Subscribe.async_get(db, mid)
|
||||
if subscribe:
|
||||
subscribe.delete(db, mid)
|
||||
await subscribe.async_delete(db, mid)
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -379,7 +382,7 @@ def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Se
|
||||
|
||||
|
||||
@arr_router.get("/series", summary="所有剧集", response_model=List[schemas.SonarrSeries])
|
||||
def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_series(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
查询Sonarr剧集
|
||||
"""
|
||||
@@ -487,7 +490,7 @@ def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
|
||||
"""
|
||||
# 查询所有电视剧订阅
|
||||
result = []
|
||||
subscribes = Subscribe.list(db)
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
if subscribe.type != MediaType.TV.value:
|
||||
continue
|
||||
@@ -605,11 +608,12 @@ def arr_series_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db:
|
||||
|
||||
|
||||
@arr_router.get("/series/{tid}", summary="剧集详情")
|
||||
def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)],
|
||||
db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
查询Sonarr剧集
|
||||
"""
|
||||
subscribe = Subscribe.get(db, tid)
|
||||
subscribe = await Subscribe.async_get(db, tid)
|
||||
if subscribe:
|
||||
return SonarrSeries(
|
||||
id=subscribe.id,
|
||||
@@ -639,17 +643,17 @@ def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
|
||||
|
||||
|
||||
@arr_router.post("/series", summary="新增剧集订阅")
|
||||
def arr_add_series(tv: schemas.SonarrSeries,
|
||||
_: Annotated[str, Depends(verify_apikey)],
|
||||
db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_add_series(tv: schemas.SonarrSeries,
|
||||
_: Annotated[str, Depends(verify_apikey)],
|
||||
db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
新增Sonarr剧集订阅
|
||||
"""
|
||||
# 检查订阅是否存在
|
||||
left_seasons = []
|
||||
for season in tv.seasons:
|
||||
subscribe = Subscribe.get_by_tmdbid(db, tmdbid=tv.tmdbId,
|
||||
season=season.get("seasonNumber"))
|
||||
subscribe = await Subscribe.async_get_by_tmdbid(db, tmdbid=tv.tmdbId,
|
||||
season=season.get("seasonNumber"))
|
||||
if subscribe:
|
||||
continue
|
||||
left_seasons.append(season)
|
||||
@@ -664,12 +668,12 @@ def arr_add_series(tv: schemas.SonarrSeries,
|
||||
for season in left_seasons:
|
||||
if not season.get("monitored"):
|
||||
continue
|
||||
sid, message = SubscribeChain().add(title=tv.title,
|
||||
year=tv.year,
|
||||
season=season.get("seasonNumber"),
|
||||
tmdbid=tv.tmdbId,
|
||||
mtype=MediaType.TV,
|
||||
username="Seerr")
|
||||
sid, message = await SubscribeChain().async_add(title=tv.title,
|
||||
year=tv.year,
|
||||
season=season.get("seasonNumber"),
|
||||
tmdbid=tv.tmdbId,
|
||||
mtype=MediaType.TV,
|
||||
username="Seerr")
|
||||
|
||||
if sid:
|
||||
return {
|
||||
@@ -683,21 +687,22 @@ def arr_add_series(tv: schemas.SonarrSeries,
|
||||
|
||||
|
||||
@arr_router.put("/series", summary="更新剧集订阅")
|
||||
def arr_update_series(tv: schemas.SonarrSeries) -> Any:
|
||||
async def arr_update_series(tv: schemas.SonarrSeries, _: Annotated[str, Depends(verify_apikey)]) -> Any:
|
||||
"""
|
||||
更新Sonarr剧集订阅
|
||||
"""
|
||||
return arr_add_series(tv)
|
||||
return await arr_add_series(tv)
|
||||
|
||||
|
||||
@arr_router.delete("/series/{tid}", summary="删除剧集订阅")
|
||||
def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
|
||||
async def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)],
|
||||
db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
删除Sonarr剧集订阅
|
||||
"""
|
||||
subscribe = Subscribe.get(db, tid)
|
||||
subscribe = await Subscribe.async_get(db, tid)
|
||||
if subscribe:
|
||||
subscribe.delete(db, tid)
|
||||
await subscribe.async_delete(db, tid)
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -2,6 +2,8 @@ import gzip
|
||||
import json
|
||||
from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
@@ -19,7 +21,7 @@ class GzipRequest(Request):
|
||||
body = await super().body()
|
||||
if "gzip" in self.headers.getlist("Content-Encoding"):
|
||||
body = gzip.decompress(body)
|
||||
self._body = body # noqa
|
||||
self._body = body # noqa
|
||||
return self._body
|
||||
|
||||
|
||||
@@ -50,12 +52,12 @@ cookie_router = APIRouter(route_class=GzipRoute,
|
||||
|
||||
|
||||
@cookie_router.get("/", response_class=PlainTextResponse)
|
||||
def get_root():
|
||||
async def get_root():
|
||||
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
|
||||
|
||||
|
||||
@cookie_router.post("/", response_class=PlainTextResponse)
|
||||
def post_root():
|
||||
async def post_root():
|
||||
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
|
||||
|
||||
|
||||
@@ -64,31 +66,31 @@ async def update_cookie(req: schemas.CookieData):
|
||||
"""
|
||||
上传Cookie数据
|
||||
"""
|
||||
file_path = settings.COOKIE_PATH / f"{req.uuid}.json"
|
||||
file_path = AsyncPath(settings.COOKIE_PATH) / f"{req.uuid}.json"
|
||||
content = json.dumps({"encrypted": req.encrypted})
|
||||
with open(file_path, encoding="utf-8", mode="w") as file:
|
||||
file.write(content)
|
||||
with open(file_path, encoding="utf-8", mode="r") as file:
|
||||
read_content = file.read()
|
||||
async with aiofiles.open(file_path, encoding="utf-8", mode="w") as file:
|
||||
await file.write(content)
|
||||
async with aiofiles.open(file_path, encoding="utf-8", mode="r") as file:
|
||||
read_content = await file.read()
|
||||
if read_content == content:
|
||||
return {"action": "done"}
|
||||
else:
|
||||
return {"action": "error"}
|
||||
|
||||
|
||||
def load_encrypt_data(uuid: str) -> Dict[str, Any]:
|
||||
async def load_encrypt_data(uuid: str) -> Dict[str, Any]:
|
||||
"""
|
||||
加载本地加密原始数据
|
||||
"""
|
||||
file_path = settings.COOKIE_PATH / f"{uuid}.json"
|
||||
file_path = AsyncPath(settings.COOKIE_PATH) / f"{uuid}.json"
|
||||
|
||||
# 检查文件是否存在
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# 读取文件
|
||||
with open(file_path, encoding="utf-8", mode="r") as file:
|
||||
read_content = file.read()
|
||||
async with aiofiles.open(file_path, encoding="utf-8", mode="r") as file:
|
||||
read_content = await file.read()
|
||||
data = json.loads(read_content.encode("utf-8"))
|
||||
return data
|
||||
|
||||
@@ -120,7 +122,7 @@ async def get_cookie(
|
||||
"""
|
||||
GET 下载加密数据
|
||||
"""
|
||||
return load_encrypt_data(uuid)
|
||||
return await load_encrypt_data(uuid)
|
||||
|
||||
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
@@ -130,5 +132,5 @@ async def post_cookie(
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = load_encrypt_data(uuid)
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import inspect
|
||||
import pickle
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
@@ -6,9 +7,11 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
@@ -43,58 +46,127 @@ class ChainBase(metaclass=ABCMeta):
|
||||
send_callback=self.run_module
|
||||
)
|
||||
self.pluginmanager = PluginManager()
|
||||
self.filecache = FileCache()
|
||||
self.async_filecache = AsyncFileCache()
|
||||
|
||||
@staticmethod
|
||||
def load_cache(filename: str) -> Any:
|
||||
def load_cache(self, filename: str) -> Any:
|
||||
"""
|
||||
从本地加载缓存
|
||||
加载缓存
|
||||
"""
|
||||
cache_path = settings.TEMP_PATH / filename
|
||||
if cache_path.exists():
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except Exception as err:
|
||||
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
content = self.filecache.get(filename)
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as err:
|
||||
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def save_cache(cache: Any, filename: str) -> None:
|
||||
async def async_load_cache(self, filename: str) -> Any:
|
||||
"""
|
||||
保存缓存到本地
|
||||
异步加载缓存
|
||||
"""
|
||||
content = await self.async_filecache.get(filename)
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as err:
|
||||
logger.error(f"异步加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
|
||||
async def async_save_cache(self, cache: Any, filename: str) -> None:
|
||||
"""
|
||||
异步保存缓存
|
||||
"""
|
||||
try:
|
||||
with open(settings.TEMP_PATH / filename, 'wb') as f:
|
||||
pickle.dump(cache, f) # noqa
|
||||
await self.async_filecache.set(filename, pickle.dumps(cache))
|
||||
except Exception as err:
|
||||
logger.error(f"异步保存缓存 {filename} 出错:{str(err)}")
|
||||
return
|
||||
|
||||
def save_cache(self, cache: Any, filename: str) -> None:
|
||||
"""
|
||||
保存缓存
|
||||
"""
|
||||
try:
|
||||
self.filecache.set(filename, pickle.dumps(cache))
|
||||
except Exception as err:
|
||||
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
|
||||
return
|
||||
|
||||
def remove_cache(self, filename: str) -> None:
|
||||
"""
|
||||
删除缓存,同时删除Redis和本地缓存
|
||||
"""
|
||||
self.filecache.delete(filename)
|
||||
|
||||
async def async_remove_cache(self, filename: str) -> None:
|
||||
"""
|
||||
异步删除缓存,同时删除Redis和本地缓存
|
||||
"""
|
||||
await self.async_filecache.delete(filename)
|
||||
|
||||
@staticmethod
|
||||
def remove_cache(filename: str) -> None:
|
||||
def __is_valid_empty(ret):
|
||||
"""
|
||||
删除本地缓存
|
||||
判断结果是否为空
|
||||
"""
|
||||
cache_path = settings.TEMP_PATH / filename
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
if isinstance(ret, tuple):
|
||||
return all(value is None for value in ret)
|
||||
else:
|
||||
return ret is None
|
||||
|
||||
def run_module(self, method: str, *args, **kwargs) -> Any:
|
||||
def __handle_plugin_error(self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs):
|
||||
"""
|
||||
运行包含该方法的所有模块,然后返回结果
|
||||
当kwargs包含命名参数raise_exception时,如模块方法抛出异常且raise_exception为True,则同步抛出异常
|
||||
处理插件模块执行错误
|
||||
"""
|
||||
if kwargs.get("raise_exception"):
|
||||
raise
|
||||
logger.error(
|
||||
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
|
||||
message=str(err),
|
||||
role="plugin")
|
||||
self.eventmanager.send_event(
|
||||
EventType.SystemError,
|
||||
{
|
||||
"type": "plugin",
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin_name,
|
||||
"plugin_method": method,
|
||||
"error": str(err),
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
)
|
||||
|
||||
def is_result_empty(ret):
|
||||
"""
|
||||
判断结果是否为空
|
||||
"""
|
||||
if isinstance(ret, tuple):
|
||||
return all(value is None for value in ret)
|
||||
else:
|
||||
return ret is None
|
||||
def __handle_system_error(self, err: Exception, module_id: str, module_name: str, method: str, **kwargs):
|
||||
"""
|
||||
处理系统模块执行错误
|
||||
"""
|
||||
if kwargs.get("raise_exception"):
|
||||
raise
|
||||
logger.error(
|
||||
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"{module_name}发生了错误",
|
||||
message=str(err),
|
||||
role="system")
|
||||
self.eventmanager.send_event(
|
||||
EventType.SystemError,
|
||||
{
|
||||
"type": "module",
|
||||
"module_id": module_id,
|
||||
"module_name": module_name,
|
||||
"module_method": method,
|
||||
"error": str(err),
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
)
|
||||
|
||||
result = None
|
||||
# 插件模块
|
||||
def __execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
|
||||
"""
|
||||
执行插件模块
|
||||
"""
|
||||
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
|
||||
plugin_id, plugin_name = plugin
|
||||
if method in module_dict:
|
||||
@@ -102,7 +174,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if func:
|
||||
try:
|
||||
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
|
||||
if is_result_empty(result):
|
||||
if self.__is_valid_empty(result):
|
||||
# 返回None,第一次执行或者需继续执行下一模块
|
||||
result = func(*args, **kwargs)
|
||||
elif isinstance(result, list):
|
||||
@@ -113,29 +185,46 @@ class ChainBase(metaclass=ABCMeta):
|
||||
else:
|
||||
break
|
||||
except Exception as err:
|
||||
if kwargs.get("raise_exception"):
|
||||
raise
|
||||
logger.error(
|
||||
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
|
||||
message=str(err),
|
||||
role="plugin")
|
||||
self.eventmanager.send_event(
|
||||
EventType.SystemError,
|
||||
{
|
||||
"type": "plugin",
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin_name,
|
||||
"plugin_method": method,
|
||||
"error": str(err),
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
)
|
||||
if not is_result_empty(result) and not isinstance(result, list):
|
||||
# 插件模块返回结果不为空且不是列表,直接返回
|
||||
return result
|
||||
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
# 系统模块
|
||||
async def __async_execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
|
||||
"""
|
||||
异步执行插件模块
|
||||
"""
|
||||
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
|
||||
plugin_id, plugin_name = plugin
|
||||
if method in module_dict:
|
||||
func = module_dict[method]
|
||||
if func:
|
||||
try:
|
||||
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
|
||||
if self.__is_valid_empty(result):
|
||||
# 返回None,第一次执行或者需继续执行下一模块
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
result = await run_in_threadpool(func, *args, **kwargs)
|
||||
elif isinstance(result, list):
|
||||
# 返回为列表,有多个模块运行结果时进行合并
|
||||
if inspect.iscoroutinefunction(func):
|
||||
temp = await func(*args, **kwargs)
|
||||
else:
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
temp = await run_in_threadpool(func, *args, **kwargs)
|
||||
if isinstance(temp, list):
|
||||
result.extend(temp)
|
||||
else:
|
||||
break
|
||||
except Exception as err:
|
||||
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
def __execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
|
||||
"""
|
||||
执行系统模块
|
||||
"""
|
||||
logger.debug(f"请求系统模块执行:{method} ...")
|
||||
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
|
||||
module_id = module.__class__.__name__
|
||||
@@ -146,7 +235,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
module_name = module_id
|
||||
try:
|
||||
func = getattr(module, method)
|
||||
if is_result_empty(result):
|
||||
if self.__is_valid_empty(result):
|
||||
# 返回None,第一次执行或者需继续执行下一模块
|
||||
result = func(*args, **kwargs)
|
||||
elif ObjectUtils.check_signature(func, result):
|
||||
@@ -161,26 +250,85 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
if kwargs.get("raise_exception"):
|
||||
raise
|
||||
logger.error(
|
||||
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"{module_name}发生了错误",
|
||||
message=str(err),
|
||||
role="system")
|
||||
self.eventmanager.send_event(
|
||||
EventType.SystemError,
|
||||
{
|
||||
"type": "module",
|
||||
"module_id": module_id,
|
||||
"module_name": module_name,
|
||||
"module_method": method,
|
||||
"error": str(err),
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
)
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
async def __async_execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
|
||||
"""
|
||||
异步执行系统模块
|
||||
"""
|
||||
logger.debug(f"请求系统模块执行:{method} ...")
|
||||
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
|
||||
module_id = module.__class__.__name__
|
||||
try:
|
||||
module_name = module.get_name()
|
||||
except Exception as err:
|
||||
logger.debug(f"获取模块名称出错:{str(err)}")
|
||||
module_name = module_id
|
||||
try:
|
||||
func = getattr(module, method)
|
||||
if self.__is_valid_empty(result):
|
||||
# 返回None,第一次执行或者需继续执行下一模块
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
elif ObjectUtils.check_signature(func, result):
|
||||
# 返回结果与方法签名一致,将结果传入
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(result)
|
||||
else:
|
||||
result = func(result)
|
||||
elif isinstance(result, list):
|
||||
# 返回为列表,有多个模块运行结果时进行合并
|
||||
if inspect.iscoroutinefunction(func):
|
||||
temp = await func(*args, **kwargs)
|
||||
else:
|
||||
temp = func(*args, **kwargs)
|
||||
if isinstance(temp, list):
|
||||
result.extend(temp)
|
||||
else:
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
def run_module(self, method: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
运行包含该方法的所有模块,然后返回结果
|
||||
当kwargs包含命名参数raise_exception时,如模块方法抛出异常且raise_exception为True,则同步抛出异常
|
||||
"""
|
||||
result = None
|
||||
|
||||
# 执行插件模块
|
||||
result = self.__execute_plugin_modules(method, result, *args, **kwargs)
|
||||
|
||||
if not self.__is_valid_empty(result) and not isinstance(result, list):
|
||||
# 插件模块返回结果不为空且不是列表,直接返回
|
||||
return result
|
||||
|
||||
# 执行系统模块
|
||||
return self.__execute_system_modules(method, result, *args, **kwargs)
|
||||
|
||||
async def async_run_module(self, method: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
异步运行包含该方法的所有模块,然后返回结果
|
||||
当kwargs包含命名参数raise_exception时,如模块方法抛出异常且raise_exception为True,则同步抛出异常
|
||||
支持异步和同步方法的混合调用
|
||||
"""
|
||||
result = None
|
||||
|
||||
# 执行插件模块
|
||||
result = await self.__async_execute_plugin_modules(method, result, *args, **kwargs)
|
||||
|
||||
if not self.__is_valid_empty(result) and not isinstance(result, list):
|
||||
# 插件模块返回结果不为空且不是列表,直接返回
|
||||
return result
|
||||
|
||||
# 执行系统模块
|
||||
return await self.__async_execute_system_modules(method, result, *args, **kwargs)
|
||||
|
||||
def recognize_media(self, meta: MetaBase = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
@@ -210,9 +358,44 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
with fresh(not cache):
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
async def async_recognize_media(self, meta: MetaBase = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
bangumiid: Optional[int] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
cache: bool = True) -> Optional[MediaInfo]:
|
||||
"""
|
||||
识别媒体信息,不含Fanart图片(异步版本)
|
||||
:param meta: 识别的元数据
|
||||
:param mtype: 识别的媒体类型,与tmdbid配套
|
||||
:param tmdbid: tmdbid
|
||||
:param doubanid: 豆瓣ID
|
||||
:param bangumiid: BangumiID
|
||||
:param episode_group: 剧集组
|
||||
:param cache: 是否使用缓存
|
||||
:return: 识别的媒体信息,包括剧集信息
|
||||
"""
|
||||
# 识别用名中含指定信息情形
|
||||
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
if not tmdbid and hasattr(meta, "tmdbid"):
|
||||
tmdbid = meta.tmdbid
|
||||
if not doubanid and hasattr(meta, "doubanid"):
|
||||
doubanid = meta.doubanid
|
||||
# 有tmdbid时不使用其它ID
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
|
||||
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
|
||||
@@ -229,6 +412,22 @@ class ChainBase(metaclass=ABCMeta):
|
||||
return self.run_module("match_doubaninfo", name=name, imdbid=imdbid,
|
||||
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
|
||||
|
||||
async def async_match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
|
||||
mtype: Optional[MediaType] = None, year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
raise_exception: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
搜索和匹配豆瓣信息(异步版本)
|
||||
:param name: 标题
|
||||
:param imdbid: imdbid
|
||||
:param mtype: 类型
|
||||
:param year: 年份
|
||||
:param season: 季
|
||||
:param raise_exception: 触发速率限制时是否抛出异常
|
||||
"""
|
||||
return await self.async_run_module("async_match_doubaninfo", name=name, imdbid=imdbid,
|
||||
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
|
||||
|
||||
def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
|
||||
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
@@ -241,6 +440,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
return self.run_module("match_tmdbinfo", name=name,
|
||||
mtype=mtype, year=year, season=season)
|
||||
|
||||
async def async_match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
|
||||
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
搜索和匹配TMDB信息(异步版本)
|
||||
:param name: 标题
|
||||
:param mtype: 类型
|
||||
:param year: 年份
|
||||
:param season: 季
|
||||
"""
|
||||
return await self.async_run_module("async_match_tmdbinfo", name=name,
|
||||
mtype=mtype, year=year, season=season)
|
||||
|
||||
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
|
||||
"""
|
||||
补充抓取媒体信息图片
|
||||
@@ -249,6 +460,14 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("obtain_images", mediainfo=mediainfo)
|
||||
|
||||
async def async_obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
|
||||
"""
|
||||
补充抓取媒体信息图片(异步版本)
|
||||
:param mediainfo: 识别的媒体信息
|
||||
:return: 更新后的媒体信息
|
||||
"""
|
||||
return await self.async_run_module("async_obtain_images", mediainfo=mediainfo)
|
||||
|
||||
def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType,
|
||||
image_type: MediaImageType, image_prefix: Optional[str] = None,
|
||||
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]:
|
||||
@@ -276,6 +495,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception)
|
||||
|
||||
async def async_douban_info(self, doubanid: str, mtype: Optional[MediaType] = None,
|
||||
raise_exception: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
获取豆瓣信息(异步版本)
|
||||
:param doubanid: 豆瓣ID
|
||||
:param mtype: 媒体类型
|
||||
:return: 豆瓣信息
|
||||
:param raise_exception: 触发速率限制时是否抛出异常
|
||||
"""
|
||||
return await self.async_run_module("async_douban_info", doubanid=doubanid, mtype=mtype,
|
||||
raise_exception=raise_exception)
|
||||
|
||||
def tvdb_info(self, tvdbid: int) -> Optional[dict]:
|
||||
"""
|
||||
获取TVDB信息
|
||||
@@ -294,6 +525,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
|
||||
|
||||
async def async_tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
获取TMDB信息(异步版本)
|
||||
:param tmdbid: int
|
||||
:param mtype: 媒体类型
|
||||
:param season: 季
|
||||
:return: TVDB信息
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
|
||||
|
||||
def bangumi_info(self, bangumiid: int) -> Optional[dict]:
|
||||
"""
|
||||
获取Bangumi信息
|
||||
@@ -302,6 +543,14 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("bangumi_info", bangumiid=bangumiid)
|
||||
|
||||
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
|
||||
"""
|
||||
获取Bangumi信息(异步版本)
|
||||
:param bangumiid: int
|
||||
:return: Bangumi信息
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any,
|
||||
args: Any) -> Optional[CommingMessage]:
|
||||
"""
|
||||
@@ -335,6 +584,14 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("search_medias", meta=meta)
|
||||
|
||||
async def async_search_medias(self, meta: MetaBase) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
搜索媒体信息(异步版本)
|
||||
:param meta: 识别的元数据
|
||||
:reutrn: 媒体信息列表
|
||||
"""
|
||||
return await self.async_run_module("async_search_medias", meta=meta)
|
||||
|
||||
def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
"""
|
||||
搜索人物信息
|
||||
@@ -342,6 +599,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("search_persons", name=name)
|
||||
|
||||
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
"""
|
||||
搜索人物信息(异步版本)
|
||||
:param name: 人物名称
|
||||
"""
|
||||
return await self.async_run_module("async_search_persons", name=name)
|
||||
|
||||
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
搜索集合信息
|
||||
@@ -349,21 +613,43 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("search_collections", name=name)
|
||||
|
||||
async def async_search_collections(self, name: str) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
搜索集合信息(异步版本)
|
||||
:param name: 集合名称
|
||||
"""
|
||||
return await self.async_run_module("async_search_collections", name=name)
|
||||
|
||||
def search_torrents(self, site: dict,
|
||||
keywords: List[str],
|
||||
keyword: str,
|
||||
mtype: Optional[MediaType] = None,
|
||||
page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
搜索一个站点的种子资源
|
||||
:param site: 站点
|
||||
:param keywords: 搜索关键词列表
|
||||
:param keyword: 搜索关键词
|
||||
:param mtype: 媒体类型
|
||||
:param page: 页码
|
||||
:reutrn: 资源列表
|
||||
"""
|
||||
return self.run_module("search_torrents", site=site, keywords=keywords,
|
||||
return self.run_module("search_torrents", site=site, keyword=keyword,
|
||||
mtype=mtype, page=page)
|
||||
|
||||
async def async_search_torrents(self, site: dict,
|
||||
keyword: str,
|
||||
mtype: Optional[MediaType] = None,
|
||||
page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
异步搜索一个站点的种子资源
|
||||
:param site: 站点
|
||||
:param keyword: 搜索关键词
|
||||
:param mtype: 媒体类型
|
||||
:param page: 页码
|
||||
:reutrn: 资源列表
|
||||
"""
|
||||
return await self.async_run_module("async_search_torrents", site=site, keyword=keyword,
|
||||
mtype=mtype, page=page)
|
||||
|
||||
def refresh_torrents(self, site: dict, keyword: Optional[str] = None,
|
||||
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
@@ -376,6 +662,19 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
|
||||
|
||||
async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None,
|
||||
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
异步获取站点最新一页的种子,多个站点需要多线程处理
|
||||
:param site: 站点
|
||||
:param keyword: 标题
|
||||
:param cat: 分类
|
||||
:param page: 页码
|
||||
:reutrn: 种子资源列表
|
||||
"""
|
||||
return await self.async_run_module("async_refresh_torrents",
|
||||
site=site, keyword=keyword, cat=cat, page=page)
|
||||
|
||||
def filter_torrents(self, rule_groups: List[str],
|
||||
torrent_list: List[TorrentInfo],
|
||||
mediainfo: MediaInfo = None) -> List[TorrentInfo]:
|
||||
@@ -389,13 +688,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
return self.run_module("filter_torrents", rule_groups=rule_groups,
|
||||
torrent_list=torrent_list, mediainfo=mediainfo)
|
||||
|
||||
def download(self, content: Union[Path, str], download_dir: Path, cookie: str,
|
||||
def download(self, content: Union[Path, str, bytes], download_dir: Path, cookie: str,
|
||||
episodes: Set[int] = None, category: Optional[str] = None, label: Optional[str] = None,
|
||||
downloader: Optional[str] = None
|
||||
) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]:
|
||||
"""
|
||||
根据种子文件,选择并添加下载任务
|
||||
:param content: 种子文件地址或者磁力链接
|
||||
:param content: 种子文件地址或者磁力链接或者种子内容
|
||||
:param download_dir: 下载目录
|
||||
:param cookie: cookie
|
||||
:param episodes: 需要下载的集数
|
||||
@@ -408,15 +707,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
cookie=cookie, episodes=episodes, category=category, label=label,
|
||||
downloader=downloader)
|
||||
|
||||
def download_added(self, context: Context, download_dir: Path, torrent_path: Path = None) -> None:
|
||||
def download_added(self, context: Context, download_dir: Path, torrent_content: Union[str, bytes] = None) -> None:
|
||||
"""
|
||||
添加下载任务成功后,从站点下载字幕,保存到下载目录
|
||||
:param context: 上下文,包括识别信息、媒体信息、种子信息
|
||||
:param download_dir: 下载目录
|
||||
:param torrent_path: 种子文件地址
|
||||
:param torrent_content: 种子内容,如果有则直接使用该内容,否则从context中获取种子文件路径
|
||||
:return: None,该方法可被多个模块同时处理
|
||||
"""
|
||||
return self.run_module("download_added", context=context, torrent_path=torrent_path,
|
||||
return self.run_module("download_added", context=context,
|
||||
torrent_content=torrent_content,
|
||||
download_dir=download_dir)
|
||||
|
||||
def list_torrents(self, status: TorrentStatus = None,
|
||||
@@ -552,9 +852,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.dict())
|
||||
self.messageoper.add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -601,15 +905,99 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
self.messagequeue.send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
async def async_post_message(self,
|
||||
message: Optional[Notification] = None,
|
||||
meta: Optional[MetaBase] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
torrentinfo: Optional[TorrentInfo] = None,
|
||||
transferinfo: Optional[TransferInfo] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
异步发送消息
|
||||
:param message: Notification实例
|
||||
:param meta: 元数据
|
||||
:param mediainfo: 媒体信息
|
||||
:param torrentinfo: 种子信息
|
||||
:param transferinfo: 文件整理信息
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
|
||||
if notify_action:
|
||||
# 'admin' 'user,admin' 'user' 'all'
|
||||
actions = notify_action.split(",")
|
||||
# 是否已发送管理员标志
|
||||
admin_sended = False
|
||||
send_orignal = False
|
||||
useroper = UserOper()
|
||||
for action in actions:
|
||||
send_message = copy.deepcopy(message)
|
||||
if action == "admin" and not admin_sended:
|
||||
# 仅发送管理员
|
||||
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
|
||||
# 读取管理员消息IDS
|
||||
send_message.targets = useroper.get_settings(settings.SUPERUSER)
|
||||
admin_sended = True
|
||||
elif action == "user" and send_message.username:
|
||||
# 发送对应用户
|
||||
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
|
||||
# 读取用户消息IDS
|
||||
send_message.targets = useroper.get_settings(send_message.username)
|
||||
if send_message.targets is None:
|
||||
# 没有找到用户
|
||||
if not admin_sended:
|
||||
# 回滚发送管理员
|
||||
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
|
||||
# 读取管理员消息IDS
|
||||
send_message.targets = useroper.get_settings(settings.SUPERUSER)
|
||||
admin_sended = True
|
||||
else:
|
||||
# 管理员发过了,此消息不发了
|
||||
logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户")
|
||||
continue
|
||||
elif send_message.username == settings.SUPERUSER:
|
||||
# 管理员同名已发送
|
||||
admin_sended = True
|
||||
else:
|
||||
# 按原消息发送全体
|
||||
if not admin_sended:
|
||||
send_orignal = True
|
||||
break
|
||||
# 按设定发送
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
await self.messagequeue.async_send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -620,7 +1008,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [media.to_dict() for media in medias]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
@@ -633,7 +1021,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
|
||||
@@ -57,3 +57,51 @@ class BangumiChain(ChainBase):
|
||||
:param person_id: 人物ID
|
||||
"""
|
||||
return self.run_module("bangumi_person_credits", person_id=person_id)
|
||||
|
||||
async def async_calendar(self) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取Bangumi每日放送(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_calendar")
|
||||
|
||||
async def async_discover(self, **kwargs) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
发现Bangumi番剧(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_discover", **kwargs)
|
||||
|
||||
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
|
||||
"""
|
||||
获取Bangumi信息(异步版本)
|
||||
:param bangumiid: BangumiID
|
||||
:return: Bangumi信息
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
|
||||
|
||||
async def async_bangumi_credits(self, bangumiid: int) -> List[schemas.MediaPerson]:
|
||||
"""
|
||||
根据BangumiID查询电影演职员表(异步版本)
|
||||
:param bangumiid: BangumiID
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_credits", bangumiid=bangumiid)
|
||||
|
||||
async def async_bangumi_recommend(self, bangumiid: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据BangumiID查询推荐电影(异步版本)
|
||||
:param bangumiid: BangumiID
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_recommend", bangumiid=bangumiid)
|
||||
|
||||
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
|
||||
"""
|
||||
根据人物ID查询Bangumi人物详情(异步版本)
|
||||
:param person_id: 人物ID
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_person_detail", person_id=person_id)
|
||||
|
||||
async def async_person_credits(self, person_id: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据人物ID查询人物参演作品(异步版本)
|
||||
:param person_id: 人物ID
|
||||
"""
|
||||
return await self.async_run_module("async_bangumi_person_credits", person_id=person_id)
|
||||
|
||||
@@ -111,3 +111,111 @@ class DoubanChain(ChainBase):
|
||||
:param doubanid: 豆瓣ID
|
||||
"""
|
||||
return self.run_module("douban_tv_recommend", doubanid=doubanid)
|
||||
|
||||
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
|
||||
"""
|
||||
根据人物ID查询豆瓣人物详情(异步版本)
|
||||
:param person_id: 人物ID
|
||||
"""
|
||||
return await self.async_run_module("async_douban_person_detail", person_id=person_id)
|
||||
|
||||
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> List[MediaInfo]:
|
||||
"""
|
||||
根据人物ID查询人物参演作品(异步版本)
|
||||
:param person_id: 人物ID
|
||||
:param page: 页码
|
||||
"""
|
||||
return await self.async_run_module("async_douban_person_credits", person_id=person_id, page=page)
|
||||
|
||||
async def async_movie_top250(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取豆瓣电影TOP250(异步版本)
|
||||
:param page: 页码
|
||||
:param count: 每页数量
|
||||
"""
|
||||
return await self.async_run_module("async_movie_top250", page=page, count=count)
|
||||
|
||||
async def async_movie_showing(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取正在上映的电影(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_movie_showing", page=page, count=count)
|
||||
|
||||
async def async_tv_weekly_chinese(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取本周中国剧集榜(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_tv_weekly_chinese", page=page, count=count)
|
||||
|
||||
async def async_tv_weekly_global(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取本周全球剧集榜(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_tv_weekly_global", page=page, count=count)
|
||||
|
||||
async def async_douban_discover(self, mtype: MediaType, sort: str, tags: str,
|
||||
page: Optional[int] = 0, count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
发现豆瓣电影、剧集(异步版本)
|
||||
:param mtype: 媒体类型
|
||||
:param sort: 排序方式
|
||||
:param tags: 标签
|
||||
:param page: 页码
|
||||
:param count: 数量
|
||||
:return: 媒体信息列表
|
||||
"""
|
||||
return await self.async_run_module("async_douban_discover", mtype=mtype, sort=sort, tags=tags,
|
||||
page=page, count=count)
|
||||
|
||||
async def async_tv_animation(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取动画剧集(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_tv_animation", page=page, count=count)
|
||||
|
||||
async def async_movie_hot(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取热门电影(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_movie_hot", page=page, count=count)
|
||||
|
||||
async def async_tv_hot(self, page: Optional[int] = 1,
|
||||
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
获取热门剧集(异步版本)
|
||||
"""
|
||||
return await self.async_run_module("async_tv_hot", page=page, count=count)
|
||||
|
||||
async def async_movie_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
根据TMDBID查询电影演职人员(异步版本)
|
||||
:param doubanid: 豆瓣ID
|
||||
"""
|
||||
return await self.async_run_module("async_douban_movie_credits", doubanid=doubanid)
|
||||
|
||||
async def async_tv_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
根据TMDBID查询电视剧演职人员(异步版本)
|
||||
:param doubanid: 豆瓣ID
|
||||
"""
|
||||
return await self.async_run_module("async_douban_tv_credits", doubanid=doubanid)
|
||||
|
||||
async def async_movie_recommend(self, doubanid: str) -> List[MediaInfo]:
|
||||
"""
|
||||
根据豆瓣ID查询推荐电影(异步版本)
|
||||
:param doubanid: 豆瓣ID
|
||||
"""
|
||||
return await self.async_run_module("async_douban_movie_recommend", doubanid=doubanid)
|
||||
|
||||
async def async_tv_recommend(self, doubanid: str) -> List[MediaInfo]:
|
||||
"""
|
||||
根据豆瓣ID查询推荐电视剧(异步版本)
|
||||
:param doubanid: 豆瓣ID
|
||||
"""
|
||||
return await self.async_run_module("async_douban_tv_recommend", doubanid=doubanid)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional, Tuple, Set, Dict, Union
|
||||
|
||||
from app import schemas
|
||||
from app.chain import ChainBase
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, TorrentInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
@@ -35,10 +36,10 @@ class DownloadChain(ChainBase):
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None
|
||||
) -> Tuple[Optional[Union[Path, str]], str, list]:
|
||||
) -> Tuple[Optional[Union[str, bytes]], str, list]:
|
||||
"""
|
||||
下载种子文件,如果是磁力链,会返回磁力链接本身
|
||||
:return: 种子路径,种子目录名,种子文件清单
|
||||
:return: 种子内容,种子目录名,种子文件清单
|
||||
"""
|
||||
|
||||
def __get_redict_url(url: str, ua: Optional[str] = None, cookie: Optional[str] = None) -> Optional[str]:
|
||||
@@ -60,6 +61,8 @@ class DownloadChain(ChainBase):
|
||||
# 是否使用cookie
|
||||
if not req_params.get('cookie'):
|
||||
cookie = None
|
||||
# 代理
|
||||
proxy = req_params.get('proxy')
|
||||
# 请求头
|
||||
if req_params.get('header'):
|
||||
headers = req_params.get('header')
|
||||
@@ -70,14 +73,16 @@ class DownloadChain(ChainBase):
|
||||
res = RequestUtils(
|
||||
ua=ua,
|
||||
cookies=cookie,
|
||||
headers=headers
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if proxy else None
|
||||
).get_res(url, params=req_params.get('params'))
|
||||
else:
|
||||
# POST请求
|
||||
res = RequestUtils(
|
||||
ua=ua,
|
||||
cookies=cookie,
|
||||
headers=headers
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if proxy else None
|
||||
).post_res(url, params=req_params.get('params'))
|
||||
if not res:
|
||||
return None
|
||||
@@ -113,7 +118,7 @@ class DownloadChain(ChainBase):
|
||||
logger.error(f"{torrent.title} 无法获取下载地址:{torrent.enclosure}!")
|
||||
return None, "", []
|
||||
# 下载种子文件
|
||||
torrent_file, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
|
||||
_, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
|
||||
url=torrent_url,
|
||||
cookie=site_cookie,
|
||||
ua=torrent.site_ua or settings.USER_AGENT,
|
||||
@@ -123,7 +128,7 @@ class DownloadChain(ChainBase):
|
||||
# 磁力链
|
||||
return content, "", []
|
||||
|
||||
if not torrent_file:
|
||||
if not content:
|
||||
logger.error(f"下载种子文件失败:{torrent.title} - {torrent_url}")
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
@@ -135,9 +140,11 @@ class DownloadChain(ChainBase):
|
||||
return None, "", []
|
||||
|
||||
# 返回 种子文件路径,种子目录名,种子文件清单
|
||||
return torrent_file, download_folder, files
|
||||
return content, download_folder, files
|
||||
|
||||
def download_single(self, context: Context, torrent_file: Path = None,
|
||||
def download_single(self, context: Context,
|
||||
torrent_file: Path = None,
|
||||
torrent_content: Optional[Union[str, bytes]] = None,
|
||||
episodes: Set[int] = None,
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
@@ -150,6 +157,7 @@ class DownloadChain(ChainBase):
|
||||
下载及发送通知
|
||||
:param context: 资源上下文
|
||||
:param torrent_file: 种子文件路径
|
||||
:param torrent_content: 种子内容(磁力链或种子文件内容)
|
||||
:param episodes: 需要下载的集数
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、Subscribe、Manual等)
|
||||
@@ -188,6 +196,9 @@ class DownloadChain(ChainBase):
|
||||
f"Resource download canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
return None
|
||||
# 如果事件修改了下载路径,使用新路径
|
||||
if event_data.options and event_data.options.get("save_path"):
|
||||
save_path = event_data.options.get("save_path")
|
||||
|
||||
# 补充完整的media数据
|
||||
if not _media.genre_ids:
|
||||
@@ -200,18 +211,26 @@ class DownloadChain(ChainBase):
|
||||
# 实际下载的集数
|
||||
download_episodes = StringUtils.format_ep(list(episodes)) if episodes else None
|
||||
_folder_name = ""
|
||||
if not torrent_file:
|
||||
if not torrent_file and not torrent_content:
|
||||
# 下载种子文件,得到的可能是文件也可能是磁力链
|
||||
content, _folder_name, _file_list = self.download_torrent(_torrent,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid)
|
||||
if not content:
|
||||
return None
|
||||
else:
|
||||
content = torrent_file
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_torrent_info(torrent_file)
|
||||
torrent_content, _folder_name, _file_list = self.download_torrent(_torrent,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid)
|
||||
elif torrent_file:
|
||||
if torrent_file.exists():
|
||||
torrent_content = torrent_file.read_bytes()
|
||||
else:
|
||||
# 缓存处理器
|
||||
cache_backend = FileCache()
|
||||
# 读取缓存的种子文件
|
||||
torrent_content = cache_backend.get(torrent_file.as_posix(), region="torrents")
|
||||
|
||||
if not torrent_content:
|
||||
return None
|
||||
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
|
||||
# 下载目录
|
||||
if save_path:
|
||||
@@ -242,7 +261,7 @@ class DownloadChain(ChainBase):
|
||||
return None
|
||||
|
||||
# 添加下载
|
||||
result: Optional[tuple] = self.download(content=content,
|
||||
result: Optional[tuple] = self.download(content=torrent_content,
|
||||
cookie=_torrent.site_cookie,
|
||||
episodes=episodes,
|
||||
download_dir=download_dir,
|
||||
@@ -271,7 +290,7 @@ class DownloadChain(ChainBase):
|
||||
# 登记下载记录
|
||||
downloadhis = DownloadHistoryOper()
|
||||
downloadhis.add(
|
||||
path=str(download_path),
|
||||
path=download_path.as_posix(),
|
||||
type=_media.type.value,
|
||||
title=_media.title,
|
||||
year=_media.year,
|
||||
@@ -312,8 +331,8 @@ class DownloadChain(ChainBase):
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
"downloader": _downloader,
|
||||
"fullpath": str(_save_path / file),
|
||||
"savepath": str(_save_path),
|
||||
"fullpath": (_save_path / file).as_posix(),
|
||||
"savepath": _save_path.as_posix(),
|
||||
"filepath": file,
|
||||
"torrentname": _meta.org_string,
|
||||
})
|
||||
@@ -339,7 +358,7 @@ class DownloadChain(ChainBase):
|
||||
username=username,
|
||||
)
|
||||
# 下载成功后处理
|
||||
self.download_added(context=context, download_dir=download_dir, torrent_path=torrent_file)
|
||||
self.download_added(context=context, download_dir=download_dir, torrent_content=torrent_content)
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(EventType.DownloadAdded, {
|
||||
"hash": _hash,
|
||||
@@ -553,7 +572,7 @@ class DownloadChain(ChainBase):
|
||||
logger.info(f"开始下载 {torrent.title} ...")
|
||||
download_id = self.download_single(
|
||||
context=context,
|
||||
torrent_file=content if isinstance(content, Path) else None,
|
||||
torrent_content=content,
|
||||
save_path=save_path,
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -720,7 +739,7 @@ class DownloadChain(ChainBase):
|
||||
logger.info(f"开始下载 {torrent.title} ...")
|
||||
download_id = self.download_single(
|
||||
context=context,
|
||||
torrent_file=content if isinstance(content, Path) else None,
|
||||
torrent_content=content,
|
||||
episodes=selected_episodes,
|
||||
save_path=save_path,
|
||||
channel=channel,
|
||||
@@ -975,7 +994,7 @@ class DownloadChain(ChainBase):
|
||||
# 发出下载任务删除事件,如需处理辅种,可监听该事件
|
||||
self.eventmanager.send_event(EventType.DownloadDeleted, {
|
||||
"hash": hash_str,
|
||||
"torrents": [torrent.dict() for torrent in torrents]
|
||||
"torrents": [torrent.model_dump() for torrent in torrents]
|
||||
})
|
||||
else:
|
||||
logger.info(f"没有在下载器中查询到 {hash_str} 对应的下载任务")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from threading import Lock
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
@@ -19,7 +21,9 @@ from app.utils.string import StringUtils
|
||||
|
||||
recognize_lock = Lock()
|
||||
scraping_lock = Lock()
|
||||
scraping_files = []
|
||||
|
||||
current_umask = os.umask(0)
|
||||
os.umask(current_umask)
|
||||
|
||||
|
||||
class MediaChain(ChainBase):
|
||||
@@ -35,25 +39,25 @@ class MediaChain(ChainBase):
|
||||
switchs = SystemConfigOper().get(SystemConfigKey.ScrapingSwitchs) or {}
|
||||
# 默认配置
|
||||
default_switchs = {
|
||||
'movie_nfo': True, # 电影NFO
|
||||
'movie_poster': True, # 电影海报
|
||||
'movie_backdrop': True, # 电影背景图
|
||||
'movie_logo': True, # 电影Logo
|
||||
'movie_disc': True, # 电影光盘图
|
||||
'movie_banner': True, # 电影横幅图
|
||||
'movie_thumb': True, # 电影缩略图
|
||||
'tv_nfo': True, # 电视剧NFO
|
||||
'tv_poster': True, # 电视剧海报
|
||||
'tv_backdrop': True, # 电视剧背景图
|
||||
'tv_banner': True, # 电视剧横幅图
|
||||
'tv_logo': True, # 电视剧Logo
|
||||
'tv_thumb': True, # 电视剧缩略图
|
||||
'season_nfo': True, # 季NFO
|
||||
'season_poster': True, # 季海报
|
||||
'season_banner': True, # 季横幅图
|
||||
'season_thumb': True, # 季缩略图
|
||||
'episode_nfo': True, # 集NFO
|
||||
'episode_thumb': True # 集缩略图
|
||||
'movie_nfo': True, # 电影NFO
|
||||
'movie_poster': True, # 电影海报
|
||||
'movie_backdrop': True, # 电影背景图
|
||||
'movie_logo': True, # 电影Logo
|
||||
'movie_disc': True, # 电影光盘图
|
||||
'movie_banner': True, # 电影横幅图
|
||||
'movie_thumb': True, # 电影缩略图
|
||||
'tv_nfo': True, # 电视剧NFO
|
||||
'tv_poster': True, # 电视剧海报
|
||||
'tv_backdrop': True, # 电视剧背景图
|
||||
'tv_banner': True, # 电视剧横幅图
|
||||
'tv_logo': True, # 电视剧Logo
|
||||
'tv_thumb': True, # 电视剧缩略图
|
||||
'season_nfo': True, # 季NFO
|
||||
'season_poster': True, # 季海报
|
||||
'season_banner': True, # 季横幅图
|
||||
'season_thumb': True, # 季缩略图
|
||||
'episode_nfo': True, # 集NFO
|
||||
'episode_thumb': True # 集缩略图
|
||||
}
|
||||
# 合并用户配置和默认配置
|
||||
for key, default_value in default_switchs.items():
|
||||
@@ -231,17 +235,15 @@ class MediaChain(ChainBase):
|
||||
meta_names = list(dict.fromkeys([k for k in [meta_org.name,
|
||||
meta.cn_name,
|
||||
meta.en_name] if k]))
|
||||
for name in meta_names:
|
||||
tmdbinfo = self.match_tmdbinfo(
|
||||
name=name,
|
||||
year=meta.year,
|
||||
mtype=mtype or meta.type,
|
||||
season=meta.begin_season
|
||||
)
|
||||
if tmdbinfo:
|
||||
# 合季季后返回
|
||||
tmdbinfo['season'] = meta.begin_season
|
||||
break
|
||||
tmdbinfo = self._match_tmdb_with_names(
|
||||
meta_names=meta_names,
|
||||
year=meta.year,
|
||||
mtype=mtype or meta.type,
|
||||
season=meta.begin_season
|
||||
)
|
||||
if tmdbinfo:
|
||||
# 合季季后返回
|
||||
tmdbinfo['season'] = meta.begin_season
|
||||
return tmdbinfo
|
||||
|
||||
def get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
|
||||
@@ -257,23 +259,17 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
|
||||
# 年份
|
||||
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
|
||||
if release_date:
|
||||
year = release_date[:4]
|
||||
else:
|
||||
year = None
|
||||
year = self._extract_year_from_bangumi(bangumiinfo)
|
||||
# 识别TMDB媒体信息
|
||||
meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
|
||||
meta.name] if k]))
|
||||
for name in meta_names:
|
||||
tmdbinfo = self.match_tmdbinfo(
|
||||
name=name,
|
||||
year=year,
|
||||
mtype=MediaType.TV,
|
||||
season=meta.begin_season
|
||||
)
|
||||
if tmdbinfo:
|
||||
return tmdbinfo
|
||||
tmdbinfo = self._match_tmdb_with_names(
|
||||
meta_names=meta_names,
|
||||
year=year,
|
||||
mtype=MediaType.TV,
|
||||
season=meta.begin_season
|
||||
)
|
||||
return tmdbinfo
|
||||
return None
|
||||
|
||||
def get_doubaninfo_by_tmdbid(self, tmdbid: int,
|
||||
@@ -286,19 +282,7 @@ class MediaChain(ChainBase):
|
||||
# 名称
|
||||
name = tmdbinfo.get("title") or tmdbinfo.get("name")
|
||||
# 年份
|
||||
year = None
|
||||
if tmdbinfo.get('release_date'):
|
||||
year = tmdbinfo['release_date'][:4]
|
||||
elif tmdbinfo.get('seasons') and season:
|
||||
for seainfo in tmdbinfo['seasons']:
|
||||
# 季
|
||||
season_number = seainfo.get("season_number")
|
||||
if not season_number:
|
||||
continue
|
||||
air_date = seainfo.get("air_date")
|
||||
if air_date and season_number == season:
|
||||
year = air_date[:4]
|
||||
break
|
||||
year = self._extract_year_from_tmdb(tmdbinfo, season)
|
||||
# IMDBID
|
||||
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
|
||||
return self.match_doubaninfo(
|
||||
@@ -321,11 +305,7 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
meta = MetaInfo(title=bangumiinfo.get("name"))
|
||||
# 年份
|
||||
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
|
||||
if release_date:
|
||||
year = release_date[:4]
|
||||
else:
|
||||
year = None
|
||||
year = self._extract_year_from_bangumi(bangumiinfo)
|
||||
# 使用名称识别豆瓣媒体信息
|
||||
return self.match_doubaninfo(
|
||||
name=meta.name,
|
||||
@@ -335,6 +315,21 @@ class MediaChain(ChainBase):
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in StorageChain().list_files(fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
@eventmanager.register(EventType.MetadataScrape)
|
||||
def scrape_metadata_event(self, event: Event):
|
||||
"""
|
||||
@@ -343,29 +338,101 @@ class MediaChain(ChainBase):
|
||||
if not event:
|
||||
return
|
||||
event_data = event.event_data or {}
|
||||
# 媒体根目录
|
||||
fileitem: FileItem = event_data.get("fileitem")
|
||||
# 媒体文件列表
|
||||
file_list: List[str] = event_data.get("file_list", [])
|
||||
# 媒体元数据
|
||||
meta: MetaBase = event_data.get("meta")
|
||||
# 媒体信息
|
||||
mediainfo: MediaInfo = event_data.get("mediainfo")
|
||||
# 是否覆盖
|
||||
overwrite = event_data.get("overwrite", False)
|
||||
# 检查媒体根目录
|
||||
if not fileitem:
|
||||
return
|
||||
|
||||
# 刮削锁
|
||||
with scraping_lock:
|
||||
if fileitem.path in scraping_files:
|
||||
# 检查文件项是否存在
|
||||
storagechain = StorageChain()
|
||||
if not storagechain.get_item(fileitem):
|
||||
logger.warn(f"文件项不存在:{fileitem.path}")
|
||||
return
|
||||
scraping_files.append(fileitem.path)
|
||||
try:
|
||||
# 执行刮削
|
||||
self.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=overwrite)
|
||||
finally:
|
||||
# 释放锁
|
||||
with scraping_lock:
|
||||
scraping_files.remove(fileitem.path)
|
||||
# 检查是否为目录
|
||||
if fileitem.type == "file":
|
||||
# 单个文件刮削
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=storagechain.get_parent_item(fileitem),
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
if file_list:
|
||||
# 如果是BDMV原盘目录,只对根目录进行刮削,不处理子目录
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.info(f"检测到BDMV原盘目录,只对根目录进行刮削:{fileitem.path}")
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=True,
|
||||
recursive=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
# 1. 收集fileitem和file_list中每个文件之间所有子目录
|
||||
all_dirs = set()
|
||||
root_path = Path(fileitem.path)
|
||||
|
||||
logger.debug(f"开始收集目录,根目录:{root_path}")
|
||||
# 收集根目录
|
||||
all_dirs.add(root_path)
|
||||
|
||||
# 收集所有目录(包括所有层级)
|
||||
for sub_file in file_list:
|
||||
sub_path = Path(sub_file)
|
||||
# 收集从根目录到文件的所有父目录
|
||||
current_path = sub_path.parent
|
||||
while current_path != root_path and current_path.is_relative_to(root_path):
|
||||
all_dirs.add(current_path)
|
||||
current_path = current_path.parent
|
||||
|
||||
logger.debug(f"共收集到 {len(all_dirs)} 个目录")
|
||||
|
||||
# 2. 初始化一遍子目录,但不处理文件
|
||||
for sub_dir in all_dirs:
|
||||
sub_dir_item = storagechain.get_file_item(storage=fileitem.storage, path=sub_dir)
|
||||
if sub_dir_item:
|
||||
logger.info(f"为目录生成海报和nfo:{sub_dir}")
|
||||
# 初始化目录元数据,但不处理文件
|
||||
self.scrape_metadata(fileitem=sub_dir_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=True,
|
||||
recursive=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取目录项:{sub_dir}")
|
||||
|
||||
# 3. 刮削每个文件
|
||||
logger.info(f"开始刮削 {len(file_list)} 个文件")
|
||||
for sub_file_path in file_list:
|
||||
sub_file_item = storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=Path(sub_file_path))
|
||||
if sub_file_item:
|
||||
self.scrape_metadata(fileitem=sub_file_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取文件项:{sub_file_path}")
|
||||
else:
|
||||
# 执行全量刮削
|
||||
logger.info(f"开始刮削目录 {fileitem.path} ...")
|
||||
self.scrape_metadata(fileitem=fileitem, meta=meta, init_folder=True,
|
||||
mediainfo=mediainfo, overwrite=overwrite)
|
||||
|
||||
def scrape_metadata(self, fileitem: schemas.FileItem,
|
||||
meta: MetaBase = None, mediainfo: MediaInfo = None,
|
||||
init_folder: bool = True, parent: schemas.FileItem = None,
|
||||
overwrite: bool = False):
|
||||
overwrite: bool = False, recursive: bool = True):
|
||||
"""
|
||||
手动刮削媒体信息
|
||||
:param fileitem: 刮削目录或文件
|
||||
@@ -374,24 +441,11 @@ class MediaChain(ChainBase):
|
||||
:param init_folder: 是否刮削根目录
|
||||
:param parent: 上级目录
|
||||
:param overwrite: 是否覆盖已有文件
|
||||
:param recursive: 是否递归处理目录内文件
|
||||
"""
|
||||
|
||||
storagechain = StorageChain()
|
||||
|
||||
def is_bluray_folder(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not _fileitem or _fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in storagechain.list_files(_fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __list_files(_fileitem: schemas.FileItem):
|
||||
"""
|
||||
列出下级文件
|
||||
@@ -407,34 +461,68 @@ class MediaChain(ChainBase):
|
||||
"""
|
||||
if not _fileitem or not _content or not _path:
|
||||
return
|
||||
# 保存文件到临时目录,文件名随机
|
||||
tmp_file = settings.TEMP_PATH / f"{_path.name}.{StringUtils.generate_random_str(10)}"
|
||||
tmp_file.write_bytes(_content)
|
||||
# 获取文件的父目录
|
||||
try:
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file, new_name=_path.name)
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 写入内容
|
||||
if isinstance(_content, bytes):
|
||||
tmp_file.write(_content)
|
||||
else:
|
||||
tmp_file.write(_content.encode('utf-8'))
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削文件只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path, new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存文件:{item.path}")
|
||||
else:
|
||||
logger.warn(f"文件保存失败:{_path}")
|
||||
finally:
|
||||
if tmp_file.exists():
|
||||
tmp_file.unlink()
|
||||
|
||||
def __download_image(_url: str) -> Optional[bytes]:
|
||||
def __download_and_save_image(_fileitem: schemas.FileItem, _path: Path, _url: str):
|
||||
"""
|
||||
下载图片并保存
|
||||
流式下载图片并直接保存到文件(减少内存占用)
|
||||
:param _fileitem: 关联的媒体文件项
|
||||
:param _path: 图片文件路径
|
||||
:param _url: 图片下载URL
|
||||
"""
|
||||
if not _fileitem or not _url or not _path:
|
||||
return
|
||||
try:
|
||||
logger.info(f"正在下载图片:{_url} ...")
|
||||
r = RequestUtils(proxies=settings.PROXY, ua=settings.USER_AGENT).get_res(url=_url)
|
||||
if r:
|
||||
return r.content
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败,请检查网络连通性!")
|
||||
request_utils = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT)
|
||||
with request_utils.get_stream(url=_url) as r:
|
||||
if r and r.status_code == 200:
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 流式写入文件
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削的图片只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path,
|
||||
new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存图片:{item.path}")
|
||||
else:
|
||||
logger.warn(f"图片保存失败:{_path}")
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败")
|
||||
except Exception as err:
|
||||
logger.error(f"{_url} 图片下载失败:{str(err)}!")
|
||||
return None
|
||||
|
||||
if not fileitem:
|
||||
return
|
||||
|
||||
# 当前文件路径
|
||||
filepath = Path(fileitem.path)
|
||||
@@ -448,7 +536,7 @@ class MediaChain(ChainBase):
|
||||
if not mediainfo:
|
||||
logger.warn(f"{filepath} 无法识别文件媒体信息!")
|
||||
return
|
||||
|
||||
|
||||
# 获取刮削开关配置
|
||||
scraping_switchs = self._get_scraping_switchs()
|
||||
logger.info(f"开始刮削:{filepath} ...")
|
||||
@@ -464,6 +552,8 @@ class MediaChain(ChainBase):
|
||||
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
|
||||
if movie_nfo:
|
||||
# 保存或上传nfo文件到上级目录
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=nfo_path, _content=movie_nfo)
|
||||
else:
|
||||
logger.warn(f"{filepath.name} nfo文件生成失败!")
|
||||
@@ -473,30 +563,36 @@ class MediaChain(ChainBase):
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 电影目录
|
||||
if is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
|
||||
# 生成原盘nfo
|
||||
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
|
||||
if movie_nfo:
|
||||
# 保存或上传nfo文件到当前目录
|
||||
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
|
||||
if recursive:
|
||||
# 处理文件
|
||||
if self.is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
|
||||
# 生成原盘nfo
|
||||
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
|
||||
if movie_nfo:
|
||||
# 保存或上传nfo文件到当前目录
|
||||
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
|
||||
else:
|
||||
logger.warn(f"{filepath.name} nfo文件生成失败!")
|
||||
else:
|
||||
logger.warn(f"{filepath.name} nfo文件生成失败!")
|
||||
logger.info(f"已存在nfo文件:{nfo_path}")
|
||||
else:
|
||||
logger.info(f"已存在nfo文件:{nfo_path}")
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
self.scrape_metadata(fileitem=file,
|
||||
meta=meta, mediainfo=mediainfo,
|
||||
init_folder=False, parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
# 图片
|
||||
@@ -520,16 +616,13 @@ class MediaChain(ChainBase):
|
||||
should_scrape = scraping_switchs.get('movie_thumb', True)
|
||||
else:
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 写入图片到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -575,26 +668,27 @@ class MediaChain(ChainBase):
|
||||
for episode, image_url in image_dict.items():
|
||||
image_path = filepath.with_suffix(Path(image_url).suffix)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
logger.info("集缩略图刮削已关闭,跳过")
|
||||
else:
|
||||
# 当前为目录,处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
self.scrape_metadata(fileitem=file,
|
||||
meta=meta, mediainfo=mediainfo,
|
||||
parent=fileitem if file.type == "file" else None,
|
||||
init_folder=True if file.type == "dir" else False,
|
||||
overwrite=overwrite)
|
||||
# 当前为电视剧目录,处理目录内的文件
|
||||
if recursive:
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir" and not file.name.lower().startswith("season"):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
parent=fileitem if file.type == "file" else None,
|
||||
init_folder=True if file.type == "dir" else False,
|
||||
overwrite=overwrite)
|
||||
# 生成目录的nfo和图片
|
||||
if init_folder:
|
||||
# 识别文件夹名称
|
||||
@@ -628,13 +722,10 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath.with_name(image_name)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到剧集目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -653,23 +744,22 @@ class MediaChain(ChainBase):
|
||||
should_scrape = scraping_switchs.get('season_thumb', True)
|
||||
else:
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
# 只下载当前刮削季的图片
|
||||
image_season = "00" if "specials" in image_name else image_name[6:8]
|
||||
if image_season != str(season_meta.begin_season).rjust(2, '0'):
|
||||
logger.info(f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
|
||||
logger.info(
|
||||
f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
|
||||
continue
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path,
|
||||
_url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -714,18 +804,307 @@ class MediaChain(ChainBase):
|
||||
should_scrape = scraping_switchs.get('tv_thumb', True)
|
||||
else:
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
logger.info(f"电视剧图片刮削已关闭,跳过:{image_name}")
|
||||
logger.info(f"{filepath.name} 刮削完成")
|
||||
|
||||
async def async_recognize_by_meta(self, metainfo: MetaBase,
|
||||
episode_group: Optional[str] = None) -> Optional[MediaInfo]:
|
||||
"""
|
||||
根据主副标题识别媒体信息(异步版本)
|
||||
"""
|
||||
title = metainfo.title
|
||||
# 识别媒体信息
|
||||
mediainfo: MediaInfo = await self.async_recognize_media(meta=metainfo, episode_group=episode_group)
|
||||
if not mediainfo:
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
logger.info(f'请求辅助识别,标题:{title} ...')
|
||||
mediainfo = await self.async_recognize_help(title=title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
logger.warn(f'{title} 未识别到媒体信息')
|
||||
return None
|
||||
# 识别成功
|
||||
logger.info(f'{title} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
|
||||
# 更新媒体图片
|
||||
await self.async_obtain_images(mediainfo=mediainfo)
|
||||
# 返回上下文
|
||||
return mediainfo
|
||||
|
||||
async def async_recognize_help(self, title: str, org_meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
请求辅助识别,返回媒体信息(异步版本)
|
||||
:param title: 标题
|
||||
:param org_meta: 原始元数据
|
||||
"""
|
||||
# 发送请求事件,等待结果
|
||||
result: Event = await eventmanager.async_send_event(
|
||||
ChainEventType.NameRecognize,
|
||||
{
|
||||
'title': title,
|
||||
}
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
# 获取返回事件数据
|
||||
event_data = result.event_data or {}
|
||||
logger.info(f'获取到辅助识别结果:{event_data}')
|
||||
# 处理数据格式
|
||||
title, year, season_number, episode_number = None, None, None, None
|
||||
if event_data.get("name"):
|
||||
title = str(event_data["name"]).split("/")[0].strip().replace(".", " ")
|
||||
if event_data.get("year"):
|
||||
year = str(event_data["year"]).split("/")[0].strip()
|
||||
if event_data.get("season") and str(event_data["season"]).isdigit():
|
||||
season_number = int(event_data["season"])
|
||||
if event_data.get("episode") and str(event_data["episode"]).isdigit():
|
||||
episode_number = int(event_data["episode"])
|
||||
if not title:
|
||||
return None
|
||||
if title == 'Unknown':
|
||||
return None
|
||||
if not str(year).isdigit():
|
||||
year = None
|
||||
# 结果赋值
|
||||
if title == org_meta.name and year == org_meta.year:
|
||||
logger.info(f'辅助识别与原始识别结果一致,无需重新识别媒体信息')
|
||||
return None
|
||||
logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...')
|
||||
org_meta.name = title
|
||||
org_meta.year = year
|
||||
org_meta.begin_season = season_number
|
||||
org_meta.begin_episode = episode_number
|
||||
if org_meta.begin_season or org_meta.begin_episode:
|
||||
org_meta.type = MediaType.TV
|
||||
# 重新识别
|
||||
return await self.async_recognize_media(meta=org_meta)
|
||||
|
||||
async def async_recognize_by_path(self, path: str, episode_group: Optional[str] = None) -> Optional[Context]:
|
||||
"""
|
||||
根据文件路径识别媒体信息(异步版本)
|
||||
"""
|
||||
logger.info(f'开始识别媒体信息,文件:{path} ...')
|
||||
file_path = Path(path)
|
||||
# 元数据
|
||||
file_meta = MetaInfoPath(file_path)
|
||||
# 识别媒体信息
|
||||
mediainfo = await self.async_recognize_media(meta=file_meta, episode_group=episode_group)
|
||||
if not mediainfo:
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
logger.info(f'请求辅助识别,标题:{file_path.name} ...')
|
||||
mediainfo = await self.async_recognize_help(title=path, org_meta=file_meta)
|
||||
if not mediainfo:
|
||||
logger.warn(f'{path} 未识别到媒体信息')
|
||||
return Context(meta_info=file_meta)
|
||||
logger.info(f'{path} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
|
||||
# 更新媒体图片
|
||||
await self.async_obtain_images(mediainfo=mediainfo)
|
||||
# 返回上下文
|
||||
return Context(meta_info=file_meta, media_info=mediainfo)
|
||||
|
||||
async def async_search(self, title: str) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
|
||||
"""
|
||||
搜索媒体/人物信息(异步版本)
|
||||
:param title: 搜索内容
|
||||
:return: 识别元数据,媒体信息列表
|
||||
"""
|
||||
# 提取要素
|
||||
mtype, key_word, season_num, episode_num, year, content = StringUtils.get_keyword(title)
|
||||
# 识别
|
||||
meta = MetaInfo(content)
|
||||
if not meta.name:
|
||||
meta.cn_name = content
|
||||
# 合并信息
|
||||
if mtype:
|
||||
meta.type = mtype
|
||||
if season_num:
|
||||
meta.begin_season = season_num
|
||||
if episode_num:
|
||||
meta.begin_episode = episode_num
|
||||
if year:
|
||||
meta.year = year
|
||||
# 开始搜索
|
||||
logger.info(f"开始搜索媒体信息:{meta.name}")
|
||||
medias: Optional[List[MediaInfo]] = await self.async_search_medias(meta=meta)
|
||||
if not medias:
|
||||
logger.warn(f"{meta.name} 没有找到对应的媒体信息!")
|
||||
return meta, []
|
||||
logger.info(f"{content} 搜索到 {len(medias)} 条相关媒体信息")
|
||||
# 识别的元数据,媒体信息列表
|
||||
return meta, medias
|
||||
|
||||
@staticmethod
|
||||
def _extract_year_from_bangumi(bangumiinfo: dict) -> Optional[str]:
|
||||
"""
|
||||
从Bangumi信息中提取年份
|
||||
"""
|
||||
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
|
||||
if release_date:
|
||||
return release_date[:4]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_year_from_tmdb(tmdbinfo: dict, season: Optional[int] = None) -> Optional[str]:
|
||||
"""
|
||||
从TMDB信息中提取年份
|
||||
"""
|
||||
year = None
|
||||
if tmdbinfo.get('release_date'):
|
||||
year = tmdbinfo['release_date'][:4]
|
||||
elif tmdbinfo.get('seasons') and season:
|
||||
for seainfo in tmdbinfo['seasons']:
|
||||
season_number = seainfo.get("season_number")
|
||||
if not season_number:
|
||||
continue
|
||||
air_date = seainfo.get("air_date")
|
||||
if air_date and season_number == season:
|
||||
year = air_date[:4]
|
||||
break
|
||||
return year
|
||||
|
||||
def _match_tmdb_with_names(self, meta_names: list, year: Optional[str],
|
||||
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息
|
||||
"""
|
||||
for name in meta_names:
|
||||
tmdbinfo = self.match_tmdbinfo(
|
||||
name=name,
|
||||
year=year,
|
||||
mtype=mtype,
|
||||
season=season
|
||||
)
|
||||
if tmdbinfo:
|
||||
return tmdbinfo
|
||||
return None
|
||||
|
||||
async def _async_match_tmdb_with_names(self, meta_names: list, year: Optional[str],
|
||||
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息(异步版本)
|
||||
"""
|
||||
for name in meta_names:
|
||||
tmdbinfo = await self.async_match_tmdbinfo(
|
||||
name=name,
|
||||
year=year,
|
||||
mtype=mtype,
|
||||
season=season
|
||||
)
|
||||
if tmdbinfo:
|
||||
return tmdbinfo
|
||||
return None
|
||||
|
||||
async def async_get_tmdbinfo_by_doubanid(self, doubanid: str, mtype: MediaType = None) -> Optional[dict]:
|
||||
"""
|
||||
根据豆瓣ID获取TMDB信息(异步版本)
|
||||
"""
|
||||
tmdbinfo = None
|
||||
doubaninfo = await self.async_douban_info(doubanid=doubanid, mtype=mtype)
|
||||
if doubaninfo:
|
||||
# 优先使用原标题匹配
|
||||
if doubaninfo.get("original_title"):
|
||||
meta = MetaInfo(title=doubaninfo.get("title"))
|
||||
meta_org = MetaInfo(title=doubaninfo.get("original_title"))
|
||||
else:
|
||||
meta_org = meta = MetaInfo(title=doubaninfo.get("title"))
|
||||
# 年份
|
||||
if doubaninfo.get("year"):
|
||||
meta.year = doubaninfo.get("year")
|
||||
# 处理类型
|
||||
if isinstance(doubaninfo.get('media_type'), MediaType):
|
||||
meta.type = doubaninfo.get('media_type')
|
||||
else:
|
||||
meta.type = MediaType.MOVIE if doubaninfo.get("type") == "movie" else MediaType.TV
|
||||
# 匹配TMDB信息
|
||||
meta_names = list(dict.fromkeys([k for k in [meta_org.name,
|
||||
meta.cn_name,
|
||||
meta.en_name] if k]))
|
||||
tmdbinfo = await self._async_match_tmdb_with_names(
|
||||
meta_names=meta_names,
|
||||
year=meta.year,
|
||||
mtype=mtype or meta.type,
|
||||
season=meta.begin_season
|
||||
)
|
||||
if tmdbinfo:
|
||||
# 合季季后返回
|
||||
tmdbinfo['season'] = meta.begin_season
|
||||
return tmdbinfo
|
||||
|
||||
async def async_get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
|
||||
"""
|
||||
根据BangumiID获取TMDB信息(异步版本)
|
||||
"""
|
||||
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
|
||||
if bangumiinfo:
|
||||
# 优先使用原标题匹配
|
||||
if bangumiinfo.get("name_cn"):
|
||||
meta = MetaInfo(title=bangumiinfo.get("name"))
|
||||
meta_cn = MetaInfo(title=bangumiinfo.get("name_cn"))
|
||||
else:
|
||||
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
|
||||
# 年份
|
||||
year = self._extract_year_from_bangumi(bangumiinfo)
|
||||
# 识别TMDB媒体信息
|
||||
meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
|
||||
meta.name] if k]))
|
||||
tmdbinfo = await self._async_match_tmdb_with_names(
|
||||
meta_names=meta_names,
|
||||
year=year,
|
||||
mtype=MediaType.TV,
|
||||
season=meta.begin_season
|
||||
)
|
||||
return tmdbinfo
|
||||
return None
|
||||
|
||||
async def async_get_doubaninfo_by_tmdbid(self, tmdbid: int, mtype: MediaType = None,
|
||||
season: Optional[int] = None) -> Optional[dict]:
|
||||
"""
|
||||
根据TMDBID获取豆瓣信息(异步版本)
|
||||
"""
|
||||
tmdbinfo = await self.async_tmdb_info(tmdbid=tmdbid, mtype=mtype)
|
||||
if tmdbinfo:
|
||||
# 名称
|
||||
name = tmdbinfo.get("title") or tmdbinfo.get("name")
|
||||
# 年份
|
||||
year = self._extract_year_from_tmdb(tmdbinfo, season)
|
||||
# IMDBID
|
||||
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
|
||||
return await self.async_match_doubaninfo(
|
||||
name=name,
|
||||
year=year,
|
||||
mtype=mtype,
|
||||
imdbid=imdbid
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_get_doubaninfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
|
||||
"""
|
||||
根据BangumiID获取豆瓣信息(异步版本)
|
||||
"""
|
||||
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
|
||||
if bangumiinfo:
|
||||
# 优先使用中文标题匹配
|
||||
if bangumiinfo.get("name_cn"):
|
||||
meta = MetaInfo(title=bangumiinfo.get("name_cn"))
|
||||
else:
|
||||
meta = MetaInfo(title=bangumiinfo.get("name"))
|
||||
# 年份
|
||||
year = self._extract_year_from_bangumi(bangumiinfo)
|
||||
# 使用名称识别豆瓣媒体信息
|
||||
return await self.async_match_doubaninfo(
|
||||
name=meta.name,
|
||||
year=year,
|
||||
mtype=MediaType.TV,
|
||||
season=meta.begin_season
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -113,6 +113,16 @@ class MediaServerChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("mediaserver_play_url", server=server, item_id=item_id)
|
||||
|
||||
def get_image_cookies(
|
||||
self, server: Optional[str], image_url: str
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取图片的Cookies
|
||||
"""
|
||||
return self.run_module(
|
||||
"mediaserver_image_cookies", server=server, image_url=image_url
|
||||
)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
同步媒体库所有数据到本地数据库
|
||||
@@ -167,7 +177,7 @@ class MediaServerChain(ChainBase):
|
||||
for episode in espisodes_info:
|
||||
seasoninfo[episode.season] = episode.episodes
|
||||
# 插入数据
|
||||
item_dict = item.dict()
|
||||
item_dict = item.model_dump()
|
||||
item_dict["seasoninfo"] = seasoninfo
|
||||
item_dict["item_type"] = item_type
|
||||
dboper.add(**item_dict)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -163,6 +165,10 @@ class MessageChain(ChainBase):
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
self.eventmanager.send_event(
|
||||
@@ -815,3 +821,86 @@ class MessageChain(ChainBase):
|
||||
buttons.append(page_buttons)
|
||||
|
||||
return buttons
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
try:
|
||||
# 检查AI智能体是否启用
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助手未启用,请在系统设置中启用"
|
||||
))
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助未配置,请在系统设置中配置"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="请输入您的问题或需求"
|
||||
))
|
||||
return
|
||||
|
||||
# 发送处理中消息
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot助手已收到您的请求,请稍候..."
|
||||
))
|
||||
|
||||
# 生成会话ID
|
||||
session_id = f"user_{userid}_{hash(user_message) % 10000}"
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -10,7 +9,7 @@ from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cache_backend, cached
|
||||
from app.core.cache import cached, FileCache
|
||||
from app.core.config import settings, global_vars
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
@@ -19,26 +18,24 @@ from app.utils.http import RequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 推荐相关的专用缓存
|
||||
recommend_ttl = 24 * 3600
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
|
||||
class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
推荐处理链,单例运行
|
||||
"""
|
||||
|
||||
# 推荐数据的缓存页数
|
||||
# 推荐缓存时间
|
||||
recommend_ttl = 24 * 3600
|
||||
# 推荐缓存页数
|
||||
cache_max_pages = 5
|
||||
# 推荐缓存区域
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
def refresh_recommend(self):
|
||||
"""
|
||||
刷新推荐
|
||||
"""
|
||||
logger.debug("Starting to refresh Recommend data.")
|
||||
cache_backend.clear(region=recommend_cache_region)
|
||||
logger.debug("Recommend Cache has been cleared.")
|
||||
|
||||
# 推荐来源方法
|
||||
recommend_methods = [
|
||||
@@ -106,31 +103,26 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
if not settings.GLOBAL_IMAGE_CACHE or not url:
|
||||
return
|
||||
|
||||
# 生成缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = settings.CACHE_PATH / "images" / sanitized_path
|
||||
|
||||
cache_path = Path("images") / sanitized_path
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 确保缓存路径和文件类型合法
|
||||
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
|
||||
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
|
||||
return
|
||||
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
|
||||
cache_backend = FileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
# 本地存在缓存图片,则直接跳过
|
||||
if cache_path.exists():
|
||||
if cache_backend.get(cache_path.as_posix(), region="images"):
|
||||
logger.debug(f"Cache hit: Image already exists at {cache_path}")
|
||||
return
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if not referer else None
|
||||
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
|
||||
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
|
||||
if not response:
|
||||
logger.debug(f"Empty response for URL: {url}")
|
||||
return
|
||||
@@ -142,19 +134,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
return
|
||||
|
||||
if not cache_path:
|
||||
return
|
||||
|
||||
try:
|
||||
if not cache_path.parent.exists():
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
|
||||
tmp_file.write(response.content)
|
||||
temp_path = Path(tmp_file.name)
|
||||
temp_path.replace(cache_path)
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
|
||||
# 保存缓存
|
||||
cache_backend.set(cache_path.as_posix(), response.content, region="images")
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
@@ -310,3 +292,158 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
tvs = DoubanChain().tv_hot(page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_tmdb_movies(self, sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1) -> List[dict]:
|
||||
"""
|
||||
异步TMDB热门电影
|
||||
"""
|
||||
movies = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.MOVIE,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return [movie.to_dict() for movie in movies] if movies else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_tmdb_tvs(self, sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "zh|en|ja|ko",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1) -> List[dict]:
|
||||
"""
|
||||
异步TMDB热门电视剧
|
||||
"""
|
||||
tvs = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.TV,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
return [tv.to_dict() for tv in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_tmdb_trending(self, page: Optional[int] = 1) -> List[dict]:
|
||||
"""
|
||||
异步TMDB流行趋势
|
||||
"""
|
||||
infos = await TmdbChain().async_run_module("async_tmdb_trending", page=page)
|
||||
return [info.to_dict() for info in infos] if infos else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_bangumi_calendar(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步Bangumi每日放送
|
||||
"""
|
||||
medias = await BangumiChain().async_run_module("async_bangumi_calendar")
|
||||
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_movie_showing(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣正在热映
|
||||
"""
|
||||
movies = await DoubanChain().async_run_module("async_movie_showing", page=page, count=count)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_movies(self, sort: Optional[str] = "R", tags: Optional[str] = "",
|
||||
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣最新电影
|
||||
"""
|
||||
movies = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.MOVIE,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_tvs(self, sort: Optional[str] = "R", tags: Optional[str] = "",
|
||||
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣最新电视剧
|
||||
"""
|
||||
tvs = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.TV,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_movie_top250(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣电影TOP250
|
||||
"""
|
||||
movies = await DoubanChain().async_run_module("async_movie_top250", page=page, count=count)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_tv_weekly_chinese(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣国产剧集榜
|
||||
"""
|
||||
tvs = await DoubanChain().async_run_module("async_tv_weekly_chinese", page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_tv_weekly_global(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣全球剧集榜
|
||||
"""
|
||||
tvs = await DoubanChain().async_run_module("async_tv_weekly_global", page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_tv_animation(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣热门动漫
|
||||
"""
|
||||
tvs = await DoubanChain().async_run_module("async_tv_animation", page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_movie_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣热门电影
|
||||
"""
|
||||
movies = await DoubanChain().async_run_module("async_movie_hot", page=page, count=count)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_douban_tv_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
|
||||
"""
|
||||
异步豆瓣热门电视剧
|
||||
"""
|
||||
tvs = await DoubanChain().async_run_module("async_tv_hot", page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import pickle
|
||||
import traceback
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.context import Context
|
||||
from app.core.context import MediaInfo, TorrentInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import NotExistMediaInfo
|
||||
@@ -54,7 +57,7 @@ class SearchChain(ChainBase):
|
||||
results = self.process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
self.save_cache(pickle.dumps(results), self.__result_temp_file)
|
||||
self.save_cache(results, self.__result_temp_file)
|
||||
return results
|
||||
|
||||
def search_by_title(self, title: str, page: Optional[int] = 0,
|
||||
@@ -71,7 +74,7 @@ class SearchChain(ChainBase):
|
||||
else:
|
||||
logger.info(f'开始浏览资源,站点:{sites} ...')
|
||||
# 搜索
|
||||
torrents = self.__search_all_sites(keywords=[title], sites=sites, page=page) or []
|
||||
torrents = self.__search_all_sites(keyword=title, sites=sites, page=page) or []
|
||||
if not torrents:
|
||||
logger.warn(f'{title} 未搜索到资源')
|
||||
return []
|
||||
@@ -80,67 +83,85 @@ class SearchChain(ChainBase):
|
||||
torrent_info=torrent) for torrent in torrents]
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
self.save_cache(pickle.dumps(contexts), self.__result_temp_file)
|
||||
self.save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
def last_search_results(self) -> List[Context]:
|
||||
def last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
获取上次搜索结果
|
||||
"""
|
||||
# 读取本地文件缓存
|
||||
content = self.load_cache(self.__result_temp_file)
|
||||
if not content:
|
||||
return []
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
|
||||
return []
|
||||
return self.load_cache(self.__result_temp_file)
|
||||
|
||||
def process(self, mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
|
||||
sites: List[int] = None,
|
||||
rule_groups: List[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
custom_words: List[str] = None,
|
||||
filter_params: Dict[str, str] = None) -> List[Context]:
|
||||
async def async_last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
根据媒体信息搜索种子资源,精确匹配,应用过滤规则,同时根据no_exists过滤本地已存在的资源
|
||||
:param mediainfo: 媒体信息
|
||||
:param keyword: 搜索关键词
|
||||
:param no_exists: 缺失的媒体信息
|
||||
:param sites: 站点ID列表,为空时搜索所有站点
|
||||
:param rule_groups: 过滤规则组名称列表
|
||||
异步获取上次搜索结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
"""
|
||||
根据TMDBID/豆瓣ID异步搜索资源,精确匹配,不过滤本地存在的资源
|
||||
:param tmdbid: TMDB ID
|
||||
:param doubanid: 豆瓣 ID
|
||||
:param mtype: 媒体,电影 or 电视剧
|
||||
:param area: 搜索范围,title or imdbid
|
||||
:param custom_words: 自定义识别词列表
|
||||
:param filter_params: 过滤参数
|
||||
:param season: 季数
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
}
|
||||
}
|
||||
results = await self.async_process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
await self.async_save_cache(results, self.__result_temp_file)
|
||||
return results
|
||||
|
||||
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
执行优先级过滤
|
||||
"""
|
||||
return self.filter_torrents(rule_groups=rule_groups,
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo) or []
|
||||
|
||||
# 豆瓣标题处理
|
||||
if not mediainfo.tmdb_id:
|
||||
meta = MetaInfo(title=mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
mediainfo.season = meta.begin_season
|
||||
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
|
||||
|
||||
# 补充媒体信息
|
||||
if not mediainfo.names:
|
||||
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id)
|
||||
if not mediainfo:
|
||||
logger.error(f'媒体信息识别失败!')
|
||||
return []
|
||||
async def async_search_by_title(self, title: str, page: Optional[int] = 0,
|
||||
sites: List[int] = None, cache_local: Optional[bool] = False) -> List[Context]:
|
||||
"""
|
||||
根据标题异步搜索资源,不识别不过滤,直接返回站点内容
|
||||
:param title: 标题,为空时返回所有站点首页内容
|
||||
:param page: 页码
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if title:
|
||||
logger.info(f'开始搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
logger.info(f'开始浏览资源,站点:{sites} ...')
|
||||
# 搜索
|
||||
torrents = await self.__async_search_all_sites(keyword=title, sites=sites, page=page) or []
|
||||
if not torrents:
|
||||
logger.warn(f'{title} 未搜索到资源')
|
||||
return []
|
||||
# 组装上下文
|
||||
contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
|
||||
torrent_info=torrent) for torrent in torrents]
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
await self.async_save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
@staticmethod
|
||||
def __prepare_params(mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None
|
||||
) -> Tuple[Dict[int, List[int]], List[str]]:
|
||||
"""
|
||||
准备搜索参数
|
||||
"""
|
||||
# 缺失的季集
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if no_exists and no_exists.get(mediakey):
|
||||
@@ -164,25 +185,41 @@ class SearchChain(ChainBase):
|
||||
mediainfo.hk_title,
|
||||
mediainfo.tw_title,
|
||||
mediainfo.sg_title] if k]))
|
||||
# 限制搜索关键词数量
|
||||
if settings.MAX_SEARCH_NAME_LIMIT:
|
||||
keywords = keywords[:settings.MAX_SEARCH_NAME_LIMIT]
|
||||
|
||||
return season_episodes, keywords
|
||||
|
||||
def __parse_result(self, torrents: List[TorrentInfo],
|
||||
mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
rule_groups: List[str] = None,
|
||||
season_episodes: Dict[int, List[int]] = None,
|
||||
custom_words: List[str] = None,
|
||||
filter_params: Dict[str, str] = None) -> List[Context]:
|
||||
"""
|
||||
处理搜索结果
|
||||
"""
|
||||
|
||||
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
执行优先级过滤
|
||||
"""
|
||||
return self.filter_torrents(rule_groups=rule_groups,
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo) or []
|
||||
|
||||
# 执行搜索
|
||||
torrents: List[TorrentInfo] = self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keywords=keywords,
|
||||
sites=sites,
|
||||
area=area
|
||||
)
|
||||
if not torrents:
|
||||
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
|
||||
return []
|
||||
|
||||
# 开始新进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.Search)
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
|
||||
# 开始过滤
|
||||
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
|
||||
key=ProgressKey.Search)
|
||||
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...')
|
||||
# 匹配订阅附加参数
|
||||
if filter_params:
|
||||
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
|
||||
@@ -200,7 +237,7 @@ class SearchChain(ChainBase):
|
||||
logger.info(f"过滤规则/剧集过滤完成,剩余 {len(torrents)} 个资源")
|
||||
|
||||
# 过滤完成
|
||||
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
|
||||
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源')
|
||||
|
||||
# 总数
|
||||
_total = len(torrents)
|
||||
@@ -213,14 +250,13 @@ class SearchChain(ChainBase):
|
||||
try:
|
||||
# 英文标题应该在别名/原标题中,不需要再匹配
|
||||
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
|
||||
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
|
||||
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...')
|
||||
for torrent in torrents:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
_count += 1
|
||||
progress.update(value=(_count / _total) * 96,
|
||||
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
|
||||
key=ProgressKey.Search)
|
||||
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...')
|
||||
if not torrent.title:
|
||||
continue
|
||||
|
||||
@@ -253,8 +289,7 @@ class SearchChain(ChainBase):
|
||||
# 匹配完成
|
||||
logger.info(f"匹配完成,共匹配到 {len(_match_torrents)} 个资源")
|
||||
progress.update(value=97,
|
||||
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
|
||||
key=ProgressKey.Search)
|
||||
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源')
|
||||
|
||||
# 去掉mediainfo中多余的数据
|
||||
mediainfo.clear()
|
||||
@@ -270,21 +305,192 @@ class SearchChain(ChainBase):
|
||||
|
||||
# 排序
|
||||
progress.update(value=99,
|
||||
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...',
|
||||
key=ProgressKey.Search)
|
||||
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...')
|
||||
contexts = torrenthelper.sort_torrents(contexts)
|
||||
|
||||
# 结束进度
|
||||
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
|
||||
progress.update(value=100,
|
||||
text=f'搜索完成,共 {len(contexts)} 个资源',
|
||||
key=ProgressKey.Search)
|
||||
progress.end(ProgressKey.Search)
|
||||
text=f'搜索完成,共 {len(contexts)} 个资源')
|
||||
progress.end()
|
||||
|
||||
# 返回
|
||||
return contexts
|
||||
# 去重后返回
|
||||
return self.__remove_duplicate(contexts)
|
||||
|
||||
def __search_all_sites(self, keywords: List[str],
|
||||
@staticmethod
|
||||
def __remove_duplicate(_torrents: List[Context]) -> List[Context]:
|
||||
"""
|
||||
去除重复的种子
|
||||
:param _torrents: 种子列表
|
||||
:return: 去重后的种子列表
|
||||
"""
|
||||
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
|
||||
for t in _torrents}.values())
|
||||
|
||||
def process(self, mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
|
||||
sites: List[int] = None,
|
||||
rule_groups: List[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
custom_words: List[str] = None,
|
||||
filter_params: Dict[str, str] = None) -> List[Context]:
|
||||
"""
|
||||
根据媒体信息搜索种子资源,精确匹配,应用过滤规则,同时根据no_exists过滤本地已存在的资源
|
||||
:param mediainfo: 媒体信息
|
||||
:param keyword: 搜索关键词
|
||||
:param no_exists: 缺失的媒体信息
|
||||
:param sites: 站点ID列表,为空时搜索所有站点
|
||||
:param rule_groups: 过滤规则组名称列表
|
||||
:param area: 搜索范围,title or imdbid
|
||||
:param custom_words: 自定义识别词列表
|
||||
:param filter_params: 过滤参数
|
||||
"""
|
||||
|
||||
# 豆瓣标题处理
|
||||
if not mediainfo.tmdb_id:
|
||||
meta = MetaInfo(title=mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
mediainfo.season = meta.begin_season
|
||||
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
|
||||
|
||||
# 补充媒体信息
|
||||
if not mediainfo.names:
|
||||
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id)
|
||||
if not mediainfo:
|
||||
logger.error(f'媒体信息识别失败!')
|
||||
return []
|
||||
|
||||
# 准备搜索参数
|
||||
season_episodes, keywords = self.__prepare_params(
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
no_exists=no_exists
|
||||
)
|
||||
|
||||
# 站点搜索结果
|
||||
torrents: List[TorrentInfo] = []
|
||||
# 站点搜索次数
|
||||
search_count = 0
|
||||
|
||||
# 多关键字执行搜索
|
||||
for search_word in keywords:
|
||||
# 强制休眠 1-10 秒
|
||||
if search_count > 0:
|
||||
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
|
||||
time.sleep(random.randint(1, 10))
|
||||
|
||||
# 搜索站点
|
||||
results = self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
# 合并结果
|
||||
|
||||
search_count += 1
|
||||
torrents.extend(results)
|
||||
|
||||
# 有结果则停止
|
||||
if not settings.SEARCH_MULTIPLE_NAME and torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
# 处理结果
|
||||
return self.__parse_result(
|
||||
torrents=torrents,
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
rule_groups=rule_groups,
|
||||
season_episodes=season_episodes,
|
||||
custom_words=custom_words,
|
||||
filter_params=filter_params
|
||||
)
|
||||
|
||||
async def async_process(self, mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
|
||||
sites: List[int] = None,
|
||||
rule_groups: List[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
custom_words: List[str] = None,
|
||||
filter_params: Dict[str, str] = None) -> List[Context]:
|
||||
"""
|
||||
根据媒体信息异步搜索种子资源,精确匹配,应用过滤规则,同时根据no_exists过滤本地已存在的资源
|
||||
:param mediainfo: 媒体信息
|
||||
:param keyword: 搜索关键词
|
||||
:param no_exists: 缺失的媒体信息
|
||||
:param sites: 站点ID列表,为空时搜索所有站点
|
||||
:param rule_groups: 过滤规则组名称列表
|
||||
:param area: 搜索范围,title or imdbid
|
||||
:param custom_words: 自定义识别词列表
|
||||
:param filter_params: 过滤参数
|
||||
"""
|
||||
|
||||
# 豆瓣标题处理
|
||||
if not mediainfo.tmdb_id:
|
||||
meta = MetaInfo(title=mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
mediainfo.season = meta.begin_season
|
||||
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
|
||||
|
||||
# 补充媒体信息
|
||||
if not mediainfo.names:
|
||||
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id)
|
||||
if not mediainfo:
|
||||
logger.error(f'媒体信息识别失败!')
|
||||
return []
|
||||
|
||||
# 准备搜索参数
|
||||
season_episodes, keywords = self.__prepare_params(
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
no_exists=no_exists
|
||||
)
|
||||
|
||||
# 站点搜索结果
|
||||
torrents: List[TorrentInfo] = []
|
||||
# 站点搜索次数
|
||||
search_count = 0
|
||||
|
||||
# 多关键字执行搜索
|
||||
for search_word in keywords:
|
||||
# 强制休眠 1-10 秒
|
||||
if search_count > 0:
|
||||
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
|
||||
await asyncio.sleep(random.randint(1, 10))
|
||||
# 搜索站点
|
||||
torrents.extend(
|
||||
await self.__async_search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
)
|
||||
search_count += 1
|
||||
# 有结果则停止
|
||||
if torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
# 处理结果
|
||||
return await run_in_threadpool(self.__parse_result,
|
||||
torrents=torrents,
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
rule_groups=rule_groups,
|
||||
season_episodes=season_episodes,
|
||||
custom_words=custom_words,
|
||||
filter_params=filter_params
|
||||
)
|
||||
|
||||
def __search_all_sites(self, keyword: str,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
sites: List[int] = None,
|
||||
page: Optional[int] = 0,
|
||||
@@ -292,7 +498,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
多线程搜索多个站点
|
||||
:param mediainfo: 识别的媒体信息
|
||||
:param keywords: 搜索关键词列表
|
||||
:param keyword: 搜索关键词
|
||||
:param sites: 指定站点ID列表,如有则只搜索指定站点,否则搜索所有站点
|
||||
:param page: 搜索页码
|
||||
:param area: 搜索区域 title or imdbid
|
||||
@@ -314,8 +520,8 @@ class SearchChain(ChainBase):
|
||||
return []
|
||||
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.Search)
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
# 开始计时
|
||||
start_time = datetime.now()
|
||||
# 总数
|
||||
@@ -324,8 +530,7 @@ class SearchChain(ChainBase):
|
||||
finish_count = 0
|
||||
# 更新进度
|
||||
progress.update(value=0,
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...")
|
||||
# 结果集
|
||||
results = []
|
||||
# 多线程
|
||||
@@ -335,13 +540,13 @@ class SearchChain(ChainBase):
|
||||
if area == "imdbid":
|
||||
# 搜索IMDBID
|
||||
task = executor.submit(self.search_torrents, site=site,
|
||||
keywords=[mediainfo.imdb_id] if mediainfo else None,
|
||||
keyword=mediainfo.imdb_id if mediainfo else None,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
else:
|
||||
# 搜索标题
|
||||
task = executor.submit(self.search_torrents, site=site,
|
||||
keywords=keywords,
|
||||
keyword=keyword,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
all_task.append(task)
|
||||
@@ -354,17 +559,100 @@ class SearchChain(ChainBase):
|
||||
results.extend(result)
|
||||
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
|
||||
progress.update(value=finish_count / total_num * 100,
|
||||
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
|
||||
# 计算耗时
|
||||
end_time = datetime.now()
|
||||
# 更新进度
|
||||
progress.update(value=100,
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒",
|
||||
key=ProgressKey.Search)
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
# 结束进度
|
||||
progress.end(ProgressKey.Search)
|
||||
progress.end()
|
||||
|
||||
# 返回
|
||||
return results
|
||||
|
||||
async def __async_search_all_sites(self, keyword: str,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
sites: List[int] = None,
|
||||
page: Optional[int] = 0,
|
||||
area: Optional[str] = "title") -> Optional[List[TorrentInfo]]:
|
||||
"""
|
||||
异步搜索多个站点
|
||||
:param mediainfo: 识别的媒体信息
|
||||
:param keyword: 搜索关键词
|
||||
:param sites: 指定站点ID列表,如有则只搜索指定站点,否则搜索所有站点
|
||||
:param page: 搜索页码
|
||||
:param area: 搜索区域 title or imdbid
|
||||
:reutrn: 资源列表
|
||||
"""
|
||||
# 未开启的站点不搜索
|
||||
indexer_sites = []
|
||||
|
||||
# 配置的索引站点
|
||||
if not sites:
|
||||
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
|
||||
|
||||
for indexer in await SitesHelper().async_get_indexers():
|
||||
# 检查站点索引开关
|
||||
if not sites or indexer.get("id") in sites:
|
||||
indexer_sites.append(indexer)
|
||||
if not indexer_sites:
|
||||
logger.warn('未开启任何有效站点,无法搜索资源')
|
||||
return []
|
||||
|
||||
# 开始进度
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
# 开始计时
|
||||
start_time = datetime.now()
|
||||
# 总数
|
||||
total_num = len(indexer_sites)
|
||||
# 完成数
|
||||
finish_count = 0
|
||||
# 更新进度
|
||||
progress.update(value=0,
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...")
|
||||
# 结果集
|
||||
results = []
|
||||
|
||||
# 创建异步任务列表
|
||||
tasks = []
|
||||
for site in indexer_sites:
|
||||
if area == "imdbid":
|
||||
# 搜索IMDBID
|
||||
task = self.async_search_torrents(site=site,
|
||||
keyword=mediainfo.imdb_id if mediainfo else None,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
else:
|
||||
# 搜索标题
|
||||
task = self.async_search_torrents(site=site,
|
||||
keyword=keyword,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
tasks.append(task)
|
||||
|
||||
# 使用asyncio.as_completed来处理并发任务
|
||||
for future in asyncio.as_completed(tasks):
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
finish_count += 1
|
||||
result = await future
|
||||
if result:
|
||||
results.extend(result)
|
||||
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
|
||||
progress.update(value=finish_count / total_num * 100,
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.now()
|
||||
# 更新进度
|
||||
progress.update(value=100,
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
# 结束进度
|
||||
progress.end()
|
||||
|
||||
# 返回
|
||||
return results
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import gc
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
@@ -9,7 +8,7 @@ from lxml import etree
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import Event, EventManager, eventmanager
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.db.models.site import Site
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -18,7 +17,7 @@ from app.helper.cloudflare import under_challenge
|
||||
from app.helper.cookie import CookieHelper
|
||||
from app.helper.cookiecloud import CookieCloudHelper
|
||||
from app.helper.rss import RssHelper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas import MessageChannel, Notification, SiteUserData
|
||||
from app.schemas.types import EventType, NotificationType
|
||||
@@ -57,9 +56,9 @@ class SiteChain(ChainBase):
|
||||
if userdata:
|
||||
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
|
||||
name=site.get("name"),
|
||||
payload=userdata.dict())
|
||||
payload=userdata.model_dump())
|
||||
# 发送事件
|
||||
EventManager().send_event(EventType.SiteRefreshed, {
|
||||
eventmanager.send_event(EventType.SiteRefreshed, {
|
||||
"site_id": site.get("id")
|
||||
})
|
||||
# 发送站点消息
|
||||
@@ -104,14 +103,10 @@ class SiteChain(ChainBase):
|
||||
any_site_updated = True
|
||||
result[site.get("name")] = userdata
|
||||
if any_site_updated:
|
||||
EventManager().send_event(EventType.SiteRefreshed, {
|
||||
eventmanager.send_event(EventType.SiteRefreshed, {
|
||||
"site_id": "*"
|
||||
})
|
||||
|
||||
# 如果不是大内存模式,进行垃圾回收
|
||||
if not settings.BIG_MEMORY_MODE:
|
||||
gc.collect()
|
||||
|
||||
return result
|
||||
|
||||
def is_special_site(self, domain: str) -> bool:
|
||||
@@ -271,16 +266,20 @@ class SiteChain(ChainBase):
|
||||
logger.error(f"获取站点页面失败:{url}")
|
||||
return favicon_url, None
|
||||
html = etree.HTML(html_text)
|
||||
if StringUtils.is_valid_html_element(html):
|
||||
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
|
||||
if fav_link:
|
||||
favicon_url = urljoin(url, fav_link[0])
|
||||
try:
|
||||
if StringUtils.is_valid_html_element(html):
|
||||
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
|
||||
if fav_link:
|
||||
favicon_url = urljoin(url, fav_link[0])
|
||||
|
||||
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
|
||||
if res:
|
||||
return favicon_url, base64.b64encode(res.content).decode()
|
||||
else:
|
||||
logger.error(f"获取站点图标失败:{favicon_url}")
|
||||
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
|
||||
if res:
|
||||
return favicon_url, base64.b64encode(res.content).decode()
|
||||
else:
|
||||
logger.error(f"获取站点图标失败:{favicon_url}")
|
||||
finally:
|
||||
if html is not None:
|
||||
del html
|
||||
return favicon_url, None
|
||||
|
||||
def sync_cookies(self, manual=False) -> Tuple[bool, str]:
|
||||
@@ -314,11 +313,16 @@ class SiteChain(ChainBase):
|
||||
siteoper = SiteOper()
|
||||
rsshelper = RssHelper()
|
||||
for domain, cookie in cookies.items():
|
||||
# 检查系统是否停止
|
||||
if global_vars.is_system_stopped:
|
||||
logger.info("系统正在停止,中断CookieCloud同步")
|
||||
return False, "系统正在停止,同步被中断"
|
||||
|
||||
# 索引器信息
|
||||
indexer = siteshelper.get_indexer(domain)
|
||||
# 数据库的站点信息
|
||||
site_info = siteoper.get_by_domain(domain)
|
||||
if site_info and site_info.is_active == 1:
|
||||
if site_info and site_info.is_active:
|
||||
# 站点已存在,检查站点连通性
|
||||
status, msg = self.test(domain)
|
||||
# 更新站点Cookie
|
||||
@@ -331,7 +335,8 @@ class SiteChain(ChainBase):
|
||||
url=site_info.url,
|
||||
cookie=cookie,
|
||||
ua=site_info.ua or settings.USER_AGENT,
|
||||
proxy=True if site_info.proxy else False
|
||||
proxy=True if site_info.proxy else False,
|
||||
timeout=site_info.timeout or 15
|
||||
)
|
||||
if rss_url:
|
||||
logger.info(f"更新站点 {domain} RSS地址 ...")
|
||||
@@ -416,7 +421,7 @@ class SiteChain(ChainBase):
|
||||
|
||||
# 通知站点更新
|
||||
if indexer:
|
||||
EventManager().send_event(EventType.SiteUpdated, {
|
||||
eventmanager.send_event(EventType.SiteUpdated, {
|
||||
"domain": domain,
|
||||
})
|
||||
# 处理完成
|
||||
@@ -559,13 +564,15 @@ class SiteChain(ChainBase):
|
||||
public = site_info.public
|
||||
proxies = settings.PROXY if site_info.proxy else None
|
||||
proxy_server = settings.PROXY_SERVER if site_info.proxy else None
|
||||
timeout = site_info.timeout or 60
|
||||
|
||||
# 访问链接
|
||||
if render:
|
||||
page_source = PlaywrightHelper().get_page_source(url=site_url,
|
||||
cookies=site_cookie,
|
||||
ua=ua,
|
||||
proxies=proxy_server)
|
||||
proxies=proxy_server,
|
||||
timeout=timeout)
|
||||
if not public and not SiteUtils.is_logged_in(page_source):
|
||||
if under_challenge(page_source):
|
||||
return False, f"无法通过Cloudflare!"
|
||||
@@ -698,7 +705,8 @@ class SiteChain(ChainBase):
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code,
|
||||
proxies=settings.PROXY_HOST if site_info.proxy else None
|
||||
proxies=settings.PROXY_SERVER if site_info.proxy else None,
|
||||
timeout=site_info.timeout or 60
|
||||
)
|
||||
if result:
|
||||
cookie, ua, msg = result
|
||||
|
||||
@@ -6,7 +6,6 @@ from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class StorageChain(ChainBase):
|
||||
@@ -110,11 +109,17 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("get_parent_item", fileitem=fileitem)
|
||||
|
||||
def snapshot_storage(self, storage: str, path: Path) -> Optional[Dict[str, float]]:
|
||||
def snapshot_storage(self, storage: str, path: Path,
|
||||
last_snapshot_time: float = None, max_depth: int = 5) -> Optional[Dict[str, Dict]]:
|
||||
"""
|
||||
快照存储
|
||||
:param storage: 存储类型
|
||||
:param path: 路径
|
||||
:param last_snapshot_time: 上次快照时间,用于增量快照
|
||||
:param max_depth: 最大递归深度,避免过深遍历
|
||||
"""
|
||||
return self.run_module("snapshot_storage", storage=storage, path=path)
|
||||
return self.run_module("snapshot_storage", storage=storage, path=path,
|
||||
last_snapshot_time=last_snapshot_time, max_depth=max_depth)
|
||||
|
||||
def storage_usage(self, storage: str) -> Optional[schemas.StorageUsage]:
|
||||
"""
|
||||
@@ -128,8 +133,7 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem,
|
||||
mtype: MediaType = None, delete_self: bool = True) -> bool:
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
@@ -146,7 +150,8 @@ class StorageChain(ChainBase):
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
if fileitem.path == "/" or len(Path(fileitem.path).parts) <= 2:
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 根目录或一级目录不允许删除")
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
@@ -156,13 +161,7 @@ class StorageChain(ChainBase):
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
elif self.any_files(fileitem, extensions=media_exts) is False:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
# 不处理父目录
|
||||
return True
|
||||
|
||||
elif delete_self:
|
||||
# 本身是文件,需要删除文件
|
||||
logger.warn(f"正在删除文件【{fileitem.storage}】{fileitem.path}")
|
||||
@@ -170,36 +169,43 @@ class StorageChain(ChainBase):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
|
||||
if mtype:
|
||||
# 重命名格式
|
||||
rename_format = settings.TV_RENAME_FORMAT \
|
||||
if mtype == MediaType.TV else settings.MOVIE_RENAME_FORMAT
|
||||
# 计算重命名中的文件夹层数
|
||||
rename_format_level = len(rename_format.split("/")) - 1
|
||||
if rename_format_level < 1:
|
||||
return True
|
||||
# 处理媒体文件根目录
|
||||
dir_item = self.get_file_item(storage=fileitem.storage,
|
||||
path=Path(fileitem.path).parents[rename_format_level - 1])
|
||||
else:
|
||||
# 处理上级目录
|
||||
dir_item = self.get_parent_item(fileitem)
|
||||
# 检查和删除上级空目录
|
||||
dir_item = fileitem if fileitem.type == "dir" else self.get_parent_item(fileitem)
|
||||
if not dir_item:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 上级目录不存在")
|
||||
return True
|
||||
|
||||
# 检查和删除上级目录
|
||||
if dir_item and len(Path(dir_item.path).parts) > 2:
|
||||
# 如何目录是所有下载目录、媒体库目录的上级,则不处理
|
||||
for d in DirectoryHelper().get_dirs():
|
||||
if d.download_path and Path(d.download_path).is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 是下载目录本级或上级目录,不删除")
|
||||
return True
|
||||
if d.library_path and Path(d.library_path).is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 是媒体库目录本级或上级目录,不删除")
|
||||
return True
|
||||
# 不存在其他媒体文件,删除空目录
|
||||
if self.any_files(dir_item, extensions=media_exts) is False:
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(dir_item):
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 删除失败")
|
||||
return False
|
||||
# 查找操作文件项匹配的配置目录(资源目录、媒体库目录)
|
||||
associated_dir = max(
|
||||
(
|
||||
Path(p)
|
||||
for d in DirectoryHelper().get_dirs()
|
||||
for p in (d.download_path, d.library_path)
|
||||
if p and fileitem_path.is_relative_to(p)
|
||||
),
|
||||
key=lambda path: len(path.parts),
|
||||
default=None,
|
||||
)
|
||||
|
||||
while dir_item and len(Path(dir_item.path).parts) > 2:
|
||||
# 目录是资源目录、媒体库目录的上级,则不处理
|
||||
if associated_dir and associated_dir.is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 位于资源或媒体库目录结构中,不删除")
|
||||
break
|
||||
|
||||
elif not associated_dir and self.list_files(dir_item, recursion=False):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 不是空目录,不删除")
|
||||
break
|
||||
|
||||
if self.any_files(dir_item, extensions=media_exts) is not False:
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 存在媒体文件,不删除")
|
||||
break
|
||||
|
||||
# 删除空目录并继续处理父目录
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(dir_item):
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 删除失败")
|
||||
return False
|
||||
dir_item = self.get_parent_item(dir_item)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import random
|
||||
import threading
|
||||
@@ -16,7 +15,7 @@ from app.chain.tmdb import TmdbChain
|
||||
from app.chain.torrents import TorrentsChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import TorrentInfo, Context, MediaInfo
|
||||
from app.core.event import eventmanager, Event, EventManager
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.meta.words import WordsMatcher
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -39,6 +38,84 @@ class SubscribeChain(ChainBase):
|
||||
"""
|
||||
|
||||
_rlock = threading.RLock()
|
||||
# 避免莫名原因导致长时间持有锁
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
event_data = MediaRecognizeConvertEventData(
|
||||
mediaid=_mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
if event_data.media_dict:
|
||||
mediachain = MediaChain()
|
||||
new_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
return mediachain.recognize_media(meta=_meta, tmdbid=new_id)
|
||||
elif event_data.convert_type == "douban":
|
||||
return mediachain.recognize_media(meta=_meta, doubanid=new_id)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def __async_get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
event_data = MediaRecognizeConvertEventData(
|
||||
mediaid=_mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
if event_data.media_dict:
|
||||
mediachain = MediaChain()
|
||||
new_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
return await mediachain.async_recognize_media(meta=_meta, tmdbid=new_id)
|
||||
elif event_data.convert_type == "douban":
|
||||
return await mediachain.async_recognize_media(meta=_meta, doubanid=new_id)
|
||||
return None
|
||||
|
||||
def __get_default_kwargs(self, mtype: MediaType, **kwargs) -> dict:
|
||||
"""
|
||||
获取订阅默认配置
|
||||
:param mtype: 媒体类型
|
||||
:param key: 配置键
|
||||
:return: 配置值
|
||||
"""
|
||||
return {
|
||||
'quality': self.__get_default_subscribe_config(mtype, "quality") if not kwargs.get(
|
||||
"quality") else kwargs.get("quality"),
|
||||
'resolution': self.__get_default_subscribe_config(mtype, "resolution") if not kwargs.get(
|
||||
"resolution") else kwargs.get("resolution"),
|
||||
'effect': self.__get_default_subscribe_config(mtype, "effect") if not kwargs.get(
|
||||
"effect") else kwargs.get("effect"),
|
||||
'include': self.__get_default_subscribe_config(mtype, "include") if not kwargs.get(
|
||||
"include") else kwargs.get("include"),
|
||||
'exclude': self.__get_default_subscribe_config(mtype, "exclude") if not kwargs.get(
|
||||
"exclude") else kwargs.get("exclude"),
|
||||
'best_version': self.__get_default_subscribe_config(mtype, "best_version") if not kwargs.get(
|
||||
"best_version") else kwargs.get("best_version"),
|
||||
'search_imdbid': self.__get_default_subscribe_config(mtype, "search_imdbid") if not kwargs.get(
|
||||
"search_imdbid") else kwargs.get("search_imdbid"),
|
||||
'sites': self.__get_default_subscribe_config(mtype, "sites") or None if not kwargs.get(
|
||||
"sites") else kwargs.get("sites"),
|
||||
'downloader': self.__get_default_subscribe_config(mtype, "downloader") if not kwargs.get(
|
||||
"downloader") else kwargs.get("downloader"),
|
||||
'save_path': self.__get_default_subscribe_config(mtype, "save_path") if not kwargs.get(
|
||||
"save_path") else kwargs.get("save_path"),
|
||||
'filter_groups': self.__get_default_subscribe_config(mtype, "filter_groups") if not kwargs.get(
|
||||
"filter_groups") else kwargs.get("filter_groups")
|
||||
}
|
||||
|
||||
def add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
@@ -59,27 +136,6 @@ class SubscribeChain(ChainBase):
|
||||
识别媒体信息并添加订阅
|
||||
"""
|
||||
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
event_data = MediaRecognizeConvertEventData(
|
||||
mediaid=_mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
if event_data.media_dict:
|
||||
mediachain = MediaChain()
|
||||
new_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
return mediachain.recognize_media(meta=_meta, tmdbid=new_id)
|
||||
elif event_data.convert_type == "douban":
|
||||
return mediachain.recognize_media(meta=_meta, doubanid=new_id)
|
||||
return None
|
||||
|
||||
logger.info(f'开始添加订阅,标题:{title} ...')
|
||||
|
||||
mediainfo = None
|
||||
@@ -102,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = __get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -113,7 +169,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = __get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
@@ -175,30 +231,8 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo.bangumi_id = bangumiid
|
||||
|
||||
# 添加订阅
|
||||
kwargs.update({
|
||||
'quality': self.__get_default_subscribe_config(mediainfo.type, "quality") if not kwargs.get(
|
||||
"quality") else kwargs.get("quality"),
|
||||
'resolution': self.__get_default_subscribe_config(mediainfo.type, "resolution") if not kwargs.get(
|
||||
"resolution") else kwargs.get("resolution"),
|
||||
'effect': self.__get_default_subscribe_config(mediainfo.type, "effect") if not kwargs.get(
|
||||
"effect") else kwargs.get("effect"),
|
||||
'include': self.__get_default_subscribe_config(mediainfo.type, "include") if not kwargs.get(
|
||||
"include") else kwargs.get("include"),
|
||||
'exclude': self.__get_default_subscribe_config(mediainfo.type, "exclude") if not kwargs.get(
|
||||
"exclude") else kwargs.get("exclude"),
|
||||
'best_version': self.__get_default_subscribe_config(mediainfo.type, "best_version") if not kwargs.get(
|
||||
"best_version") else kwargs.get("best_version"),
|
||||
'search_imdbid': self.__get_default_subscribe_config(mediainfo.type, "search_imdbid") if not kwargs.get(
|
||||
"search_imdbid") else kwargs.get("search_imdbid"),
|
||||
'sites': self.__get_default_subscribe_config(mediainfo.type, "sites") or None if not kwargs.get(
|
||||
"sites") else kwargs.get("sites"),
|
||||
'downloader': self.__get_default_subscribe_config(mediainfo.type, "downloader") if not kwargs.get(
|
||||
"downloader") else kwargs.get("downloader"),
|
||||
'save_path': self.__get_default_subscribe_config(mediainfo.type, "save_path") if not kwargs.get(
|
||||
"save_path") else kwargs.get("save_path"),
|
||||
'filter_groups': self.__get_default_subscribe_config(mediainfo.type, "filter_groups") if not kwargs.get(
|
||||
"filter_groups") else kwargs.get("filter_groups")
|
||||
})
|
||||
kwargs.update(self.__get_default_kwargs(mediainfo.type, **kwargs))
|
||||
|
||||
# 操作数据库
|
||||
sid, err_msg = SubscribeOper().add(mediainfo=mediainfo, season=season, username=username, **kwargs)
|
||||
if not sid:
|
||||
@@ -236,7 +270,7 @@ class SubscribeChain(ChainBase):
|
||||
username=username
|
||||
)
|
||||
# 发送事件
|
||||
EventManager().send_event(EventType.SubscribeAdded, {
|
||||
eventmanager.send_event(EventType.SubscribeAdded, {
|
||||
"subscribe_id": sid,
|
||||
"username": username,
|
||||
"mediainfo": mediainfo.to_dict(),
|
||||
@@ -260,6 +294,183 @@ class SubscribeChain(ChainBase):
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
|
||||
async def async_add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
bangumiid: Optional[int] = None,
|
||||
mediaid: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
message: Optional[bool] = True,
|
||||
exist_ok: Optional[bool] = False,
|
||||
**kwargs) -> Tuple[Optional[int], str]:
|
||||
"""
|
||||
异步识别媒体信息并添加订阅
|
||||
"""
|
||||
|
||||
logger.info(f'开始添加订阅,标题:{title} ...')
|
||||
|
||||
mediainfo = None
|
||||
metainfo = MetaInfo(title)
|
||||
if year:
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
# TMDB识别模式
|
||||
if not tmdbid:
|
||||
if doubanid:
|
||||
# 将豆瓣信息转换为TMDB信息
|
||||
tmdbinfo = await MediaChain().async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=mtype)
|
||||
if tmdbinfo:
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = await self.__async_get_event_meida(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = await self.async_recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
episode_group=episode_group, cache=False)
|
||||
else:
|
||||
if doubanid:
|
||||
# 豆瓣识别模式,不使用缓存
|
||||
mediainfo = await self.async_recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = await self.__async_get_event_meida(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
if not mediainfo:
|
||||
mediainfo = await self.async_recognize_media(meta=metainfo, episode_group=episode_group)
|
||||
|
||||
# 识别失败
|
||||
if not mediainfo:
|
||||
logger.warn(f'未识别到媒体信息,标题:{title},tmdbid:{tmdbid},doubanid:{doubanid}')
|
||||
return None, "未识别到媒体信息"
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
if not mediainfo.seasons or episode_group:
|
||||
# 补充媒体信息
|
||||
mediainfo = await self.async_recognize_media(mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
bangumiid=mediainfo.bangumi_id,
|
||||
episode_group=episode_group,
|
||||
cache=False)
|
||||
if not mediainfo:
|
||||
logger.error(f"媒体信息识别失败!")
|
||||
return None, "媒体信息识别失败"
|
||||
if not mediainfo.seasons:
|
||||
logger.error(f"媒体信息中没有季集信息,标题:{title},tmdbid:{tmdbid},doubanid:{doubanid}")
|
||||
return None, "媒体信息中没有季集信息"
|
||||
total_episode = len(mediainfo.seasons.get(season) or [])
|
||||
if not total_episode:
|
||||
logger.error(f'未获取到总集数,标题:{title},tmdbid:{tmdbid}, doubanid:{doubanid}')
|
||||
return None, f"未获取到第 {season} 季的总集数"
|
||||
kwargs.update({
|
||||
'total_episode': total_episode
|
||||
})
|
||||
# 缺失集
|
||||
if not kwargs.get('lack_episode'):
|
||||
kwargs.update({
|
||||
'lack_episode': kwargs.get('total_episode')
|
||||
})
|
||||
else:
|
||||
# 避免season为0的问题
|
||||
season = None
|
||||
|
||||
# 更新媒体图片
|
||||
await self.async_obtain_images(mediainfo=mediainfo)
|
||||
# 合并信息
|
||||
if doubanid:
|
||||
mediainfo.douban_id = doubanid
|
||||
if bangumiid:
|
||||
mediainfo.bangumi_id = bangumiid
|
||||
|
||||
# 列新默认参数
|
||||
kwargs.update(self.__get_default_kwargs(mediainfo.type, **kwargs))
|
||||
|
||||
# 操作数据库
|
||||
sid, err_msg = await SubscribeOper().async_add(mediainfo=mediainfo, season=season, username=username, **kwargs)
|
||||
if not sid:
|
||||
logger.error(f'{mediainfo.title_year} {err_msg}')
|
||||
if not exist_ok and message:
|
||||
# 失败发回原用户
|
||||
await self.async_post_message(schemas.Notification(channel=channel,
|
||||
source=source,
|
||||
mtype=NotificationType.Subscribe,
|
||||
title=f"{mediainfo.title_year} {metainfo.season} "
|
||||
f"添加订阅失败!",
|
||||
text=f"{err_msg}",
|
||||
image=mediainfo.get_message_image(),
|
||||
userid=userid))
|
||||
return None, err_msg
|
||||
elif message:
|
||||
if mediainfo.type == MediaType.TV:
|
||||
link = settings.MP_DOMAIN('#/subscribe/tv?tab=mysub')
|
||||
else:
|
||||
link = settings.MP_DOMAIN('#/subscribe/movie?tab=mysub')
|
||||
# 订阅成功按规则发送消息
|
||||
await self.async_post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
mtype=NotificationType.Subscribe,
|
||||
ctype=ContentType.SubscribeAdded,
|
||||
image=mediainfo.get_message_image(),
|
||||
link=link,
|
||||
userid=userid,
|
||||
username=username
|
||||
),
|
||||
meta=metainfo,
|
||||
mediainfo=mediainfo,
|
||||
username=username
|
||||
)
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeAdded, {
|
||||
"subscribe_id": sid,
|
||||
"username": username,
|
||||
"mediainfo": mediainfo.to_dict(),
|
||||
})
|
||||
# 统计订阅
|
||||
await SubscribeHelper().async_sub_reg({
|
||||
"name": title,
|
||||
"year": year,
|
||||
"type": metainfo.type.value,
|
||||
"tmdbid": mediainfo.tmdb_id,
|
||||
"imdbid": mediainfo.imdb_id,
|
||||
"tvdbid": mediainfo.tvdb_id,
|
||||
"doubanid": mediainfo.douban_id,
|
||||
"bangumiid": mediainfo.bangumi_id,
|
||||
"season": metainfo.begin_season,
|
||||
"poster": mediainfo.get_poster_image(),
|
||||
"backdrop": mediainfo.get_backdrop_image(),
|
||||
"vote": mediainfo.vote_average,
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
|
||||
@staticmethod
|
||||
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
|
||||
"""
|
||||
@@ -279,8 +490,15 @@ class SubscribeChain(ChainBase):
|
||||
:param manual: 是否手动搜索
|
||||
:return: 更新订阅状态为R或删除订阅
|
||||
"""
|
||||
with self._rlock:
|
||||
logger.debug(f"search lock acquired at {datetime.now()}")
|
||||
lock_acquired = False
|
||||
try:
|
||||
if lock_acquired := self._rlock.acquire(
|
||||
blocking=True, timeout=self._LOCK_TIMOUT
|
||||
):
|
||||
logger.debug(f"search lock acquired at {datetime.now()}")
|
||||
else:
|
||||
logger.warn("search上锁超时")
|
||||
|
||||
subscribeoper = SubscribeOper()
|
||||
if sid:
|
||||
subscribe = subscribeoper.get(sid)
|
||||
@@ -434,14 +652,13 @@ class SubscribeChain(ChainBase):
|
||||
else:
|
||||
self.messagehelper.put('没有找到订阅!', title="订阅搜索", role="system")
|
||||
|
||||
logger.debug(f"search Lock released at {datetime.now()}")
|
||||
finally:
|
||||
subscribes.clear()
|
||||
del subscribes
|
||||
|
||||
# 如果不是大内存模式,进行垃圾回收
|
||||
if not settings.BIG_MEMORY_MODE:
|
||||
gc.collect()
|
||||
finally:
|
||||
if lock_acquired:
|
||||
self._rlock.release()
|
||||
logger.debug(f"search Lock released at {datetime.now()}")
|
||||
|
||||
def update_subscribe_priority(self, subscribe: Subscribe, meta: MetaBase,
|
||||
mediainfo: MediaInfo, downloads: Optional[List[Context]]):
|
||||
@@ -514,9 +731,6 @@ class SubscribeChain(ChainBase):
|
||||
self.match(
|
||||
TorrentsChain().refresh(sites=sites)
|
||||
)
|
||||
# 如果不是大内存模式,进行垃圾回收
|
||||
if not settings.BIG_MEMORY_MODE:
|
||||
gc.collect()
|
||||
|
||||
@staticmethod
|
||||
def get_sub_sites(subscribe: Subscribe) -> List[int]:
|
||||
@@ -546,10 +760,15 @@ class SubscribeChain(ChainBase):
|
||||
:return: 返回[]代表所有站点命中,返回None代表没有订阅
|
||||
"""
|
||||
ret_sites = []
|
||||
subscribes = SubscribeOper().list()
|
||||
if not subscribes:
|
||||
# 没有订阅
|
||||
return None
|
||||
# 刷新订阅选中的Rss站点
|
||||
for subscribe in SubscribeOper().list(self.get_states_for_search('R')):
|
||||
for subscribe in subscribes:
|
||||
# 刷新选中的站点
|
||||
ret_sites.extend(self.get_sub_sites(subscribe))
|
||||
if subscribe.state in self.get_states_for_search('R'):
|
||||
ret_sites.extend(self.get_sub_sites(subscribe))
|
||||
# 去重
|
||||
if ret_sites:
|
||||
ret_sites = list(set(ret_sites))
|
||||
@@ -564,8 +783,14 @@ class SubscribeChain(ChainBase):
|
||||
logger.warn('没有缓存资源,无法匹配订阅')
|
||||
return
|
||||
|
||||
with self._rlock:
|
||||
logger.debug(f"match lock acquired at {datetime.now()}")
|
||||
lock_acquired = False
|
||||
try:
|
||||
if lock_acquired := self._rlock.acquire(
|
||||
blocking=True, timeout=self._LOCK_TIMOUT
|
||||
):
|
||||
logger.debug(f"match lock acquired at {datetime.now()}")
|
||||
else:
|
||||
logger.warn("match上锁超时")
|
||||
|
||||
# 预识别所有未识别的种子
|
||||
processed_torrents: Dict[str, List[Context]] = {}
|
||||
@@ -576,15 +801,27 @@ class SubscribeChain(ChainBase):
|
||||
for context in contexts:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
# 如果种子未识别,尝试识别
|
||||
if not context.media_info or (not context.media_info.tmdb_id
|
||||
and not context.media_info.douban_id):
|
||||
# 如果种子未识别且失败次数未超过3次,尝试识别
|
||||
if (not context.media_info or (not context.media_info.tmdb_id
|
||||
and not context.media_info.douban_id)) and context.media_recognize_fail_count < 3:
|
||||
logger.debug(
|
||||
f'尝试重新识别种子:{context.torrent_info.title},当前失败次数:{context.media_recognize_fail_count}/3')
|
||||
re_mediainfo = self.recognize_media(meta=context.meta_info)
|
||||
if re_mediainfo:
|
||||
# 清理多余信息
|
||||
re_mediainfo.clear()
|
||||
# 更新种子缓存
|
||||
context.media_info = re_mediainfo
|
||||
# 重置失败次数
|
||||
context.media_recognize_fail_count = 0
|
||||
logger.debug(f'种子 {context.torrent_info.title} 重新识别成功')
|
||||
else:
|
||||
# 识别失败,增加失败次数
|
||||
context.media_recognize_fail_count += 1
|
||||
logger.debug(
|
||||
f'种子 {context.torrent_info.title} 媒体识别失败,失败次数:{context.media_recognize_fail_count}/3')
|
||||
elif context.media_recognize_fail_count >= 3:
|
||||
logger.debug(f'种子 {context.torrent_info.title} 已达到最大识别失败次数(3次),跳过识别')
|
||||
# 添加已预处理
|
||||
processed_torrents[domain].append(context)
|
||||
|
||||
@@ -647,153 +884,150 @@ class SubscribeChain(ChainBase):
|
||||
if domains and domain not in domains:
|
||||
continue
|
||||
logger.debug(f'开始匹配站点:{domain},共缓存了 {len(contexts)} 个种子...')
|
||||
try:
|
||||
for context in contexts:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
# 提取信息
|
||||
_context = copy.copy(context)
|
||||
torrent_meta = _context.meta_info
|
||||
torrent_mediainfo = _context.media_info
|
||||
torrent_info = _context.torrent_info
|
||||
for context in contexts:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
# 提取信息
|
||||
_context = copy.copy(context)
|
||||
torrent_meta = _context.meta_info
|
||||
torrent_mediainfo = _context.media_info
|
||||
torrent_info = _context.torrent_info
|
||||
|
||||
# 不在订阅站点范围的不处理
|
||||
sub_sites = self.get_sub_sites(subscribe)
|
||||
if sub_sites and torrent_info.site not in sub_sites:
|
||||
logger.debug(f"{torrent_info.site_name} - {torrent_info.title} 不符合订阅站点要求")
|
||||
continue
|
||||
# 不在订阅站点范围的不处理
|
||||
sub_sites = self.get_sub_sites(subscribe)
|
||||
if sub_sites and torrent_info.site not in sub_sites:
|
||||
logger.debug(f"{torrent_info.site_name} - {torrent_info.title} 不符合订阅站点要求")
|
||||
continue
|
||||
|
||||
# 有自定义识别词时,需要判断是否需要重新识别
|
||||
if custom_words_list:
|
||||
# 使用org_string,应用一次后理论上不能再次应用
|
||||
_, apply_words = wordsmatcher.prepare(torrent_meta.org_string,
|
||||
custom_words=custom_words_list)
|
||||
if apply_words:
|
||||
logger.info(
|
||||
f'{torrent_info.site_name} - {torrent_info.title} 因订阅存在自定义识别词,重新识别元数据...')
|
||||
# 重新识别元数据
|
||||
torrent_meta = MetaInfo(title=torrent_info.title, subtitle=torrent_info.description,
|
||||
custom_words=custom_words_list)
|
||||
# 更新元数据缓存
|
||||
_context.meta_info = torrent_meta
|
||||
# 重新识别媒体信息
|
||||
torrent_mediainfo = self.recognize_media(meta=torrent_meta,
|
||||
episode_group=subscribe.episode_group)
|
||||
if torrent_mediainfo:
|
||||
# 清理多余信息
|
||||
torrent_mediainfo.clear()
|
||||
# 更新种子缓存
|
||||
_context.media_info = torrent_mediainfo
|
||||
|
||||
# 如果仍然没有识别到媒体信息,尝试标题匹配
|
||||
if not torrent_mediainfo or (not torrent_mediainfo.tmdb_id and not torrent_mediainfo.douban_id):
|
||||
# 有自定义识别词时,需要判断是否需要重新识别
|
||||
if custom_words_list:
|
||||
# 使用org_string,应用一次后理论上不能再次应用
|
||||
_, apply_words = wordsmatcher.prepare(torrent_meta.org_string,
|
||||
custom_words=custom_words_list)
|
||||
if apply_words:
|
||||
logger.info(
|
||||
f'{torrent_info.site_name} - {torrent_info.title} 重新识别失败,尝试通过标题匹配...')
|
||||
if torrenthelper.match_torrent(mediainfo=mediainfo,
|
||||
torrent_meta=torrent_meta,
|
||||
torrent=torrent_info):
|
||||
# 匹配成功
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过标题匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
torrent_mediainfo = mediainfo
|
||||
f'{torrent_info.site_name} - {torrent_info.title} 因订阅存在自定义识别词,重新识别元数据...')
|
||||
# 重新识别元数据
|
||||
torrent_meta = MetaInfo(title=torrent_info.title, subtitle=torrent_info.description,
|
||||
custom_words=custom_words_list)
|
||||
# 更新元数据缓存
|
||||
_context.meta_info = torrent_meta
|
||||
# 重新识别媒体信息
|
||||
torrent_mediainfo = self.recognize_media(meta=torrent_meta,
|
||||
episode_group=subscribe.episode_group)
|
||||
if torrent_mediainfo:
|
||||
# 清理多余信息
|
||||
torrent_mediainfo.clear()
|
||||
# 更新种子缓存
|
||||
_context.media_info = mediainfo
|
||||
else:
|
||||
continue
|
||||
_context.media_info = torrent_mediainfo
|
||||
|
||||
# 直接比对媒体信息
|
||||
if torrent_mediainfo and (torrent_mediainfo.tmdb_id or torrent_mediainfo.douban_id):
|
||||
if torrent_mediainfo.type != mediainfo.type:
|
||||
continue
|
||||
if torrent_mediainfo.tmdb_id \
|
||||
and torrent_mediainfo.tmdb_id != mediainfo.tmdb_id:
|
||||
continue
|
||||
if torrent_mediainfo.douban_id \
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
# 如果仍然没有识别到媒体信息,尝试标题匹配
|
||||
if not torrent_mediainfo or (
|
||||
not torrent_mediainfo.tmdb_id and not torrent_mediainfo.douban_id):
|
||||
logger.debug(
|
||||
f'{torrent_info.site_name} - {torrent_info.title} 重新识别失败,尝试通过标题匹配...')
|
||||
if torrenthelper.match_torrent(mediainfo=mediainfo,
|
||||
torrent_meta=torrent_meta,
|
||||
torrent=torrent_info):
|
||||
# 匹配成功
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
f'{mediainfo.title_year} 通过标题匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
torrent_mediainfo = mediainfo
|
||||
# 更新种子缓存
|
||||
_context.media_info = mediainfo
|
||||
else:
|
||||
continue
|
||||
|
||||
# 如果是电视剧
|
||||
if torrent_mediainfo.type == MediaType.TV:
|
||||
# 有多季的不要
|
||||
if len(torrent_meta.season_list) > 1:
|
||||
logger.debug(f'{torrent_info.title} 有多季,不处理')
|
||||
continue
|
||||
# 比对季
|
||||
if torrent_meta.begin_season:
|
||||
if meta.begin_season != torrent_meta.begin_season:
|
||||
logger.debug(f'{torrent_info.title} 季不匹配')
|
||||
continue
|
||||
elif meta.begin_season != 1:
|
||||
# 直接比对媒体信息
|
||||
if torrent_mediainfo and (torrent_mediainfo.tmdb_id or torrent_mediainfo.douban_id):
|
||||
if torrent_mediainfo.type != mediainfo.type:
|
||||
continue
|
||||
if torrent_mediainfo.tmdb_id \
|
||||
and torrent_mediainfo.tmdb_id != mediainfo.tmdb_id:
|
||||
continue
|
||||
if torrent_mediainfo.douban_id \
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
else:
|
||||
continue
|
||||
|
||||
# 如果是电视剧
|
||||
if torrent_mediainfo.type == MediaType.TV:
|
||||
# 有多季的不要
|
||||
if len(torrent_meta.season_list) > 1:
|
||||
logger.debug(f'{torrent_info.title} 有多季,不处理')
|
||||
continue
|
||||
# 比对季
|
||||
if torrent_meta.begin_season:
|
||||
if meta.begin_season != torrent_meta.begin_season:
|
||||
logger.debug(f'{torrent_info.title} 季不匹配')
|
||||
continue
|
||||
# 非洗版
|
||||
if not subscribe.best_version:
|
||||
# 不是缺失的剧集不要
|
||||
if no_exists and no_exists.get(mediakey):
|
||||
# 缺失集
|
||||
no_exists_info = no_exists.get(mediakey).get(subscribe.season)
|
||||
if no_exists_info:
|
||||
# 是否有交集
|
||||
if no_exists_info.episodes and \
|
||||
torrent_meta.episode_list and \
|
||||
not set(no_exists_info.episodes).intersection(
|
||||
set(torrent_meta.episode_list)
|
||||
):
|
||||
logger.debug(
|
||||
f'{torrent_info.title} 对应剧集 {torrent_meta.episode_list} 未包含缺失的剧集'
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 洗版时,非整季不要
|
||||
if meta.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
elif meta.begin_season != 1:
|
||||
logger.debug(f'{torrent_info.title} 季不匹配')
|
||||
continue
|
||||
# 非洗版
|
||||
if not subscribe.best_version:
|
||||
# 不是缺失的剧集不要
|
||||
if no_exists and no_exists.get(mediakey):
|
||||
# 缺失集
|
||||
no_exists_info = no_exists.get(mediakey).get(subscribe.season)
|
||||
if no_exists_info:
|
||||
# 是否有交集
|
||||
if no_exists_info.episodes and \
|
||||
torrent_meta.episode_list and \
|
||||
not set(no_exists_info.episodes).intersection(
|
||||
set(torrent_meta.episode_list)
|
||||
):
|
||||
logger.debug(
|
||||
f'{torrent_info.title} 对应剧集 {torrent_meta.episode_list} 未包含缺失的剧集'
|
||||
)
|
||||
continue
|
||||
|
||||
# 匹配订阅附加参数
|
||||
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
|
||||
filter_params=self.get_params(subscribe)):
|
||||
continue
|
||||
|
||||
# 优先级过滤规则
|
||||
if subscribe.best_version:
|
||||
rule_groups = subscribe.filter_groups \
|
||||
or systemconfig.get(SystemConfigKey.BestVersionFilterRuleGroups)
|
||||
else:
|
||||
rule_groups = subscribe.filter_groups \
|
||||
or systemconfig.get(SystemConfigKey.SubscribeFilterRuleGroups)
|
||||
result: List[TorrentInfo] = self.filter_torrents(
|
||||
rule_groups=rule_groups,
|
||||
torrent_list=[torrent_info],
|
||||
mediainfo=torrent_mediainfo)
|
||||
if result is not None and not result:
|
||||
# 不符合过滤规则
|
||||
logger.debug(f"{torrent_info.title} 不匹配过滤规则")
|
||||
# 洗版时,非整季不要
|
||||
if meta.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
|
||||
# 匹配订阅附加参数
|
||||
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
|
||||
filter_params=self.get_params(subscribe)):
|
||||
continue
|
||||
|
||||
# 优先级过滤规则
|
||||
if subscribe.best_version:
|
||||
rule_groups = subscribe.filter_groups \
|
||||
or systemconfig.get(SystemConfigKey.BestVersionFilterRuleGroups)
|
||||
else:
|
||||
rule_groups = subscribe.filter_groups \
|
||||
or systemconfig.get(SystemConfigKey.SubscribeFilterRuleGroups)
|
||||
result: List[TorrentInfo] = self.filter_torrents(
|
||||
rule_groups=rule_groups,
|
||||
torrent_list=[torrent_info],
|
||||
mediainfo=torrent_mediainfo)
|
||||
if result is not None and not result:
|
||||
# 不符合过滤规则
|
||||
logger.debug(f"{torrent_info.title} 不匹配过滤规则")
|
||||
continue
|
||||
|
||||
# 洗版时,优先级小于已下载优先级的不要
|
||||
if subscribe.best_version:
|
||||
if subscribe.current_priority \
|
||||
and torrent_info.pri_order <= subscribe.current_priority:
|
||||
logger.info(
|
||||
f'{subscribe.name} 正在洗版,{torrent_info.title} 优先级低于或等于已下载优先级')
|
||||
continue
|
||||
|
||||
# 洗版时,优先级小于已下载优先级的不要
|
||||
if subscribe.best_version:
|
||||
if subscribe.current_priority \
|
||||
and torrent_info.pri_order <= subscribe.current_priority:
|
||||
logger.info(
|
||||
f'{subscribe.name} 正在洗版,{torrent_info.title} 优先级低于或等于已下载优先级')
|
||||
continue
|
||||
|
||||
# 匹配成功
|
||||
logger.info(f'{mediainfo.title_year} 匹配成功:{torrent_info.title}')
|
||||
# 自定义属性
|
||||
if subscribe.media_category:
|
||||
torrent_mediainfo.category = subscribe.media_category
|
||||
if subscribe.episode_group:
|
||||
torrent_mediainfo.episode_group = subscribe.episode_group
|
||||
_match_context.append(_context)
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
# 匹配成功
|
||||
logger.info(f'{mediainfo.title_year} 匹配成功:{torrent_info.title}')
|
||||
# 自定义属性
|
||||
if subscribe.media_category:
|
||||
torrent_mediainfo.category = subscribe.media_category
|
||||
if subscribe.episode_group:
|
||||
torrent_mediainfo.episode_group = subscribe.episode_group
|
||||
_match_context.append(_context)
|
||||
|
||||
if not _match_context:
|
||||
# 未匹配到资源
|
||||
@@ -809,7 +1043,8 @@ class SubscribeChain(ChainBase):
|
||||
username=subscribe.username,
|
||||
save_path=subscribe.save_path,
|
||||
downloader=subscribe.downloader,
|
||||
source=self.get_subscribe_source_keyword(subscribe)
|
||||
source=self.get_subscribe_source_keyword(
|
||||
subscribe)
|
||||
)
|
||||
|
||||
# 同步外部修改,更新订阅信息
|
||||
@@ -824,8 +1059,10 @@ class SubscribeChain(ChainBase):
|
||||
del processed_torrents
|
||||
subscribes.clear()
|
||||
del subscribes
|
||||
|
||||
logger.debug(f"match Lock released at {datetime.now()}")
|
||||
finally:
|
||||
if lock_acquired:
|
||||
self._rlock.release()
|
||||
logger.debug(f"match Lock released at {datetime.now()}")
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
@@ -947,6 +1184,42 @@ class SubscribeChain(ChainBase):
|
||||
logger.error(f'follow用户分享订阅 {title} 添加失败:{message}')
|
||||
logger.info(f'follow用户分享订阅刷新完成,共添加 {success_count} 个订阅')
|
||||
|
||||
async def cache_calendar(self):
|
||||
"""
|
||||
预缓存订阅日历,实际上就是查询一遍所有订阅的媒体信息
|
||||
前端请示是异常的,所以需要使用异步缓存方法
|
||||
"""
|
||||
logger.info(f'开始预缓存订阅日历 ...')
|
||||
for subscribe in await SubscribeOper().async_list():
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
try:
|
||||
mtype = MediaType(subscribe.type)
|
||||
except ValueError:
|
||||
logger.error(f'订阅 {subscribe.name} 类型错误:{subscribe.type}')
|
||||
continue
|
||||
# 识别媒体信息
|
||||
if mtype == MediaType.MOVIE:
|
||||
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mtype,
|
||||
tmdbid=subscribe.tmdbid,
|
||||
doubanid=subscribe.doubanid,
|
||||
bangumiid=subscribe.bangumiid,
|
||||
episode_group=subscribe.episode_group,
|
||||
cache=False)
|
||||
if not mediainfo:
|
||||
logger.warn(
|
||||
f'未识别到媒体信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid},doubanid:{subscribe.doubanid}')
|
||||
continue
|
||||
else:
|
||||
episodes = await TmdbChain().async_tmdb_episodes(tmdbid=subscribe.tmdbid,
|
||||
season=subscribe.season,
|
||||
episode_group=subscribe.episode_group)
|
||||
if not episodes:
|
||||
logger.warn(
|
||||
f'未识别到季集信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid},豆瓣ID:{subscribe.doubanid},季:{subscribe.season}')
|
||||
continue
|
||||
logger.info(f'订阅日历预缓存完成')
|
||||
|
||||
@staticmethod
|
||||
def __update_subscribe_note(subscribe: Subscribe, downloads: Optional[List[Context]]):
|
||||
"""
|
||||
@@ -1075,7 +1348,7 @@ class SubscribeChain(ChainBase):
|
||||
username=subscribe.username
|
||||
)
|
||||
# 发送事件
|
||||
EventManager().send_event(EventType.SubscribeComplete, {
|
||||
eventmanager.send_event(EventType.SubscribeComplete, {
|
||||
"subscribe_id": subscribe.id,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
"mediainfo": mediainfo.to_dict(),
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Union, Optional
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.helper.system import SystemHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, MessageChannel
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.system import SystemUtils
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.plugin import PluginHelper
|
||||
from version import FRONTEND_VERSION, APP_VERSION
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class SystemChain(ChainBase):
|
||||
重启系统
|
||||
"""
|
||||
from app.core.config import global_vars
|
||||
|
||||
|
||||
if channel and userid:
|
||||
self.post_message(Notification(channel=channel, source=source,
|
||||
title="系统正在重启,请耐心等候!", userid=userid))
|
||||
@@ -136,13 +136,6 @@ class SystemChain(ChainBase):
|
||||
shutil.rmtree(target_path)
|
||||
shutil.copytree(item, target_path)
|
||||
logger.info(f"已恢复插件目录: {item.name}")
|
||||
# 安装依赖
|
||||
requirements_file = target_path / "requirements.txt"
|
||||
if requirements_file.exists():
|
||||
logger.info(f"正在安装插件 {item.name} 的依赖...")
|
||||
success, message = PluginHelper.pip_install_with_fallback(requirements_file)
|
||||
if not success:
|
||||
logger.warn(f"插件 {item.name} 依赖安装失败: {message}")
|
||||
restored_count += 1
|
||||
# 如果是文件
|
||||
elif item.is_file():
|
||||
@@ -155,6 +148,9 @@ class SystemChain(ChainBase):
|
||||
|
||||
logger.info(f"插件恢复完成,共恢复 {restored_count} 个项目")
|
||||
|
||||
# 安装缺少的依赖
|
||||
PluginManager.install_plugin_missing_dependencies()
|
||||
|
||||
# 删除备份目录
|
||||
try:
|
||||
shutil.rmtree(backup_dir)
|
||||
|
||||
@@ -164,3 +164,159 @@ class TmdbChain(ChainBase):
|
||||
if infos:
|
||||
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
|
||||
return []
|
||||
|
||||
async def async_tmdb_discover(self, mtype: MediaType,
|
||||
sort_by: str,
|
||||
with_genres: str,
|
||||
with_original_language: str,
|
||||
with_keywords: str,
|
||||
with_watch_providers: str,
|
||||
vote_average: float,
|
||||
vote_count: int,
|
||||
release_date: str,
|
||||
page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
发现TMDB电影、剧集(异步版本)
|
||||
:param mtype: 媒体类型
|
||||
:param sort_by: 排序方式
|
||||
:param with_genres: 类型
|
||||
:param with_original_language: 语言
|
||||
:param with_keywords: 关键字
|
||||
:param with_watch_providers: 提供商
|
||||
:param vote_average: 评分
|
||||
:param vote_count: 评分人数
|
||||
:param release_date: 上映日期
|
||||
:param page: 页码
|
||||
:return: 媒体信息列表
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_discover", mtype=mtype,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
|
||||
async def async_tmdb_trending(self, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
TMDB流行趋势(异步版本)
|
||||
:param page: 第几页
|
||||
:return: TMDB信息列表
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_trending", page=page)
|
||||
|
||||
async def async_tmdb_collection(self, collection_id: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据合集ID查询集合(异步版本)
|
||||
:param collection_id: 合集ID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_collection", collection_id=collection_id)
|
||||
|
||||
async def async_tmdb_seasons(self, tmdbid: int) -> List[schemas.TmdbSeason]:
|
||||
"""
|
||||
根据TMDBID查询themoviedb所有季信息(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_seasons", tmdbid=tmdbid)
|
||||
|
||||
async def async_tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
|
||||
"""
|
||||
根据剧集组ID查询themoviedb所有季集信息(异步版本)
|
||||
:param group_id: 剧集组ID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_group_seasons", group_id=group_id)
|
||||
|
||||
async def async_tmdb_episodes(self, tmdbid: int, season: int,
|
||||
episode_group: Optional[str] = None) -> List[schemas.TmdbEpisode]:
|
||||
"""
|
||||
根据TMDBID查询某季的所有信信息(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
:param season: 季
|
||||
:param episode_group: 剧集组
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_episodes", tmdbid=tmdbid, season=season,
|
||||
episode_group=episode_group)
|
||||
|
||||
async def async_movie_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据TMDBID查询类似电影(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_movie_similar", tmdbid=tmdbid)
|
||||
|
||||
async def async_tv_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据TMDBID查询类似电视剧(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_tv_similar", tmdbid=tmdbid)
|
||||
|
||||
async def async_movie_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据TMDBID查询推荐电影(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_movie_recommend", tmdbid=tmdbid)
|
||||
|
||||
async def async_tv_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据TMDBID查询推荐电视剧(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_tv_recommend", tmdbid=tmdbid)
|
||||
|
||||
async def async_movie_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
根据TMDBID查询电影演职人员(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
:param page: 页码
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_movie_credits", tmdbid=tmdbid, page=page)
|
||||
|
||||
async def async_tv_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
根据TMDBID查询电视剧演职人员(异步版本)
|
||||
:param tmdbid: TMDBID
|
||||
:param page: 页码
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_tv_credits", tmdbid=tmdbid, page=page)
|
||||
|
||||
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
|
||||
"""
|
||||
根据TMDBID查询演职员详情(异步版本)
|
||||
:param person_id: 人物ID
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_person_detail", person_id=person_id)
|
||||
|
||||
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
|
||||
"""
|
||||
根据人物ID查询人物参演作品(异步版本)
|
||||
:param person_id: 人物ID
|
||||
:param page: 页码
|
||||
"""
|
||||
return await self.async_run_module("async_tmdb_person_credits", person_id=person_id, page=page)
|
||||
|
||||
async def async_get_random_wallpager(self) -> Optional[str]:
|
||||
"""
|
||||
获取随机壁纸(异步版本),缓存1个小时
|
||||
"""
|
||||
infos = await self.async_tmdb_trending()
|
||||
if infos:
|
||||
# 随机一个电影
|
||||
while True:
|
||||
info = random.choice(infos)
|
||||
if info and info.backdrop_path:
|
||||
return info.backdrop_path
|
||||
return None
|
||||
|
||||
async def async_get_trending_wallpapers(self, num: Optional[int] = 10) -> List[str]:
|
||||
"""
|
||||
获取所有流行壁纸(异步版本)
|
||||
"""
|
||||
infos = await self.async_tmdb_trending()
|
||||
if infos:
|
||||
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
|
||||
return []
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.rss import RssHelper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -56,9 +56,34 @@ class TorrentsChain(ChainBase):
|
||||
|
||||
# 读取缓存
|
||||
if stype == 'spider':
|
||||
return self.load_cache(self._spider_file) or {}
|
||||
torrents_cache = self.load_cache(self._spider_file) or {}
|
||||
else:
|
||||
return self.load_cache(self._rss_file) or {}
|
||||
torrents_cache = self.load_cache(self._rss_file) or {}
|
||||
|
||||
# 兼容性处理:为旧版本的Context对象添加失败次数字段
|
||||
self._ensure_context_compatibility(torrents_cache)
|
||||
|
||||
return torrents_cache
|
||||
|
||||
async def async_get_torrents(self, stype: Optional[str] = None) -> Dict[str, List[Context]]:
|
||||
"""
|
||||
异步获取当前缓存的种子
|
||||
:param stype: 强制指定缓存类型,spider:爬虫缓存,rss:rss缓存
|
||||
"""
|
||||
|
||||
if not stype:
|
||||
stype = settings.SUBSCRIBE_MODE
|
||||
|
||||
# 异步读取缓存
|
||||
if stype == 'spider':
|
||||
torrents_cache = await self.async_load_cache(self._spider_file) or {}
|
||||
else:
|
||||
torrents_cache = await self.async_load_cache(self._rss_file) or {}
|
||||
|
||||
# 兼容性处理:为旧版本的Context对象添加失败次数字段
|
||||
self._ensure_context_compatibility(torrents_cache)
|
||||
|
||||
return torrents_cache
|
||||
|
||||
def clear_torrents(self):
|
||||
"""
|
||||
@@ -69,6 +94,15 @@ class TorrentsChain(ChainBase):
|
||||
self.remove_cache(self._rss_file)
|
||||
logger.info(f'种子缓存数据清理完成')
|
||||
|
||||
async def async_clear_torrents(self):
|
||||
"""
|
||||
异步清理种子缓存数据
|
||||
"""
|
||||
logger.info(f'开始异步清理种子缓存数据 ...')
|
||||
await self.async_remove_cache(self._spider_file)
|
||||
await self.async_remove_cache(self._rss_file)
|
||||
logger.info(f'异步种子缓存数据清理完成')
|
||||
|
||||
def browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
|
||||
page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
@@ -85,6 +119,22 @@ class TorrentsChain(ChainBase):
|
||||
return []
|
||||
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
|
||||
|
||||
async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
|
||||
page: Optional[int] = 0) -> List[TorrentInfo]:
|
||||
"""
|
||||
异步浏览站点首页内容,返回种子清单,TTL缓存5分钟
|
||||
:param domain: 站点域名
|
||||
:param keyword: 搜索标题
|
||||
:param cat: 搜索分类
|
||||
:param page: 页码
|
||||
"""
|
||||
logger.info(f'开始获取站点 {domain} 最新种子 ...')
|
||||
site = await SitesHelper().async_get_indexer(domain)
|
||||
if not site:
|
||||
logger.error(f'站点 {domain} 不存在!')
|
||||
return []
|
||||
return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
|
||||
|
||||
def rss(self, domain: str) -> List[TorrentInfo]:
|
||||
"""
|
||||
获取站点RSS内容,返回种子清单,TTL缓存3分钟
|
||||
@@ -100,7 +150,7 @@ class TorrentsChain(ChainBase):
|
||||
return []
|
||||
# 解析RSS
|
||||
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
|
||||
timeout=int(site.get("timeout") or 30))
|
||||
timeout=int(site.get("timeout") or 30), ua=site.get("ua") if site.get("ua") else None)
|
||||
if rss_items is None:
|
||||
# rss过期,尝试保留原配置生成新的rss
|
||||
self.__renew_rss_url(domain=domain, site=site)
|
||||
@@ -140,6 +190,16 @@ class TorrentsChain(ChainBase):
|
||||
:param stype: 强制指定缓存类型,spider:爬虫缓存,rss:rss缓存
|
||||
:param sites: 强制指定站点ID列表,为空则读取设置的订阅站点
|
||||
"""
|
||||
|
||||
def __is_no_cache_site(_domain: str) -> bool:
|
||||
"""
|
||||
判断站点是否不需要缓存
|
||||
"""
|
||||
for url_key in settings.NO_CACHE_SITE_KEY.split(','):
|
||||
if url_key in _domain:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 刷新类型
|
||||
if not stype:
|
||||
stype = settings.SUBSCRIBE_MODE
|
||||
@@ -169,7 +229,15 @@ class TorrentsChain(ChainBase):
|
||||
domains.append(domain)
|
||||
if stype == "spider":
|
||||
# 刷新首页种子
|
||||
torrents: List[TorrentInfo] = self.browse(domain=domain)
|
||||
torrents: List[TorrentInfo] = []
|
||||
# 读取第0页和第1页
|
||||
for page in range(2):
|
||||
page_torrents = self.browse(domain=domain, page=page)
|
||||
if page_torrents:
|
||||
torrents.extend(page_torrents)
|
||||
else:
|
||||
# 如果某一页没有数据,说明已经到最后一页,停止获取
|
||||
break
|
||||
else:
|
||||
# 刷新RSS种子
|
||||
torrents: List[TorrentInfo] = self.rss(domain=domain)
|
||||
@@ -178,11 +246,16 @@ class TorrentsChain(ChainBase):
|
||||
# 取前N条
|
||||
torrents = torrents[:settings.CONF.refresh]
|
||||
if torrents:
|
||||
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
|
||||
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
|
||||
for t in torrents_cache.get(domain) or []}
|
||||
torrents = [torrent for torrent in torrents
|
||||
if f'{torrent.title}{torrent.description}' not in cached_signatures]
|
||||
if __is_no_cache_site(domain):
|
||||
# 不需要缓存的站点,直接处理
|
||||
logger.info(f'{indexer.get("name")} 有 {len(torrents)} 个种子 (不缓存)')
|
||||
torrents_cache[domain] = []
|
||||
else:
|
||||
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
|
||||
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
|
||||
for t in torrents_cache.get(domain) or []}
|
||||
torrents = [torrent for torrent in torrents
|
||||
if f'{torrent.title}{torrent.description}' not in cached_signatures]
|
||||
if torrents:
|
||||
logger.info(f'{indexer.get("name")} 有 {len(torrents)} 个新种子')
|
||||
else:
|
||||
@@ -211,6 +284,9 @@ class TorrentsChain(ChainBase):
|
||||
mediainfo.clear()
|
||||
# 上下文
|
||||
context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent)
|
||||
# 如果未识别到媒体信息,设置初始失败次数为1
|
||||
if not mediainfo or (not mediainfo.tmdb_id and not mediainfo.douban_id):
|
||||
context.media_recognize_fail_count = 1
|
||||
# 添加到缓存
|
||||
if not torrents_cache.get(domain):
|
||||
torrents_cache[domain] = [context]
|
||||
@@ -237,6 +313,21 @@ class TorrentsChain(ChainBase):
|
||||
|
||||
return torrents_cache
|
||||
|
||||
@staticmethod
|
||||
def _ensure_context_compatibility(torrents_cache: Dict[str, List[Context]]):
|
||||
"""
|
||||
确保Context对象的兼容性,为旧版本添加缺失的字段
|
||||
"""
|
||||
for domain, contexts in torrents_cache.items():
|
||||
for context in contexts:
|
||||
# 如果Context对象没有media_recognize_fail_count字段,添加默认值
|
||||
if not hasattr(context, 'media_recognize_fail_count'):
|
||||
context.media_recognize_fail_count = 0
|
||||
# 如果媒体信息未识别,设置初始失败次数
|
||||
if (not context.media_info or
|
||||
(not context.media_info.tmdb_id and not context.media_info.douban_id)):
|
||||
context.media_recognize_fail_count = 1
|
||||
|
||||
def __renew_rss_url(self, domain: str, site: dict):
|
||||
"""
|
||||
保留原配置生成新的rss地址
|
||||
@@ -249,7 +340,8 @@ class TorrentsChain(ChainBase):
|
||||
url=site.get("url"),
|
||||
cookie=site.get("cookie"),
|
||||
ua=site.get("ua") or settings.USER_AGENT,
|
||||
proxy=True if site.get("proxy") else False
|
||||
proxy=True if site.get("proxy") else False,
|
||||
timeout=site.get("timeout"),
|
||||
)
|
||||
if rss_url:
|
||||
# 获取新的日期的passkey
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import gc
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from time import sleep
|
||||
from typing import List, Optional, Tuple, Union, Dict, Callable
|
||||
|
||||
@@ -16,9 +14,9 @@ from app.chain.storage import StorageChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.core.event import eventmanager
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
@@ -28,13 +26,14 @@ from app.helper.directory import DirectoryHelper
|
||||
from app.helper.format import FormatParser
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.log import logger
|
||||
from app.schemas import StorageOperSelectionEventData
|
||||
from app.schemas import TransferInfo, TransferTorrent, Notification, EpisodeFormat, FileItem, TransferDirectoryConf, \
|
||||
TransferTask, TransferQueue, TransferJob, TransferJobTask
|
||||
from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey, NotificationType, MessageChannel, \
|
||||
SystemConfigKey, ChainEventType, ContentType
|
||||
from app.schemas import StorageOperSelectionEventData
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
downloader_lock = threading.Lock()
|
||||
job_lock = threading.Lock()
|
||||
@@ -213,6 +212,7 @@ class JobManager:
|
||||
set(self._season_episodes[mediaid]) - set(task.meta.episode_list)
|
||||
)
|
||||
return task
|
||||
return None
|
||||
|
||||
def remove_job(self, task: TransferTask) -> Optional[TransferJob]:
|
||||
"""
|
||||
@@ -226,6 +226,7 @@ class JobManager:
|
||||
if __mediaid__ in self._season_episodes:
|
||||
self._season_episodes.pop(__mediaid__)
|
||||
return self._job_view.pop(__mediaid__)
|
||||
return None
|
||||
|
||||
def is_done(self, task: TransferTask) -> bool:
|
||||
"""
|
||||
@@ -311,7 +312,7 @@ class JobManager:
|
||||
|
||||
def count(self, media: MediaInfo, season: Optional[int] = None) -> int:
|
||||
"""
|
||||
获取某项任务总数
|
||||
获取某项任务成功总数
|
||||
"""
|
||||
__mediaid__ = self.__get_media_id(media=media, season=season)
|
||||
with job_lock:
|
||||
@@ -322,15 +323,19 @@ class JobManager:
|
||||
|
||||
def size(self, media: MediaInfo, season: Optional[int] = None) -> int:
|
||||
"""
|
||||
获取某项任务总大小
|
||||
获取某项任务成功文件总大小
|
||||
"""
|
||||
__mediaid__ = self.__get_media_id(media=media, season=season)
|
||||
with job_lock:
|
||||
# 计算状态为完成的任务数
|
||||
if __mediaid__ not in self._job_view:
|
||||
return 0
|
||||
return sum([task.fileitem.size for task in self._job_view[__mediaid__].tasks if
|
||||
task.state == "completed" and task.fileitem.size is not None])
|
||||
return sum([
|
||||
task.fileitem.size if task.fileitem.size is not None
|
||||
else (SystemUtils.get_directory_size(Path(task.fileitem.path)) if task.fileitem.storage == "local" else 0)
|
||||
for task in self._job_view[__mediaid__].tasks
|
||||
if task.state == "completed"
|
||||
])
|
||||
|
||||
def total(self) -> int:
|
||||
"""
|
||||
@@ -359,22 +364,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
文件整理处理链
|
||||
"""
|
||||
|
||||
# 可处理的文件后缀
|
||||
all_exts = settings.RMT_MEDIAEXT
|
||||
|
||||
# 待整理任务队列
|
||||
_queue = Queue()
|
||||
|
||||
# 文件整理线程
|
||||
_transfer_thread = None
|
||||
|
||||
# 队列间隔时间(秒)
|
||||
_transfer_interval = 15
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 可处理的文件后缀
|
||||
self.all_exts = settings.RMT_MEDIAEXT
|
||||
# 待整理任务队列
|
||||
self._queue = queue.Queue()
|
||||
# 文件整理线程
|
||||
self._transfer_thread = None
|
||||
# 队列间隔时间(秒)
|
||||
self._transfer_interval = 15
|
||||
# 事件管理器
|
||||
self.jobview = JobManager()
|
||||
|
||||
# 车移成功的文件清单
|
||||
self._success_target_files: Dict[str, List[str]] = {}
|
||||
# 启动整理任务
|
||||
self.__init()
|
||||
|
||||
@@ -391,6 +394,44 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
整理完成后处理
|
||||
"""
|
||||
|
||||
def __do_finished():
|
||||
"""
|
||||
完成时发送消息、刮削事件、移除任务等
|
||||
"""
|
||||
# 更新文件数量
|
||||
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
|
||||
# 更新文件大小
|
||||
transferinfo.total_size = self.jobview.size(task.mediainfo,
|
||||
task.meta.begin_season) or task.fileitem.size
|
||||
# 更新文件清单
|
||||
transferinfo.file_list_new = self._success_target_files.pop(transferinfo.target_diritem.path, [])
|
||||
# 发送通知,实时手动整理时不发
|
||||
if transferinfo.need_notify and (task.background or not task.manual):
|
||||
se_str = None
|
||||
if task.mediainfo.type == MediaType.TV:
|
||||
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
|
||||
if season_episodes:
|
||||
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
|
||||
else:
|
||||
se_str = f"{task.meta.season}"
|
||||
self.send_transfer_message(meta=task.meta,
|
||||
mediainfo=task.mediainfo,
|
||||
transferinfo=transferinfo,
|
||||
season_episode=se_str,
|
||||
username=task.username)
|
||||
# 刮削事件
|
||||
if transferinfo.need_scrape:
|
||||
self.eventmanager.send_event(EventType.MetadataScrape, {
|
||||
'meta': task.meta,
|
||||
'mediainfo': task.mediainfo,
|
||||
'fileitem': transferinfo.target_diritem,
|
||||
'file_list': transferinfo.file_list_new,
|
||||
'overwrite': False
|
||||
})
|
||||
# 移除已完成的任务
|
||||
self.jobview.remove_job(task)
|
||||
|
||||
transferhis = TransferHistoryOper()
|
||||
if not transferinfo.success:
|
||||
# 转移失败
|
||||
@@ -416,6 +457,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
))
|
||||
# 整理失败
|
||||
self.jobview.fail_task(task)
|
||||
with task_lock:
|
||||
# 整理完成且有成功的任务时
|
||||
if self.jobview.is_finished(task):
|
||||
__do_finished()
|
||||
return False, transferinfo.message
|
||||
|
||||
# 转移成功
|
||||
@@ -444,55 +489,32 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
})
|
||||
|
||||
with task_lock:
|
||||
# 登记转移成功文件清单
|
||||
target_dir_path = transferinfo.target_diritem.path
|
||||
target_files = transferinfo.file_list_new
|
||||
if self._success_target_files.get(target_dir_path):
|
||||
self._success_target_files[target_dir_path].extend(target_files)
|
||||
else:
|
||||
self._success_target_files[target_dir_path] = target_files
|
||||
# 全部整理成功时
|
||||
if self.jobview.is_success(task):
|
||||
# 移动模式删除空目录
|
||||
if transferinfo.transfer_type in ["move"]:
|
||||
# 所有成功的业务
|
||||
tasks = self.jobview.success_tasks(task.mediainfo, task.meta.begin_season)
|
||||
# 记录已处理的种子hash
|
||||
processed_hashes = set()
|
||||
storagechain = StorageChain()
|
||||
# 获取整理屏蔽词
|
||||
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
|
||||
for t in tasks:
|
||||
# 下载器hash
|
||||
if t.download_hash and t.download_hash not in processed_hashes:
|
||||
processed_hashes.add(t.download_hash)
|
||||
if t.download_hash and self._can_delete_torrent(t.download_hash, t.downloader,
|
||||
transfer_exclude_words):
|
||||
if self.remove_torrents(t.download_hash, downloader=t.downloader):
|
||||
logger.info(f"移动模式删除种子成功:{t.download_hash} ")
|
||||
# 删除残留目录
|
||||
logger.info(f"移动模式删除种子成功:{t.download_hash}")
|
||||
if t.fileitem:
|
||||
storagechain.delete_media_file(t.fileitem, delete_self=False)
|
||||
# 整理完成且有成功的任务时
|
||||
if self.jobview.is_finished(task):
|
||||
# 发送通知,实时手动整理时不发
|
||||
if transferinfo.need_notify and (task.background or not task.manual):
|
||||
se_str = None
|
||||
if task.mediainfo.type == MediaType.TV:
|
||||
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
|
||||
if season_episodes:
|
||||
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
|
||||
else:
|
||||
se_str = f"{task.meta.season}"
|
||||
# 更新文件数量
|
||||
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
|
||||
# 更新文件大小
|
||||
transferinfo.total_size = self.jobview.size(task.mediainfo,
|
||||
task.meta.begin_season) or task.fileitem.size
|
||||
self.send_transfer_message(meta=task.meta,
|
||||
mediainfo=task.mediainfo,
|
||||
transferinfo=transferinfo,
|
||||
season_episode=se_str,
|
||||
username=task.username)
|
||||
# 刮削事件
|
||||
if transferinfo.need_scrape:
|
||||
self.eventmanager.send_event(EventType.MetadataScrape, {
|
||||
'meta': task.meta,
|
||||
'mediainfo': task.mediainfo,
|
||||
'fileitem': transferinfo.target_diritem
|
||||
})
|
||||
|
||||
# 移除已完成的任务
|
||||
self.jobview.remove_job(task)
|
||||
__do_finished()
|
||||
|
||||
return True, ""
|
||||
|
||||
@@ -538,8 +560,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
processed_num = 0
|
||||
# 失败数量
|
||||
fail_num = 0
|
||||
# 已完成文件
|
||||
finished_files = []
|
||||
|
||||
progress = ProgressHelper()
|
||||
progress = ProgressHelper(ProgressKey.FileTransfer)
|
||||
|
||||
while not global_vars.is_system_stopped:
|
||||
try:
|
||||
@@ -554,7 +578,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
if __queue_start:
|
||||
logger.info("开始整理队列处理...")
|
||||
# 启动进度
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress.start()
|
||||
# 重置计数
|
||||
processed_num = 0
|
||||
fail_num = 0
|
||||
@@ -562,8 +586,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
__process_msg = f"开始整理队列处理,当前共 {total_num} 个文件 ..."
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=0,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
text=__process_msg)
|
||||
# 队列已开始
|
||||
__queue_start = False
|
||||
# 更新进度
|
||||
@@ -571,7 +594,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=processed_num / total_num * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={
|
||||
"current": Path(fileitem.path).as_posix(),
|
||||
"finished": finished_files
|
||||
})
|
||||
# 整理
|
||||
state, err_msg = self.__handle_transfer(task=task, callback=item.callback)
|
||||
if not state:
|
||||
@@ -579,20 +605,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
fail_num += 1
|
||||
# 更新进度
|
||||
processed_num += 1
|
||||
finished_files.append(Path(fileitem.path).as_posix())
|
||||
__process_msg = f"{fileitem.name} 整理完成"
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=processed_num / total_num * 100,
|
||||
progress.update(value=(processed_num / total_num) * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={})
|
||||
except queue.Empty:
|
||||
if not __queue_start:
|
||||
# 结束进度
|
||||
__end_msg = f"整理队列处理完成,共整理 {processed_num} 个文件,失败 {fail_num} 个"
|
||||
logger.info(__end_msg)
|
||||
progress.update(value=100,
|
||||
text=__end_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
text=__end_msg)
|
||||
progress.end()
|
||||
# 重置计数
|
||||
processed_num = 0
|
||||
fail_num = 0
|
||||
@@ -847,7 +873,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
state, errmsg = self.do_transfer(
|
||||
fileitem=FileItem(
|
||||
storage="local",
|
||||
path=str(file_path).replace("\\", "/"),
|
||||
path=file_path.as_posix(),
|
||||
type="dir" if not file_path.is_file() else "file",
|
||||
name=file_path.name,
|
||||
size=file_path.stat().st_size,
|
||||
@@ -867,10 +893,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
torrents.clear()
|
||||
del torrents
|
||||
|
||||
# 如果不是大内存模式,进行垃圾回收
|
||||
if not settings.BIG_MEMORY_MODE:
|
||||
gc.collect()
|
||||
|
||||
# 结束
|
||||
logger.info("所有下载器中下载完成的文件已整理完成")
|
||||
return True
|
||||
@@ -880,7 +902,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
) -> List[Tuple[FileItem, bool]]:
|
||||
"""
|
||||
获取整理目录或文件列表
|
||||
|
||||
|
||||
:param fileitem: 文件项
|
||||
:param depth: 递归深度,默认为1
|
||||
"""
|
||||
@@ -1058,16 +1080,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
continue
|
||||
|
||||
# 整理屏蔽词不处理
|
||||
is_blocked = False
|
||||
if transfer_exclude_words:
|
||||
for keyword in transfer_exclude_words:
|
||||
if not keyword:
|
||||
continue
|
||||
if keyword and re.search(r"%s" % keyword, file_item.path, re.IGNORECASE):
|
||||
logger.info(f"{file_item.path} 命中整理屏蔽词 {keyword},不处理")
|
||||
is_blocked = True
|
||||
break
|
||||
if is_blocked:
|
||||
if self._is_blocked_by_exclude_words(file_item.path, transfer_exclude_words):
|
||||
continue
|
||||
|
||||
# 整理成功的不再处理
|
||||
@@ -1099,9 +1112,11 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
# 自定义识别
|
||||
if formaterHandler:
|
||||
# 开始集、结束集、PART
|
||||
begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name, file_meta=file_meta)
|
||||
begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name,
|
||||
file_meta=file_meta)
|
||||
if begin_ep is not None:
|
||||
file_meta.begin_episode = begin_ep
|
||||
if part is not None:
|
||||
file_meta.part = part
|
||||
if end_ep is not None:
|
||||
file_meta.end_episode = end_ep
|
||||
@@ -1111,10 +1126,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(str(file_path))
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(str(file_path))
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
@@ -1160,15 +1175,16 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
processed_num = 0
|
||||
# 失败数量
|
||||
fail_num = 0
|
||||
# 已完成文件
|
||||
finished_files = []
|
||||
|
||||
# 启动进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress = ProgressHelper(ProgressKey.FileTransfer)
|
||||
progress.start()
|
||||
__process_msg = f"开始整理,共 {total_num} 个文件 ..."
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=0,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
text=__process_msg)
|
||||
try:
|
||||
for transfer_task in transfer_tasks:
|
||||
if global_vars.is_system_stopped:
|
||||
@@ -1180,7 +1196,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=(processed_num + fail_num) / total_num * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={
|
||||
"current": Path(transfer_task.fileitem.path).as_posix(),
|
||||
"finished": finished_files,
|
||||
})
|
||||
state, err_msg = self.__handle_transfer(
|
||||
task=transfer_task,
|
||||
callback=self.__default_callback
|
||||
@@ -1192,6 +1211,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
fail_num += 1
|
||||
else:
|
||||
processed_num += 1
|
||||
# 记录已完成
|
||||
finished_files.append(Path(transfer_task.fileitem.path).as_posix())
|
||||
finally:
|
||||
transfer_tasks.clear()
|
||||
del transfer_tasks
|
||||
@@ -1201,10 +1222,11 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__end_msg)
|
||||
progress.update(value=100,
|
||||
text=__end_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
data={})
|
||||
progress.end()
|
||||
|
||||
return all_success, ",".join(err_msgs)
|
||||
error_msg = "、".join(err_msgs[:2]) + (f",等{len(err_msgs)}个文件错误!" if len(err_msgs) > 2 else "")
|
||||
return all_success, error_msg
|
||||
|
||||
def remote_transfer(self, arg_str: str, channel: MessageChannel,
|
||||
userid: Union[str, int] = None, source: Optional[str] = None):
|
||||
@@ -1341,16 +1363,12 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
mediainfo: MediaInfo = MediaChain().recognize_media(tmdbid=tmdbid, doubanid=doubanid,
|
||||
mtype=mtype, episode_group=episode_group)
|
||||
if not mediainfo:
|
||||
return False, f"媒体信息识别失败,tmdbid:{tmdbid},doubanid:{doubanid},type: {mtype.value}"
|
||||
return (False,
|
||||
f"媒体信息识别失败,tmdbid:{tmdbid},doubanid:{doubanid},type: {mtype.value if mtype else None}")
|
||||
else:
|
||||
# 更新媒体图片
|
||||
self.obtain_images(mediainfo=mediainfo)
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress.update(value=0,
|
||||
text=f"开始整理 {fileitem.path} ...",
|
||||
key=ProgressKey.FileTransfer)
|
||||
|
||||
# 开始整理
|
||||
state, errmsg = self.do_transfer(
|
||||
fileitem=fileitem,
|
||||
@@ -1371,7 +1389,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
if not state:
|
||||
return False, errmsg
|
||||
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
logger.info(f"{fileitem.path} 整理完成")
|
||||
return True, ""
|
||||
else:
|
||||
@@ -1411,3 +1428,63 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
season_episode=season_episode,
|
||||
username=username
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_blocked_by_exclude_words(file_path: str, exclude_words: list) -> bool:
|
||||
"""
|
||||
检查文件是否被整理屏蔽词阻止处理
|
||||
:param file_path: 文件路径
|
||||
:param exclude_words: 整理屏蔽词列表
|
||||
:return: 如果被屏蔽返回True,否则返回False
|
||||
"""
|
||||
if not exclude_words:
|
||||
return False
|
||||
|
||||
for keyword in exclude_words:
|
||||
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
|
||||
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _can_delete_torrent(self, download_hash: str, downloader: str, transfer_exclude_words) -> bool:
|
||||
"""
|
||||
检查是否可以删除种子文件
|
||||
:param download_hash: 种子Hash
|
||||
:param downloader: 下载器名称
|
||||
:param transfer_exclude_words: 整理屏蔽词
|
||||
:return: 如果可以删除返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 获取种子信息
|
||||
torrents = self.list_torrents(hashs=download_hash, downloader=downloader)
|
||||
if not torrents:
|
||||
return False
|
||||
|
||||
# 未下载完成
|
||||
if torrents[0].progress < 100:
|
||||
return False
|
||||
|
||||
# 获取种子文件列表
|
||||
torrent_files = self.torrent_files(download_hash, downloader)
|
||||
if not torrent_files:
|
||||
return False
|
||||
|
||||
if not isinstance(torrent_files, list):
|
||||
torrent_files = torrent_files.data
|
||||
|
||||
# 检查是否有媒体文件未被屏蔽且存在
|
||||
save_path = torrents[0].path.parent
|
||||
for file in torrent_files:
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (file_path.suffix in self.all_exts
|
||||
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
|
||||
and file_path.exists()):
|
||||
return False
|
||||
|
||||
# 所有媒体文件都被屏蔽或不存在,可以删除种子
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查种子 {download_hash} 是否需要删除失败:{e}")
|
||||
return False
|
||||
|
||||
@@ -10,11 +10,13 @@ from pydantic.fields import Callable
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
@@ -178,7 +180,7 @@ class WorkflowExecutor:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
for key, value in context.dict().items():
|
||||
for key, value in context.model_dump().items():
|
||||
if not getattr(self.context, key, None):
|
||||
setattr(self.context, key, value)
|
||||
|
||||
@@ -188,6 +190,16 @@ class WorkflowChain(ChainBase):
|
||||
工作流链
|
||||
"""
|
||||
|
||||
@eventmanager.register(EventType.WorkflowExecute)
|
||||
def event_process(self, event: Event):
|
||||
"""
|
||||
事件触发工作流执行
|
||||
"""
|
||||
workflow_id = event.event_data.get('workflow_id')
|
||||
if not workflow_id:
|
||||
return
|
||||
self.process(workflow_id, from_begin=False)
|
||||
|
||||
@staticmethod
|
||||
def process(workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]:
|
||||
"""
|
||||
@@ -225,7 +237,7 @@ class WorkflowChain(ChainBase):
|
||||
logger.warn(f"工作流 {workflow.name} 无流程")
|
||||
return False, "工作流无流程"
|
||||
|
||||
logger.info(f"开始处理 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
|
||||
logger.info(f"开始执行工作流 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
|
||||
workflowoper.start(workflow_id)
|
||||
|
||||
# 执行工作流
|
||||
@@ -247,3 +259,17 @@ class WorkflowChain(ChainBase):
|
||||
获取工作流列表
|
||||
"""
|
||||
return WorkflowOper().list_enabled()
|
||||
|
||||
@staticmethod
|
||||
def get_timer_workflows() -> List[Workflow]:
|
||||
"""
|
||||
获取定时触发的工作流列表
|
||||
"""
|
||||
return WorkflowOper().get_timer_triggered_workflows()
|
||||
|
||||
@staticmethod
|
||||
def get_event_workflows() -> List[Workflow]:
|
||||
"""
|
||||
获取事件触发的工作流列表
|
||||
"""
|
||||
return WorkflowOper().get_event_triggered_workflows()
|
||||
|
||||
@@ -215,7 +215,7 @@ class Command(metaclass=Singleton):
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
|
||||
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
触发事件,允许调整命令数据
|
||||
"""
|
||||
@@ -225,6 +225,9 @@ class Command(metaclass=Singleton):
|
||||
添加命令集合
|
||||
"""
|
||||
for cmd, command in source.items():
|
||||
if not command.get("show", True):
|
||||
continue
|
||||
|
||||
command_data = {
|
||||
"type": command_type,
|
||||
"description": command.get("description"),
|
||||
@@ -261,6 +264,7 @@ class Command(metaclass=Singleton):
|
||||
"func": self.send_plugin_event,
|
||||
"description": command.get("desc"),
|
||||
"category": command.get("category"),
|
||||
"show": command.get("show", True),
|
||||
"data": {
|
||||
"etype": command.get("event"),
|
||||
"data": command.get("data")
|
||||
@@ -335,7 +339,8 @@ class Command(metaclass=Singleton):
|
||||
return self._commands.get(cmd, {})
|
||||
|
||||
def register(self, cmd: str, func: Any, data: Optional[dict] = None,
|
||||
desc: Optional[str] = None, category: Optional[str] = None) -> None:
|
||||
desc: Optional[str] = None, category: Optional[str] = None,
|
||||
show: bool = True) -> None:
|
||||
"""
|
||||
注册单个命令
|
||||
"""
|
||||
@@ -344,7 +349,8 @@ class Command(metaclass=Singleton):
|
||||
"func": func,
|
||||
"description": desc,
|
||||
"category": category,
|
||||
"data": data or {}
|
||||
"data": data or {},
|
||||
"show": show
|
||||
}
|
||||
|
||||
def execute(self, cmd: str, data_str: Optional[str] = "",
|
||||
|
||||
1491
app/core/cache.py
1491
app/core/cache.py
File diff suppressed because it is too large
Load Diff
@@ -1,18 +1,24 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, BaseSettings, validator, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from app.log import logger, log_settings, LogConfigModel
|
||||
from app.schemas import MediaType
|
||||
from app.utils.system import SystemUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from version import APP_VERSION
|
||||
|
||||
|
||||
class SystemConfModel(BaseModel):
|
||||
@@ -37,10 +43,6 @@ class SystemConfModel(BaseModel):
|
||||
scheduler: int = 0
|
||||
# 线程池大小
|
||||
threadpool: int = 0
|
||||
# 数据库连接池大小
|
||||
dbpool: int = 0
|
||||
# 数据库连接池溢出数量
|
||||
dbpooloverflow: int = 0
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
@@ -48,9 +50,9 @@ class ConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# ==================== 基础应用配置 ====================
|
||||
# 项目名称
|
||||
PROJECT_NAME: str = "MoviePilot"
|
||||
# 域名 格式;https://movie-pilot.org
|
||||
@@ -59,6 +61,24 @@ class ConfigModel(BaseModel):
|
||||
API_V1_STR: str = "/api/v1"
|
||||
# 前端资源路径
|
||||
FRONTEND_PATH: str = "/public"
|
||||
# 时区
|
||||
TZ: str = "Asia/Shanghai"
|
||||
# API监听地址
|
||||
HOST: str = "0.0.0.0"
|
||||
# API监听端口
|
||||
PORT: int = 3001
|
||||
# 前端监听端口
|
||||
NGINX_PORT: int = 3000
|
||||
# 配置文件目录
|
||||
CONFIG_DIR: Optional[str] = None
|
||||
# 是否调试模式
|
||||
DEBUG: bool = False
|
||||
# 是否开发模式
|
||||
DEV: bool = False
|
||||
# 高级设置模式
|
||||
ADVANCED_MODE: bool = True
|
||||
|
||||
# ==================== 安全认证配置 ====================
|
||||
# 密钥
|
||||
SECRET_KEY: str = secrets.token_urlsafe(32)
|
||||
# RESOURCE密钥
|
||||
@@ -69,20 +89,26 @@ class ConfigModel(BaseModel):
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
|
||||
# RESOURCE_TOKEN过期时间
|
||||
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
|
||||
# 时区
|
||||
TZ: str = "Asia/Shanghai"
|
||||
# API监听地址
|
||||
HOST: str = "0.0.0.0"
|
||||
# API监听端口
|
||||
PORT: int = 3001
|
||||
# 前端监听端口
|
||||
NGINX_PORT: int = 3000
|
||||
# 是否调试模式
|
||||
DEBUG: bool = False
|
||||
# 是否开发模式
|
||||
DEV: bool = False
|
||||
# 超级管理员初始用户名
|
||||
SUPERUSER: str = "admin"
|
||||
# 超级管理员初始密码
|
||||
SUPERUSER_PASSWORD: Optional[str] = None
|
||||
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
API_TOKEN: Optional[str] = None
|
||||
# 用户认证站点
|
||||
AUTH_SITE: str = ""
|
||||
|
||||
# ==================== 数据库配置 ====================
|
||||
# 数据库类型,支持 sqlite 和 postgresql,默认使用 sqlite
|
||||
DB_TYPE: str = "sqlite"
|
||||
# 是否在控制台输出 SQL 语句,默认关闭
|
||||
DB_ECHO: bool = False
|
||||
# 数据库连接超时时间(秒),默认为 60 秒
|
||||
DB_TIMEOUT: int = 60
|
||||
# 是否启用 WAL 模式,仅适用于SQLite,默认开启
|
||||
DB_WAL_ENABLE: bool = True
|
||||
# 数据库连接池类型,QueuePool, NullPool
|
||||
DB_POOL_TYPE: str = "QueuePool"
|
||||
# 是否在获取连接时进行预先 ping 操作
|
||||
@@ -91,71 +117,44 @@ class ConfigModel(BaseModel):
|
||||
DB_POOL_RECYCLE: int = 300
|
||||
# 数据库连接池获取连接的超时时间(秒)
|
||||
DB_POOL_TIMEOUT: int = 30
|
||||
# SQLite 的 busy_timeout 参数,默认为 60 秒
|
||||
DB_TIMEOUT: int = 60
|
||||
# SQLite 是否启用 WAL 模式,默认开启
|
||||
DB_WAL_ENABLE: bool = True
|
||||
# SQLite 连接池大小
|
||||
DB_SQLITE_POOL_SIZE: int = 10
|
||||
# SQLite 连接池溢出数量
|
||||
DB_SQLITE_MAX_OVERFLOW: int = 50
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST: str = "localhost"
|
||||
# PostgreSQL 端口
|
||||
DB_POSTGRESQL_PORT: int = 5432
|
||||
# PostgreSQL 数据库名
|
||||
DB_POSTGRESQL_DATABASE: str = "moviepilot"
|
||||
# PostgreSQL 用户名
|
||||
DB_POSTGRESQL_USERNAME: str = "moviepilot"
|
||||
# PostgreSQL 密码
|
||||
DB_POSTGRESQL_PASSWORD: str = "moviepilot"
|
||||
# PostgreSQL 连接池大小
|
||||
DB_POSTGRESQL_POOL_SIZE: int = 10
|
||||
# PostgreSQL 连接池溢出数量
|
||||
DB_POSTGRESQL_MAX_OVERFLOW: int = 50
|
||||
|
||||
# ==================== 缓存配置 ====================
|
||||
# 缓存类型,支持 cachetools 和 redis,默认使用 cachetools
|
||||
CACHE_BACKEND_TYPE: str = "cachetools"
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要
|
||||
CACHE_BACKEND_URL: Optional[str] = None
|
||||
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
|
||||
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
|
||||
CACHE_REDIS_MAXMEMORY: Optional[str] = None
|
||||
# 配置文件目录
|
||||
CONFIG_DIR: Optional[str] = None
|
||||
# 超级管理员
|
||||
SUPERUSER: str = "admin"
|
||||
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
API_TOKEN: Optional[str] = None
|
||||
# 全局图片缓存,将媒体图片缓存到本地
|
||||
GLOBAL_IMAGE_CACHE: bool = False
|
||||
# 全局图片缓存保留天数
|
||||
GLOBAL_IMAGE_CACHE_DAYS: int = 7
|
||||
# 临时文件保留天数
|
||||
TEMP_FILE_DAYS: int = 3
|
||||
# 元数据识别缓存过期时间(小时),0为自动
|
||||
META_CACHE_EXPIRE: int = 0
|
||||
|
||||
# ==================== 网络代理配置 ====================
|
||||
# 网络代理服务器地址
|
||||
PROXY_HOST: Optional[str] = None
|
||||
# 登录页面电影海报,tmdb/bing/mediaserver
|
||||
WALLPAPER: str = "tmdb"
|
||||
# 自定义壁纸api地址
|
||||
CUSTOMIZE_WALLPAPER_API_URL: Optional[str] = None
|
||||
# 媒体搜索来源 themoviedb/douban/bangumi,多个用,分隔
|
||||
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
|
||||
# 媒体识别来源 themoviedb/douban
|
||||
RECOGNIZE_SOURCE: str = "themoviedb"
|
||||
# 刮削来源 themoviedb/douban
|
||||
SCRAP_SOURCE: str = "themoviedb"
|
||||
# 新增已入库媒体是否跟随TMDB信息变化
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
# TMDB图片地址
|
||||
TMDB_IMAGE_DOMAIN: str = "image.tmdb.org"
|
||||
# TMDB API地址
|
||||
TMDB_API_DOMAIN: str = "api.themoviedb.org"
|
||||
# TMDB元数据语言
|
||||
TMDB_LOCALE: str = "zh"
|
||||
# 刮削使用TMDB原始语种图片
|
||||
TMDB_SCRAP_ORIGINAL_IMAGE: bool = False
|
||||
# TMDB API Key
|
||||
TMDB_API_KEY: str = "db55323b8d3e4154498498a75642b381"
|
||||
# TVDB API Key
|
||||
TVDB_V4_API_KEY: str = "ed2aa66b-7899-4677-92a7-67bc9ce3d93a"
|
||||
TVDB_V4_API_PIN: str = ""
|
||||
# Fanart开关
|
||||
FANART_ENABLE: bool = True
|
||||
# Fanart语言
|
||||
FANART_LANG: str = "zh,en"
|
||||
# Fanart API Key
|
||||
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
|
||||
# 115 AppId
|
||||
U115_APP_ID: str = "100196807"
|
||||
# Alipan AppId
|
||||
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
|
||||
# 元数据识别缓存过期时间(小时)
|
||||
META_CACHE_EXPIRE: int = 0
|
||||
# 电视剧动漫的分类genre_ids
|
||||
ANIME_GENREIDS: List[int] = Field(default=[16])
|
||||
# 用户认证站点
|
||||
AUTH_SITE: str = ""
|
||||
# 重启自动升级
|
||||
MOVIEPILOT_AUTO_UPDATE: str = 'release'
|
||||
# 自动检查和更新站点资源包(站点索引、认证等)
|
||||
AUTO_UPDATE_RESOURCE: bool = True
|
||||
# 是否启用DOH解析域名
|
||||
DOH_ENABLE: bool = False
|
||||
# 使用 DOH 解析的域名列表
|
||||
@@ -169,6 +168,55 @@ class ConfigModel(BaseModel):
|
||||
"api.telegram.org")
|
||||
# DOH 解析服务器列表
|
||||
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
|
||||
|
||||
# ==================== 媒体元数据配置 ====================
|
||||
# 媒体搜索来源 themoviedb/douban/bangumi,多个用,分隔
|
||||
SEARCH_SOURCE: str = "themoviedb"
|
||||
# 媒体识别来源 themoviedb/douban
|
||||
RECOGNIZE_SOURCE: str = "themoviedb"
|
||||
# 刮削来源 themoviedb/douban
|
||||
SCRAP_SOURCE: str = "themoviedb"
|
||||
# 电视剧动漫的分类genre_ids
|
||||
ANIME_GENREIDS: List[int] = Field(default=[16])
|
||||
|
||||
# ==================== TMDB配置 ====================
|
||||
# TMDB图片地址
|
||||
TMDB_IMAGE_DOMAIN: str = "image.tmdb.org"
|
||||
# TMDB API地址
|
||||
TMDB_API_DOMAIN: str = "api.themoviedb.org"
|
||||
# TMDB元数据语言
|
||||
TMDB_LOCALE: str = "zh"
|
||||
# 刮削使用TMDB原始语种图片
|
||||
TMDB_SCRAP_ORIGINAL_IMAGE: bool = False
|
||||
# TMDB API Key
|
||||
TMDB_API_KEY: str = "db55323b8d3e4154498498a75642b381"
|
||||
|
||||
# ==================== TVDB配置 ====================
|
||||
# TVDB API Key
|
||||
TVDB_V4_API_KEY: str = "ed2aa66b-7899-4677-92a7-67bc9ce3d93a"
|
||||
TVDB_V4_API_PIN: str = ""
|
||||
|
||||
# ==================== Fanart配置 ====================
|
||||
# Fanart开关
|
||||
FANART_ENABLE: bool = True
|
||||
# Fanart语言
|
||||
FANART_LANG: str = "zh,en"
|
||||
# Fanart API Key
|
||||
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
|
||||
|
||||
# ==================== 云盘配置 ====================
|
||||
# 115 AppId
|
||||
U115_APP_ID: str = "100196807"
|
||||
# Alipan AppId
|
||||
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
|
||||
|
||||
# ==================== 系统升级配置 ====================
|
||||
# 重启自动升级
|
||||
MOVIEPILOT_AUTO_UPDATE: str = 'release'
|
||||
# 自动检查和更新站点资源包(站点索引、认证等)
|
||||
AUTO_UPDATE_RESOURCE: bool = True
|
||||
|
||||
# ==================== 媒体文件格式配置 ====================
|
||||
# 支持的后缀格式
|
||||
RMT_MEDIAEXT: list = Field(
|
||||
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
|
||||
@@ -191,10 +239,12 @@ class ConfigModel(BaseModel):
|
||||
'.aifc', '.aiff', '.alac', '.adif', '.adts',
|
||||
'.flac', '.midi', '.opus', '.sfalc']
|
||||
)
|
||||
# 下载器临时文件后缀
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
|
||||
|
||||
# ==================== 媒体服务器配置 ====================
|
||||
# 媒体服务器同步间隔(小时)
|
||||
MEDIASERVER_SYNC_INTERVAL: int = 6
|
||||
|
||||
# ==================== 订阅配置 ====================
|
||||
# 订阅模式
|
||||
SUBSCRIBE_MODE: str = "spider"
|
||||
# RSS订阅模式刷新时间间隔(分钟)
|
||||
@@ -203,20 +253,42 @@ class ConfigModel(BaseModel):
|
||||
SUBSCRIBE_STATISTIC_SHARE: bool = True
|
||||
# 订阅搜索开关
|
||||
SUBSCRIBE_SEARCH: bool = False
|
||||
# 订阅搜索时间间隔(小时)
|
||||
SUBSCRIBE_SEARCH_INTERVAL: int = 24
|
||||
# 检查本地媒体库是否存在资源开关
|
||||
LOCAL_EXISTS_SEARCH: bool = False
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
LOCAL_EXISTS_SEARCH: bool = True
|
||||
|
||||
# ==================== 站点配置 ====================
|
||||
# 站点数据刷新间隔(小时)
|
||||
SITEDATA_REFRESH_INTERVAL: int = 6
|
||||
# 读取和发送站点消息
|
||||
SITE_MESSAGE: bool = True
|
||||
# 不能缓存站点资源的站点域名,多个使用,分隔
|
||||
NO_CACHE_SITE_KEY: str = "m-team"
|
||||
# OCR服务器地址,用于识别站点验证码
|
||||
OCR_HOST: str = "https://movie-pilot.org"
|
||||
# 仿真类型:playwright 或 flaresolverr
|
||||
BROWSER_EMULATION: str = "playwright"
|
||||
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
|
||||
FLARESOLVERR_URL: Optional[str] = None
|
||||
|
||||
# ==================== 搜索配置 ====================
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
# 最大搜索名称数量
|
||||
MAX_SEARCH_NAME_LIMIT: int = 2
|
||||
|
||||
# ==================== 下载配置 ====================
|
||||
# 种子标签
|
||||
TORRENT_TAG: str = "MOVIEPILOT"
|
||||
# 下载站点字幕
|
||||
DOWNLOAD_SUBTITLE: bool = True
|
||||
# 交互搜索自动下载用户ID,使用,分割
|
||||
AUTO_DOWNLOAD_USER: Optional[str] = None
|
||||
# 下载器临时文件后缀
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
|
||||
|
||||
# ==================== CookieCloud配置 ====================
|
||||
# CookieCloud是否启动本地服务
|
||||
COOKIECLOUD_ENABLE_LOCAL: Optional[bool] = False
|
||||
# CookieCloud服务器地址
|
||||
@@ -229,8 +301,8 @@ class ConfigModel(BaseModel):
|
||||
COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24
|
||||
# CookieCloud同步黑名单,多个域名,分割
|
||||
COOKIECLOUD_BLACKLIST: Optional[str] = None
|
||||
# CookieCloud对应的浏览器UA
|
||||
USER_AGENT: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.57"
|
||||
|
||||
# ==================== 整理配置 ====================
|
||||
# 电影重命名格式
|
||||
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
|
||||
@@ -240,10 +312,24 @@ class ConfigModel(BaseModel):
|
||||
"/Season {{season}}" \
|
||||
"/{{title}} - {{season_episode}}{% if part %}-{{part}}{% endif %}{% if episode %} - 第 {{episode}} 集{% endif %}" \
|
||||
"{{fileExt}}"
|
||||
# OCR服务器地址
|
||||
OCR_HOST: str = "https://movie-pilot.org"
|
||||
# 重命名时支持的S0别名
|
||||
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
|
||||
# 为指定默认字幕添加.default后缀
|
||||
DEFAULT_SUB: Optional[str] = "zh-cn"
|
||||
# 新增已入库媒体是否跟随TMDB信息变化
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
|
||||
# ==================== 服务地址配置 ====================
|
||||
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
|
||||
MP_SERVER_HOST: str = "https://movie-pilot.org"
|
||||
|
||||
# ==================== 个性化 ====================
|
||||
# 登录页面电影海报,tmdb/bing/mediaserver
|
||||
WALLPAPER: str = "tmdb"
|
||||
# 自定义壁纸api地址
|
||||
CUSTOMIZE_WALLPAPER_API_URL: Optional[str] = None
|
||||
|
||||
# ==================== 插件配置 ====================
|
||||
# 插件市场仓库地址,多个地址使用,分隔,地址以/结尾
|
||||
PLUGIN_MARKET: str = ("https://github.com/jxxghp/MoviePilot-Plugins,"
|
||||
"https://github.com/thsrite/MoviePilot-Plugins,"
|
||||
@@ -264,6 +350,8 @@ class ConfigModel(BaseModel):
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
PLUGIN_AUTO_RELOAD: bool = False
|
||||
|
||||
# ==================== Github & PIP ====================
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
GITHUB_TOKEN: Optional[str] = None
|
||||
# Github代理服务器,格式:https://mirror.ghproxy.com/
|
||||
@@ -272,20 +360,18 @@ class ConfigModel(BaseModel):
|
||||
PIP_PROXY: Optional[str] = ''
|
||||
# 指定的仓库Github token,多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
|
||||
REPO_GITHUB_TOKEN: Optional[str] = None
|
||||
|
||||
# ==================== 性能配置 ====================
|
||||
# 大内存模式
|
||||
BIG_MEMORY_MODE: bool = False
|
||||
# 是否启用内存监控
|
||||
MEMORY_ANALYSIS: bool = False
|
||||
# 内存快照间隔(分钟)
|
||||
MEMORY_SNAPSHOT_INTERVAL: int = 30
|
||||
# 保留的内存快照文件数量
|
||||
MEMORY_SNAPSHOT_KEEP_COUNT: int = 20
|
||||
# 全局图片缓存,将媒体图片缓存到本地
|
||||
GLOBAL_IMAGE_CACHE: bool = False
|
||||
# 是否启用编码探测的性能模式
|
||||
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
|
||||
# 编码探测的最低置信度阈值
|
||||
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
|
||||
# 主动内存回收时间间隔(分钟),0为不启用
|
||||
MEMORY_GC_INTERVAL: int = 30
|
||||
|
||||
# ==================== 安全配置 ====================
|
||||
# 允许的图片缓存域名
|
||||
SECURITY_IMAGE_DOMAINS: list = Field(default=[
|
||||
"image.tmdb.org",
|
||||
@@ -305,23 +391,58 @@ class ConfigModel(BaseModel):
|
||||
])
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# 重命名时支持的S0别名
|
||||
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
|
||||
# 为指定默认字幕添加.default后缀
|
||||
DEFAULT_SUB: Optional[str] = "zh-cn"
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
WORKFLOW_STATISTIC_SHARE: bool = True
|
||||
|
||||
# ==================== 存储配置 ====================
|
||||
# 对rclone进行快照对比时,是否检查文件夹的修改时间
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对OpenList进行快照对比时,是否检查文件夹的修改时间
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
系统配置类
|
||||
"""
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -418,33 +539,54 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
|
||||
return default, True
|
||||
|
||||
@validator('*', pre=True, always=True)
|
||||
def generic_type_validator(cls, value: Any, field): # noqa
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def generic_type_validator(cls, data: Any): # noqa
|
||||
"""
|
||||
通用校验器,尝试将配置值转换为期望的类型
|
||||
"""
|
||||
if field.name == "API_TOKEN":
|
||||
converted_value, needs_update = cls.validate_api_token(value, value)
|
||||
else:
|
||||
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
|
||||
field.name)
|
||||
if needs_update:
|
||||
cls.update_env_config(field, value, converted_value)
|
||||
return converted_value
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# 处理 API_TOKEN 特殊验证
|
||||
if 'API_TOKEN' in data:
|
||||
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
|
||||
if needs_update:
|
||||
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
|
||||
data['API_TOKEN'] = converted_value
|
||||
|
||||
# 对其他字段进行类型转换
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
value = data[field_name]
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
field = cls.model_fields.get(field_name)
|
||||
if field:
|
||||
converted_value, needs_update = cls.generic_type_converter(
|
||||
value, value, field.annotation, field.default, field_name
|
||||
)
|
||||
if needs_update:
|
||||
cls.update_env_config(field_name, value, converted_value)
|
||||
data[field_name] = converted_value
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 env 配置
|
||||
"""
|
||||
message = None
|
||||
is_converted = original_value is not None and str(original_value) != str(converted_value)
|
||||
if is_converted:
|
||||
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
logger.warning(message)
|
||||
|
||||
if field.name in os.environ:
|
||||
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
if field_name in os.environ:
|
||||
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
@@ -454,10 +596,10 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
else:
|
||||
value_to_write = str(converted_value) if converted_value is not None else ""
|
||||
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field.name, value_to_set=value_to_write,
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
|
||||
quote_mode="always")
|
||||
if is_converted:
|
||||
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
|
||||
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
|
||||
return True, message
|
||||
|
||||
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
|
||||
@@ -471,19 +613,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
|
||||
try:
|
||||
field = self.__fields__[key]
|
||||
field = Settings.model_fields[key]
|
||||
original_value = getattr(self, key)
|
||||
if field.name == "API_TOKEN":
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(value, original_value)
|
||||
else:
|
||||
converted_value, needs_update = self.generic_type_converter(value,
|
||||
original_value,
|
||||
field.type_,
|
||||
field.default,
|
||||
key)
|
||||
converted_value, needs_update = self.generic_type_converter(
|
||||
value, original_value, field.annotation, field.default, key
|
||||
)
|
||||
# 如果没有抛出异常,则统一使用 converted_value 进行更新
|
||||
if needs_update or str(value) != str(converted_value):
|
||||
success, message = self.update_env_config(field, value, converted_value)
|
||||
success, message = self.update_env_config(key, value, converted_value)
|
||||
# 仅成功更新配置时,才更新内存
|
||||
if success:
|
||||
setattr(self, key, converted_value)
|
||||
@@ -510,6 +650,20 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
return "v2"
|
||||
|
||||
@property
|
||||
def USER_AGENT(self) -> str:
|
||||
"""
|
||||
全局用户代理字符串
|
||||
"""
|
||||
return f"{self.PROJECT_NAME}/{APP_VERSION[1:]} ({platform.system()} {platform.release()}; {SystemUtils.cpu_arch()})"
|
||||
|
||||
@property
|
||||
def NORMAL_USER_AGENT(self) -> str:
|
||||
"""
|
||||
默认浏览器用户代理字符串
|
||||
"""
|
||||
return "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36"
|
||||
|
||||
@property
|
||||
def INNER_CONFIG_PATH(self):
|
||||
return self.ROOT_PATH / "config"
|
||||
@@ -561,11 +715,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
douban=512,
|
||||
bangumi=512,
|
||||
fanart=512,
|
||||
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
|
||||
meta=(self.META_CACHE_EXPIRE or 72) * 3600,
|
||||
scheduler=100,
|
||||
threadpool=100,
|
||||
dbpool=100,
|
||||
dbpooloverflow=50
|
||||
threadpool=100
|
||||
)
|
||||
return SystemConfModel(
|
||||
torrents=100,
|
||||
@@ -574,11 +726,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
douban=256,
|
||||
bangumi=256,
|
||||
fanart=128,
|
||||
meta=(self.META_CACHE_EXPIRE or 2) * 3600,
|
||||
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
|
||||
scheduler=50,
|
||||
threadpool=50,
|
||||
dbpool=50,
|
||||
dbpooloverflow=20
|
||||
threadpool=50
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -593,9 +743,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@property
|
||||
def PROXY_SERVER(self):
|
||||
if self.PROXY_HOST:
|
||||
return {
|
||||
"server": self.PROXY_HOST
|
||||
}
|
||||
try:
|
||||
parsed = urlparse(self.PROXY_HOST)
|
||||
if not parsed.scheme:
|
||||
return {"server": self.PROXY_HOST}
|
||||
host = parsed.hostname or ""
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
server = f"{parsed.scheme}://{host}{port}"
|
||||
proxy = {"server": server}
|
||||
if parsed.username:
|
||||
proxy["username"] = parsed.username
|
||||
if parsed.password:
|
||||
proxy["password"] = parsed.password
|
||||
return proxy
|
||||
except Exception as err:
|
||||
logger.error(f"解析代理服务器地址 '{self.PROXY_HOST}' 时出错: {err}")
|
||||
return {"server": self.PROXY_HOST}
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -606,7 +769,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
if self.GITHUB_TOKEN:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.GITHUB_TOKEN}",
|
||||
"User-Agent": self.USER_AGENT,
|
||||
"User-Agent": self.NORMAL_USER_AGENT,
|
||||
}
|
||||
return {}
|
||||
|
||||
@@ -635,7 +798,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
continue
|
||||
headers[repo_info] = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": self.USER_AGENT,
|
||||
"User-Agent": self.NORMAL_USER_AGENT,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"处理令牌对 '{token_pair}' 时出错: {e}")
|
||||
@@ -655,6 +818,23 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return None
|
||||
return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url)
|
||||
|
||||
def RENAME_FORMAT(self, media_type: MediaType):
|
||||
"""
|
||||
获取指定类型的重命名格式
|
||||
|
||||
:param media_type: MediaType.TV 或 MediaType.Movie
|
||||
:return: 重命名格式
|
||||
"""
|
||||
rename_format = (
|
||||
self.TV_RENAME_FORMAT
|
||||
if media_type == MediaType.TV
|
||||
else self.MOVIE_RENAME_FORMAT
|
||||
)
|
||||
# 规范重命名格式
|
||||
rename_format = rename_format.replace("\\", "/")
|
||||
rename_format = re.sub(r'/+', '/', rename_format)
|
||||
return rename_format.strip("/")
|
||||
|
||||
|
||||
# 实例化配置
|
||||
settings = Settings()
|
||||
@@ -670,6 +850,8 @@ class GlobalVar(object):
|
||||
SUBSCRIPTIONS: List[dict] = []
|
||||
# 需应急停止的工作流
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -710,12 +892,30 @@ class GlobalVar(object):
|
||||
if workflow_id in self.EMERGENCY_STOP_WORKFLOWS:
|
||||
self.EMERGENCY_STOP_WORKFLOWS.remove(workflow_id)
|
||||
|
||||
def is_workflow_stopped(self, workflow_id: int):
|
||||
def is_workflow_stopped(self, workflow_id: int) -> bool:
|
||||
"""
|
||||
是否停止工作流
|
||||
"""
|
||||
return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS
|
||||
|
||||
def stop_transfer(self, path: str):
|
||||
"""
|
||||
停止文件整理
|
||||
"""
|
||||
if path not in self.EMERGENCY_STOP_TRANSFER:
|
||||
self.EMERGENCY_STOP_TRANSFER.append(path)
|
||||
|
||||
def is_transfer_stopped(self, path: str) -> bool:
|
||||
"""
|
||||
是否停止文件整理
|
||||
"""
|
||||
if self.is_system_stopped:
|
||||
return True
|
||||
if path in self.EMERGENCY_STOP_TRANSFER:
|
||||
self.EMERGENCY_STOP_TRANSFER.remove(path)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# 全局标识
|
||||
global_vars = GlobalVar()
|
||||
|
||||
@@ -193,7 +193,7 @@ class MediaInfo:
|
||||
# LOGO
|
||||
logo_path: str = None
|
||||
# 评分
|
||||
vote_average: float = 0.0
|
||||
vote_average: float = None
|
||||
# 描述
|
||||
overview: str = None
|
||||
# 风格ID
|
||||
@@ -237,9 +237,9 @@ class MediaInfo:
|
||||
# 流媒体平台
|
||||
networks: list = field(default_factory=list)
|
||||
# 集数
|
||||
number_of_episodes: int = 0
|
||||
number_of_episodes: int = None
|
||||
# 季数
|
||||
number_of_seasons: int = 0
|
||||
number_of_seasons: int = None
|
||||
# 原产国
|
||||
origin_country: list = field(default_factory=list)
|
||||
# 原名
|
||||
@@ -250,14 +250,16 @@ class MediaInfo:
|
||||
production_countries: list = field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: list = field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = field(default_factory=list)
|
||||
# 状态
|
||||
status: str = None
|
||||
# 标签
|
||||
tagline: str = None
|
||||
# 评价数量
|
||||
vote_count: int = 0
|
||||
vote_count: int = None
|
||||
# 流行度
|
||||
popularity: int = 0
|
||||
popularity: float = None
|
||||
# 时长
|
||||
runtime: int = None
|
||||
# 下一集
|
||||
@@ -433,6 +435,18 @@ class MediaInfo:
|
||||
if self.release_date:
|
||||
# 年份
|
||||
self.year = self.release_date[:4]
|
||||
# 所有发行日期
|
||||
self.release_dates = [
|
||||
{
|
||||
"date": release_date.get("release_date"),
|
||||
"iso_code": result.get("iso_3166_1"),
|
||||
"note": release_date.get("note"),
|
||||
"type": release_date.get("type"),
|
||||
}
|
||||
for result in info.get("release_dates", {}).get("results", [])
|
||||
for release_date in result.get("release_dates", [])
|
||||
if release_date.get("release_date")
|
||||
]
|
||||
else:
|
||||
# 电视剧
|
||||
self.title = info.get('name')
|
||||
@@ -474,7 +488,16 @@ class MediaInfo:
|
||||
self.names = info.get('names') or []
|
||||
# 剩余属性赋值
|
||||
for key, value in info.items():
|
||||
if hasattr(self, key) and not getattr(self, key):
|
||||
if not value:
|
||||
continue
|
||||
if not hasattr(self, key):
|
||||
continue
|
||||
current_value = getattr(self, key)
|
||||
if current_value:
|
||||
continue
|
||||
if current_value is None:
|
||||
setattr(self, key, value)
|
||||
elif type(current_value) is type(value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def set_douban_info(self, info: dict):
|
||||
@@ -606,7 +629,16 @@ class MediaInfo:
|
||||
self.production_countries = [{"id": country, "name": country} for country in info.get("countries") or []]
|
||||
# 剩余属性赋值
|
||||
for key, value in info.items():
|
||||
if not value:
|
||||
continue
|
||||
if not hasattr(self, key):
|
||||
continue
|
||||
current_value = getattr(self, key)
|
||||
if current_value:
|
||||
continue
|
||||
if current_value is None:
|
||||
setattr(self, key, value)
|
||||
elif type(current_value) is type(value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def set_bangumi_info(self, info: dict):
|
||||
@@ -796,6 +828,8 @@ class Context:
|
||||
media_info: MediaInfo = None
|
||||
# 种子信息
|
||||
torrent_info: TorrentInfo = None
|
||||
# 媒体识别失败次数
|
||||
media_recognize_fail_count: int = 0
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -804,5 +838,6 @@ class Context:
|
||||
return {
|
||||
"meta_info": self.meta_info.to_dict() if self.meta_info else None,
|
||||
"torrent_info": self.torrent_info.to_dict() if self.torrent_info else None,
|
||||
"media_info": self.media_info.to_dict() if self.media_info else None
|
||||
"media_info": self.media_info.to_dict() if self.media_info else None,
|
||||
"media_recognize_fail_count": self.media_recognize_fail_count
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import copy
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import random
|
||||
@@ -7,7 +7,9 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
@@ -69,18 +71,27 @@ class EventManager(metaclass=Singleton):
|
||||
EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件
|
||||
"""
|
||||
|
||||
# 退出事件
|
||||
__event = threading.Event()
|
||||
|
||||
def __init__(self):
|
||||
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
|
||||
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
|
||||
self.__event_queue = PriorityQueue() # 优先级队列
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {} # 广播事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {} # 链式事件的订阅者
|
||||
self.__disabled_handlers = set() # 禁用的事件处理器集合
|
||||
self.__disabled_classes = set() # 禁用的事件处理器类集合
|
||||
self.__lock = threading.Lock() # 线程锁
|
||||
# 动态线程池,用于消费事件
|
||||
self.__executor = ThreadHelper()
|
||||
# 用于保存启动的事件消费者线程
|
||||
self.__consumer_threads = []
|
||||
# 优先级队列
|
||||
self.__event_queue = PriorityQueue()
|
||||
# 广播事件的订阅者
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {}
|
||||
# 链式事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {}
|
||||
# 禁用的事件处理器集合
|
||||
self.__disabled_handlers = set()
|
||||
# 禁用的事件处理器类集合
|
||||
self.__disabled_classes = set()
|
||||
# 线程锁
|
||||
self.__lock = threading.Lock()
|
||||
# 退出事件
|
||||
self.__event = threading.Event()
|
||||
# 当前事件循环
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -144,6 +155,25 @@ class EventManager(metaclass=Singleton):
|
||||
logger.error(f"Unknown event type: {etype}")
|
||||
return None
|
||||
|
||||
async def async_send_event(self, etype: Union[EventType, ChainEventType],
|
||||
data: Optional[Union[Dict, ChainEventData]] = None,
|
||||
priority: Optional[int] = DEFAULT_EVENT_PRIORITY) -> Optional[Event]:
|
||||
"""
|
||||
异步发送事件,根据事件类型决定是广播事件还是链式事件
|
||||
:param etype: 事件类型 (EventType 或 ChainEventType)
|
||||
:param data: 可选,事件数据
|
||||
:param priority: 广播事件的优先级,默认为 10
|
||||
:return: 如果是链式事件,返回处理后的事件数据;否则返回 None
|
||||
"""
|
||||
event = Event(etype, data, priority)
|
||||
if isinstance(etype, EventType):
|
||||
return self.__trigger_broadcast_event(event)
|
||||
elif isinstance(etype, ChainEventType):
|
||||
return await self.__trigger_chain_event_async(event)
|
||||
else:
|
||||
logger.error(f"Unknown event type: {etype}")
|
||||
return None
|
||||
|
||||
def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable,
|
||||
priority: Optional[int] = DEFAULT_EVENT_PRIORITY):
|
||||
"""
|
||||
@@ -327,6 +357,14 @@ class EventManager(metaclass=Singleton):
|
||||
dispatch = self.__dispatch_chain_event(event)
|
||||
return event if dispatch else None
|
||||
|
||||
async def __trigger_chain_event_async(self, event: Event) -> Optional[Event]:
|
||||
"""
|
||||
异步触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
|
||||
"""
|
||||
logger.debug(f"Triggering asynchronous chain event: {event}")
|
||||
dispatch = await self.__dispatch_chain_event_async(event)
|
||||
return event if dispatch else None
|
||||
|
||||
def __trigger_broadcast_event(self, event: Event):
|
||||
"""
|
||||
触发广播事件,将事件插入到优先级队列中
|
||||
@@ -364,6 +402,35 @@ class EventManager(metaclass=Singleton):
|
||||
self.__log_event_lifecycle(event, "Completed")
|
||||
return True
|
||||
|
||||
async def __dispatch_chain_event_async(self, event: Event) -> bool:
|
||||
"""
|
||||
异步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
|
||||
:param event: 要调度的事件对象
|
||||
"""
|
||||
handlers = self.__chain_subscribers.get(event.event_type, {})
|
||||
if not handlers:
|
||||
logger.debug(f"No handlers found for chain event: {event}")
|
||||
return False
|
||||
|
||||
# 过滤出启用的处理器
|
||||
enabled_handlers = {handler_id: (priority, handler) for handler_id, (priority, handler) in handlers.items()
|
||||
if self.__is_handler_enabled(handler)}
|
||||
|
||||
if not enabled_handlers:
|
||||
logger.debug(f"No enabled handlers found for chain event: {event}. Skipping execution.")
|
||||
return False
|
||||
|
||||
self.__log_event_lifecycle(event, "Started")
|
||||
for handler_id, (priority, handler) in enabled_handlers.items():
|
||||
start_time = time.time()
|
||||
await self.__safe_invoke_handler_async(handler, event)
|
||||
logger.debug(
|
||||
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
|
||||
f"completed in {time.time() - start_time:.3f}s for event: {event}"
|
||||
)
|
||||
self.__log_event_lifecycle(event, "Completed")
|
||||
return True
|
||||
|
||||
def __dispatch_broadcast_event(self, event: Event):
|
||||
"""
|
||||
异步方式调度广播事件,通过线程池逐个调用事件处理器
|
||||
@@ -373,8 +440,25 @@ class EventManager(metaclass=Singleton):
|
||||
if not handlers:
|
||||
logger.debug(f"No handlers found for broadcast event: {event}")
|
||||
return
|
||||
# 为每个处理器提供独立的事件实例,防止某个处理器对 event_data 的修改影响其他处理器
|
||||
for handler_id, handler in handlers.items():
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, event)
|
||||
# 仅浅拷贝顶层字典,避免不必要的深拷贝开销;这样可以隔离键级别的替换/赋值
|
||||
if isinstance(event.event_data, dict):
|
||||
event_data_copy = event.event_data.copy()
|
||||
else:
|
||||
event_data_copy = event.event_data
|
||||
isolated_event = Event(event_type=event.event_type,
|
||||
event_data=event_data_copy,
|
||||
priority=event.priority)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
self.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
|
||||
|
||||
def __safe_invoke_handler(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -386,48 +470,162 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
# 根据事件类型判断是否需要深复制
|
||||
is_broadcast_event = isinstance(event.event_type, EventType)
|
||||
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
|
||||
self.__invoke_handler_by_type_sync(handler, event)
|
||||
|
||||
async def __safe_invoke_handler_async(self, handler: Callable, event: Event):
|
||||
"""
|
||||
异步调用处理器,处理链式事件
|
||||
:param handler: 处理器
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not self.__is_handler_enabled(handler):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
await self.__invoke_handler_by_type_async(handler, event)
|
||||
|
||||
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
|
||||
"""
|
||||
同步方式根据处理器类型调用相应的方法
|
||||
:param handler: 处理器
|
||||
:param event: 要处理的事件对象
|
||||
"""
|
||||
class_name, method_name = self.__parse_handler_names(handler)
|
||||
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
module_manager = ModuleManager()
|
||||
|
||||
if class_name in plugin_manager.get_plugin_ids():
|
||||
# 插件处理器
|
||||
plugin = plugin_manager.running_plugins.get(class_name)
|
||||
if not plugin:
|
||||
return
|
||||
method = getattr(plugin, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
# 模块处理器
|
||||
module = module_manager.get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=module.get_name(),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
else:
|
||||
# 全局处理器
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
return
|
||||
method = getattr(class_obj, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=class_name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
|
||||
"""
|
||||
异步方式根据处理器类型调用相应的方法
|
||||
:param handler: 处理器
|
||||
:param event: 要处理的事件对象
|
||||
"""
|
||||
class_name, method_name = self.__parse_handler_names(handler)
|
||||
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
module_manager = ModuleManager()
|
||||
|
||||
if class_name in plugin_manager.get_plugin_ids():
|
||||
await self.__invoke_plugin_method_async(plugin_manager, class_name, method_name, event)
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
|
||||
else:
|
||||
await self.__invoke_global_method_async(class_name, method_name, event)
|
||||
|
||||
@staticmethod
|
||||
def __parse_handler_names(handler: Callable) -> Tuple[str, str]:
|
||||
"""
|
||||
解析处理器的类名和方法名
|
||||
:param handler: 处理器
|
||||
:return: (class_name, method_name)
|
||||
"""
|
||||
names = handler.__qualname__.split(".")
|
||||
class_name, method_name = names[0], names[1]
|
||||
return names[0], names[1]
|
||||
|
||||
async def __invoke_plugin_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用插件方法
|
||||
"""
|
||||
plugin = handler.running_plugins.get(class_name)
|
||||
if not plugin:
|
||||
return
|
||||
method = getattr(plugin, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
if class_name in PluginManager().get_plugin_ids():
|
||||
def plugin_callable():
|
||||
"""
|
||||
插件调用函数
|
||||
"""
|
||||
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
|
||||
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(plugin_callable)
|
||||
else:
|
||||
plugin_callable()
|
||||
elif class_name in ModuleManager().get_module_ids():
|
||||
module = ModuleManager().get_running_module(class_name)
|
||||
if module:
|
||||
method = getattr(module, method_name, None)
|
||||
if method:
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(method, event_to_process)
|
||||
else:
|
||||
method(event_to_process)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
# 获取全局对象或模块类的实例
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if class_obj and hasattr(class_obj, method_name):
|
||||
method = getattr(class_obj, method_name)
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(method, event_to_process)
|
||||
else:
|
||||
method(event_to_process)
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
await run_in_threadpool(method, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event, handler, e)
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_module_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用模块方法
|
||||
"""
|
||||
module = handler.get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=module.get_name(),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用全局对象方法
|
||||
"""
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj:
|
||||
return
|
||||
method = getattr(class_obj, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=class_name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
@staticmethod
|
||||
def __get_class_instance(class_name: str):
|
||||
@@ -454,7 +652,11 @@ class EventManager(metaclass=Singleton):
|
||||
module_name = f"app.chain.{class_name[:-5].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
elif class_name.endswith("Helper"):
|
||||
module_name = f"app.helper.{class_name[:-6].lower()}"
|
||||
# 特殊处理 Async 类
|
||||
if class_name.startswith("Async"):
|
||||
module_name = f"app.helper.{class_name[5:-6].lower()}"
|
||||
else:
|
||||
module_name = f"app.helper.{class_name[:-6].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
else:
|
||||
module_name = f"app.{class_name.lower()}"
|
||||
@@ -494,18 +696,16 @@ class EventManager(metaclass=Singleton):
|
||||
"""
|
||||
logger.debug(f"{stage} - {event}")
|
||||
|
||||
def __handle_event_error(self, event: Event, handler: Callable, e: Exception):
|
||||
def __handle_event_error(self, event: Event, module_name: str,
|
||||
class_name: str, method_name: str, e: Exception):
|
||||
"""
|
||||
全局错误处理器,用于处理事件处理中的异常
|
||||
"""
|
||||
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
|
||||
|
||||
names = handler.__qualname__.split(".")
|
||||
class_name, method_name = names[0], names[1]
|
||||
logger.error(f"{module_name} 事件处理出错:{str(e)} - {traceback.format_exc()}")
|
||||
|
||||
# 发送系统错误通知
|
||||
from app.helper.message import MessageHelper
|
||||
MessageHelper().put(title=f"{event.event_type} 事件处理出错",
|
||||
MessageHelper().put(title=f"{module_name} 处理事件 {event.event_type} 时出错",
|
||||
message=f"{class_name}.{method_name}:{str(e)}",
|
||||
role="system")
|
||||
self.send_event(
|
||||
|
||||
@@ -9,8 +9,6 @@ class CustomizationMatcher(metaclass=Singleton):
|
||||
"""
|
||||
识别自定义占位符
|
||||
"""
|
||||
customization = None
|
||||
custom_separator = None
|
||||
|
||||
def __init__(self):
|
||||
self.systemconfig = SystemConfigOper()
|
||||
|
||||
@@ -55,6 +55,8 @@ class MetaBase(object):
|
||||
resource_team: Optional[str] = None
|
||||
# 识别的自定义占位符
|
||||
customization: Optional[str] = None
|
||||
# 识别的流媒体平台
|
||||
web_source: Optional[str] = None
|
||||
# 视频编码
|
||||
video_encode: Optional[str] = None
|
||||
# 音频编码
|
||||
|
||||
@@ -67,7 +67,6 @@ class MetaVideo(MetaBase):
|
||||
original_title = title
|
||||
self._source = ""
|
||||
self._effect = []
|
||||
self.web_source = None
|
||||
self._index = 0
|
||||
# 判断是否纯数字命名
|
||||
if isfile \
|
||||
@@ -95,7 +94,6 @@ class MetaVideo(MetaBase):
|
||||
title = re.sub(r'\d{4}[\s._-]\d{1,2}[\s._-]\d{1,2}', "", title)
|
||||
# 拆分tokens
|
||||
tokens = Tokens(title)
|
||||
self.tokens = tokens
|
||||
# 实例化StreamingPlatforms对象
|
||||
streaming_platforms = StreamingPlatforms()
|
||||
# 解析名称、年份、季、集、资源类型、分辨率等
|
||||
@@ -103,7 +101,7 @@ class MetaVideo(MetaBase):
|
||||
while token:
|
||||
self._index += 1 # 更新当前处理的token索引
|
||||
# Part
|
||||
self.__init_part(token)
|
||||
self.__init_part(token, tokens)
|
||||
# 标题
|
||||
if self._continue_flag:
|
||||
self.__init_name(token)
|
||||
@@ -124,7 +122,7 @@ class MetaVideo(MetaBase):
|
||||
self.__init_resource_type(token)
|
||||
# 流媒体平台
|
||||
if self._continue_flag:
|
||||
self.__init_web_source(token, streaming_platforms)
|
||||
self.__init_web_source(token, tokens, streaming_platforms)
|
||||
# 视频编码
|
||||
if self._continue_flag:
|
||||
self.__init_video_encode(token)
|
||||
@@ -140,9 +138,6 @@ class MetaVideo(MetaBase):
|
||||
self.resource_effect = " ".join(self._effect)
|
||||
if self._source:
|
||||
self.resource_type = self._source.strip()
|
||||
# 添加流媒体平台
|
||||
if self.web_source:
|
||||
self.resource_type = f"{self.web_source} {self.resource_type}"
|
||||
# 提取原盘DIY
|
||||
if self.resource_type and "BluRay" in self.resource_type:
|
||||
if (self.subtitle and re.findall(r'D[Ii]Y', self.subtitle)) \
|
||||
@@ -204,7 +199,7 @@ class MetaVideo(MetaBase):
|
||||
name = re.sub(r'%s' % self._name_nostring_re, '', name,
|
||||
flags=re.IGNORECASE).strip()
|
||||
name = re.sub(r'\s+', ' ', name)
|
||||
if name.isdigit() \
|
||||
if name.isdecimal() \
|
||||
and int(name) < 1800 \
|
||||
and not self.year \
|
||||
and not self.begin_season \
|
||||
@@ -315,7 +310,7 @@ class MetaVideo(MetaBase):
|
||||
self.en_name = token
|
||||
self._last_token_type = "enname"
|
||||
|
||||
def __init_part(self, token: str):
|
||||
def __init_part(self, token: str, tokens: Tokens):
|
||||
"""
|
||||
识别Part
|
||||
"""
|
||||
@@ -331,12 +326,12 @@ class MetaVideo(MetaBase):
|
||||
if re_res:
|
||||
if not self.part:
|
||||
self.part = re_res.group(1)
|
||||
nextv = self.tokens.cur()
|
||||
nextv = tokens.cur()
|
||||
if nextv \
|
||||
and ((nextv.isdigit() and (len(nextv) == 1 or len(nextv) == 2 and nextv.startswith('0')))
|
||||
or nextv.upper() in ['A', 'B', 'C', 'I', 'II', 'III']):
|
||||
self.part = "%s%s" % (self.part, nextv)
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
self._last_token_type = "part"
|
||||
self._continue_flag = False
|
||||
# self._stop_name_flag = False
|
||||
@@ -586,7 +581,7 @@ class MetaVideo(MetaBase):
|
||||
self._effect.append(effect)
|
||||
self._last_token = effect.upper()
|
||||
|
||||
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
|
||||
def __init_web_source(self, token: str, tokens: Tokens, streaming_platforms: StreamingPlatforms):
|
||||
"""
|
||||
识别流媒体平台
|
||||
"""
|
||||
@@ -598,10 +593,10 @@ class MetaVideo(MetaBase):
|
||||
|
||||
prev_token = None
|
||||
prev_idx = self._index - 2
|
||||
if 0 <= prev_idx < len(self.tokens.tokens):
|
||||
prev_token = self.tokens.tokens[prev_idx]
|
||||
if 0 <= prev_idx < len(tokens.tokens):
|
||||
prev_token = tokens.tokens[prev_idx]
|
||||
|
||||
next_token = self.tokens.peek()
|
||||
next_token = tokens.peek()
|
||||
|
||||
if streaming_platforms.is_streaming_platform(token):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(token)
|
||||
@@ -620,7 +615,7 @@ class MetaVideo(MetaBase):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
|
||||
query_range = 2
|
||||
if is_next:
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
break
|
||||
|
||||
if not platform_name:
|
||||
@@ -630,8 +625,8 @@ class MetaVideo(MetaBase):
|
||||
match_start_idx = self._index - query_range
|
||||
match_end_idx = self._index - 1
|
||||
start_index = max(0, match_start_idx - query_range)
|
||||
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = self.tokens.tokens[start_index:end_index]
|
||||
end_index = min(len(tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = tokens.tokens[start_index:end_index]
|
||||
|
||||
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
|
||||
self.web_source = platform_name
|
||||
|
||||
@@ -9,7 +9,6 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
"""
|
||||
识别制作组、字幕组
|
||||
"""
|
||||
__release_groups: str = None
|
||||
# 内置组
|
||||
RELEASE_GROUPS: dict = {
|
||||
"0ff": ['FF(?:(?:A|WE)B|CD|E(?:DU|B)|TV)'],
|
||||
@@ -48,7 +47,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
"joyhd": [],
|
||||
"keepfrds": ['FRDS', 'Yumi', 'cXcY'],
|
||||
"lemonhd": ['L(?:eague(?:(?:C|H)D|(?:M|T)V|NF|WEB)|HD)', 'i18n', 'CiNT'],
|
||||
"mteam": ['MTeam(?:TV|)', 'MPAD'],
|
||||
"mteam": ['MTeam(?:TV|)', 'MPAD', 'MWeb'],
|
||||
"nanyangpt": [],
|
||||
"nicept": [],
|
||||
"oshen": [],
|
||||
@@ -70,7 +69,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
"U2": [],
|
||||
"ultrahd": [],
|
||||
"others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:yG|)',
|
||||
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )',],
|
||||
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )'],
|
||||
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
|
||||
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
|
||||
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社',
|
||||
@@ -106,10 +105,11 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
else:
|
||||
groups = self.__release_groups
|
||||
title = f"{title} "
|
||||
groups_re = re.compile(r"(?<=[-@\[£【&])(?:%s)(?=[@.\s\S\]\[】&])" % groups, re.I)
|
||||
# 处理一个制作组识别多次的情况,保留顺序
|
||||
groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=[@.\s\S\]\[】&])" % groups, re.I)
|
||||
unique_groups = []
|
||||
for item in re.findall(groups_re, title):
|
||||
if item not in unique_groups:
|
||||
unique_groups.append(item)
|
||||
item_str = item[0] if isinstance(item, tuple) else item
|
||||
if item_str not in unique_groups:
|
||||
unique_groups.append(item_str)
|
||||
|
||||
return "@".join(unique_groups)
|
||||
|
||||
@@ -312,4 +312,3 @@ class StreamingPlatforms(metaclass=Singleton):
|
||||
if name is None:
|
||||
return False
|
||||
return name.upper() in self._lookup_cache
|
||||
|
||||
|
||||
@@ -154,35 +154,35 @@ def find_metainfo(title: str) -> Tuple[str, dict]:
|
||||
# 去除title中该部分
|
||||
if tmdbid or mtype or begin_season or end_season or begin_episode or end_episode:
|
||||
title = title.replace(f"{{[{result}]}}", '')
|
||||
|
||||
|
||||
# 支持Emby格式的ID标签
|
||||
# 1. [tmdbid=xxxx] 或 [tmdbid-xxxx] 格式
|
||||
tmdb_match = re.search(r'\[tmdbid[=\-](\d+)\]', title)
|
||||
if tmdb_match:
|
||||
metainfo['tmdbid'] = tmdb_match.group(1)
|
||||
title = re.sub(r'\[tmdbid[=\-](\d+)\]', '', title).strip()
|
||||
|
||||
|
||||
# 2. [tmdb=xxxx] 或 [tmdb-xxxx] 格式
|
||||
if not metainfo['tmdbid']:
|
||||
tmdb_match = re.search(r'\[tmdb[=\-](\d+)\]', title)
|
||||
if tmdb_match:
|
||||
metainfo['tmdbid'] = tmdb_match.group(1)
|
||||
title = re.sub(r'\[tmdb[=\-](\d+)\]', '', title).strip()
|
||||
|
||||
|
||||
# 3. {tmdbid=xxxx} 或 {tmdbid-xxxx} 格式
|
||||
if not metainfo['tmdbid']:
|
||||
tmdb_match = re.search(r'\{tmdbid[=\-](\d+)\}', title)
|
||||
if tmdb_match:
|
||||
metainfo['tmdbid'] = tmdb_match.group(1)
|
||||
title = re.sub(r'\{tmdbid[=\-](\d+)\}', '', title).strip()
|
||||
|
||||
|
||||
# 4. {tmdb=xxxx} 或 {tmdb-xxxx} 格式
|
||||
if not metainfo['tmdbid']:
|
||||
tmdb_match = re.search(r'\{tmdb[=\-](\d+)\}', title)
|
||||
if tmdb_match:
|
||||
metainfo['tmdbid'] = tmdb_match.group(1)
|
||||
title = re.sub(r'\{tmdb[=\-](\d+)\}', '', title).strip()
|
||||
|
||||
|
||||
# 计算季集总数
|
||||
if metainfo.get('begin_season') and metainfo.get('end_season'):
|
||||
if metainfo['begin_season'] > metainfo['end_season']:
|
||||
|
||||
@@ -16,14 +16,14 @@ class ModuleManager(metaclass=Singleton):
|
||||
模块管理器
|
||||
"""
|
||||
|
||||
# 模块列表
|
||||
_modules: dict = {}
|
||||
# 运行态模块列表
|
||||
_running_modules: dict = {}
|
||||
# 子模块类型集合
|
||||
SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType]
|
||||
|
||||
def __init__(self):
|
||||
# 模块列表
|
||||
self._modules: dict = {}
|
||||
# 运行态模块列表
|
||||
self._running_modules: dict = {}
|
||||
self.load_modules()
|
||||
|
||||
def load_modules(self):
|
||||
@@ -48,7 +48,7 @@ class ModuleManager(metaclass=Singleton):
|
||||
# 通过模板开关控制加载
|
||||
_module.init_module()
|
||||
self._running_modules[module_id] = _module
|
||||
logger.info(f"Moudle Loaded:{module_id}")
|
||||
logger.debug(f"Moudle Loaded:{module_id}")
|
||||
except Exception as err:
|
||||
logger.error(f"Load Moudle Error:{module_id},{str(err)} - {traceback.format_exc()}", exc_info=True)
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModuleManager(metaclass=Singleton):
|
||||
if hasattr(module, "stop"):
|
||||
try:
|
||||
module.stop()
|
||||
logger.info(f"Moudle Stoped:{module_id}")
|
||||
logger.debug(f"Moudle Stoped:{module_id}")
|
||||
except Exception as err:
|
||||
logger.error(f"Stop Moudle Error:{module_id},{str(err)} - {traceback.format_exc()}", exc_info=True)
|
||||
logger.info("所有模块停止完成")
|
||||
|
||||
@@ -1,102 +1,55 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
from watchfiles import watch
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.limit import rate_limit_window
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginMonitorHandler(FileSystemEventHandler):
|
||||
|
||||
def on_modified(self, event):
|
||||
"""
|
||||
插件文件修改后重载
|
||||
"""
|
||||
if event.is_directory:
|
||||
return
|
||||
# 使用 pathlib 处理文件路径,跳过非 .py 文件以及 pycache 目录中的文件
|
||||
event_path = Path(event.src_path)
|
||||
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
|
||||
return
|
||||
|
||||
# 读取插件根目录下的__init__.py文件,读取class XXXX(_PluginBase)的类名
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if plugins_root not in event_path.parents:
|
||||
return
|
||||
# 获取插件目录路径,没有找到__init__.py时,说明不是有效包,跳过插件重载
|
||||
# 插件重载目前没有支持app/plugins/plugin/package/__init__.py的场景,这里也不做支持
|
||||
plugin_dir = event_path.parent
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
logger.debug(f"{plugin_dir} 下没有找到 __init__.py,跳过插件重载")
|
||||
return
|
||||
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
pid = None
|
||||
for line in lines:
|
||||
if line.startswith("class") and "(_PluginBase)" in line:
|
||||
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
|
||||
if pid:
|
||||
self.__reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
@staticmethod
|
||||
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
|
||||
def __reload_plugin(pid):
|
||||
"""
|
||||
重新加载插件
|
||||
"""
|
||||
try:
|
||||
logger.info(f"插件 {pid} 文件修改,重新加载...")
|
||||
PluginManager().reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
"""
|
||||
|
||||
# 插件列表
|
||||
_plugins: dict = {}
|
||||
# 运行态插件列表
|
||||
_running_plugins: dict = {}
|
||||
# 配置Key
|
||||
_config_key: str = "plugin.%s"
|
||||
# 监听器
|
||||
_observer: Observer = None
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
self._plugins: dict = {}
|
||||
# 运行态插件列表
|
||||
self._running_plugins: dict = {}
|
||||
# 配置Key
|
||||
self._config_key: str = "plugin.%s"
|
||||
# 监控线程
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -198,10 +151,14 @@ class PluginManager(metaclass=Singleton):
|
||||
# 清空指定插件
|
||||
self._plugins.pop(pid, None)
|
||||
self._running_plugins.pop(pid, None)
|
||||
# 清除插件模块缓存,包括所有子模块
|
||||
self._clear_plugin_modules(pid)
|
||||
else:
|
||||
# 清空
|
||||
self._plugins = {}
|
||||
self._running_plugins = {}
|
||||
# 清除所有插件模块缓存
|
||||
self._clear_plugin_modules()
|
||||
logger.info("插件停止完成")
|
||||
|
||||
@staticmethod
|
||||
@@ -259,7 +216,6 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(module_name)
|
||||
importlib.reload(module)
|
||||
|
||||
# 检查模块中的类
|
||||
for name, obj in module.__dict__.items():
|
||||
@@ -313,10 +269,9 @@ class PluginManager(metaclass=Singleton):
|
||||
重新加载插件文件修改监测
|
||||
"""
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
if self._observer and self._observer.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
else:
|
||||
self.__start_monitor()
|
||||
# 先关闭已有监测,再重新启动
|
||||
self.stop_monitor()
|
||||
self.__start_monitor()
|
||||
else:
|
||||
self.stop_monitor()
|
||||
|
||||
@@ -324,25 +279,123 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
启用监测插件文件修改监测
|
||||
"""
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
return
|
||||
|
||||
logger.info("开始监测插件文件修改...")
|
||||
monitor_handler = PluginMonitorHandler()
|
||||
self._observer = Observer()
|
||||
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
|
||||
self._observer.start()
|
||||
|
||||
# 在启动新线程之前,确保停止事件是清除状态
|
||||
self._stop_monitor_event.clear()
|
||||
|
||||
# 创建并启动监控线程
|
||||
self._monitor_thread = threading.Thread(
|
||||
target=self._run_file_watcher,
|
||||
daemon=True
|
||||
)
|
||||
self._monitor_thread.start()
|
||||
|
||||
def stop_monitor(self):
|
||||
"""
|
||||
停止监测插件文件修改监测
|
||||
"""
|
||||
# 停止监测
|
||||
if self._observer and self._observer.is_alive():
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("正在停止插件文件修改监测...")
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
self._stop_monitor_event.set()
|
||||
self._monitor_thread.join(timeout=5)
|
||||
if self._monitor_thread.is_alive():
|
||||
logger.warning("插件文件修改监测线程在5秒内未能正常停止。")
|
||||
self._monitor_thread = None
|
||||
logger.info("插件文件修改监测停止完成")
|
||||
else:
|
||||
logger.info("未启用插件文件修改监测,无需停止")
|
||||
|
||||
def _run_file_watcher(self):
|
||||
"""
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
if plugins_to_reload:
|
||||
logger.info(f"检测到插件文件变化,准备重载: {list(plugins_to_reload)}")
|
||||
for pid in plugins_to_reload:
|
||||
try:
|
||||
self.reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {pid} 热重载失败: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_id_from_path(event_path: Path) -> Optional[str]:
|
||||
"""
|
||||
根据文件路径解析出插件的ID。
|
||||
:param event_path: 被修改文件的 Path 对象。
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
return None
|
||||
|
||||
try:
|
||||
plugin_dir_name = event_path.relative_to(plugins_root).parts[0]
|
||||
plugin_dir = plugins_root / plugin_dir_name
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
return None
|
||||
|
||||
# 读取 __init__.py 文件,查找插件主类名
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# 遍历AST,查找继承自 _PluginBase 的类
|
||||
for node in ast.walk(tree):
|
||||
# 检查节点是否为类定义
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# 遍历该类的所有基类
|
||||
for base in node.bases:
|
||||
# 检查基类是否是我们寻找的 _PluginBase
|
||||
# ast.Name 用于处理简单的基类名
|
||||
if isinstance(base, ast.Name) and base.id == '_PluginBase':
|
||||
# 返回这个类的名字
|
||||
return node.name
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -366,25 +419,55 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
self.stop(plugin_id)
|
||||
|
||||
# 从模块列表中移除插件
|
||||
from sys import modules
|
||||
try:
|
||||
del modules[f"app.plugins.{plugin_id.lower()}"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def reload_plugin(self, plugin_id: str):
|
||||
"""
|
||||
将一个插件重新加载到内存
|
||||
:param plugin_id: 插件ID
|
||||
"""
|
||||
# 先移除
|
||||
# 先移除插件实例
|
||||
self.stop(plugin_id)
|
||||
# 重新加载
|
||||
self.start(plugin_id)
|
||||
# 广播事件
|
||||
eventmanager.send_event(EventType.PluginReload, data={"plugin_id": plugin_id})
|
||||
|
||||
@staticmethod
|
||||
def _clear_plugin_modules(plugin_id: Optional[str] = None):
|
||||
"""
|
||||
清除插件及其所有子模块的缓存
|
||||
:param plugin_id: 插件ID
|
||||
"""
|
||||
|
||||
# 构建插件模块前缀
|
||||
if plugin_id:
|
||||
plugin_module_prefix = f"app.plugins.{plugin_id.lower()}"
|
||||
else:
|
||||
plugin_module_prefix = "app.plugins"
|
||||
|
||||
# 收集需要删除的模块名(创建模块名列表的副本以避免迭代时修改字典)
|
||||
modules_to_remove = []
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if module_name == plugin_module_prefix or module_name.startswith(plugin_module_prefix + "."):
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
# 删除模块
|
||||
for module_name in modules_to_remove:
|
||||
try:
|
||||
del sys.modules[module_name]
|
||||
logger.debug(f"已清除插件模块缓存:{module_name}")
|
||||
except KeyError:
|
||||
# 模块可能已经被删除
|
||||
pass
|
||||
|
||||
importlib.invalidate_caches()
|
||||
logger.debug("已清除查找器的缓存")
|
||||
|
||||
if plugin_id:
|
||||
if modules_to_remove:
|
||||
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
|
||||
else:
|
||||
logger.debug(f"插件 {plugin_id} 没有找到需要清除的模块缓存")
|
||||
|
||||
def sync(self) -> List[str]:
|
||||
"""
|
||||
安装本地不存在或需要更新的插件
|
||||
@@ -662,6 +745,36 @@ class PluginManager(metaclass=Singleton):
|
||||
logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}")
|
||||
return ret_actions
|
||||
|
||||
def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
[{
|
||||
"plugin_id": "插件ID",
|
||||
"plugin_name": "插件名称",
|
||||
"tools": [ToolClass1, ToolClass2, ...]
|
||||
}]
|
||||
"""
|
||||
ret_tools = []
|
||||
# 创建字典快照避免并发修改
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if pid and pid != plugin_id:
|
||||
continue
|
||||
if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools):
|
||||
try:
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
tools = plugin.get_agent_tools()
|
||||
if tools:
|
||||
ret_tools.append({
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}")
|
||||
return ret_tools
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str:
|
||||
"""
|
||||
@@ -801,6 +914,25 @@ class PluginManager(metaclass=Singleton):
|
||||
return None
|
||||
return getattr(plugin, method)(*args, **kwargs)
|
||||
|
||||
async def async_run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
异步运行插件方法
|
||||
:param pid: 插件ID
|
||||
:param method: 方法名
|
||||
:param args: 参数
|
||||
:param kwargs: 关键字参数
|
||||
"""
|
||||
plugin = self._running_plugins.get(pid)
|
||||
if not plugin:
|
||||
return None
|
||||
if not hasattr(plugin, method):
|
||||
return None
|
||||
method_func = getattr(plugin, method)
|
||||
if asyncio.iscoroutinefunction(method_func):
|
||||
return await method_func(*args, **kwargs)
|
||||
else:
|
||||
return method_func(*args, **kwargs)
|
||||
|
||||
def get_plugin_ids(self) -> List[str]:
|
||||
"""
|
||||
获取所有插件ID
|
||||
@@ -820,8 +952,6 @@ class PluginManager(metaclass=Singleton):
|
||||
if not settings.PLUGIN_MARKET:
|
||||
return []
|
||||
|
||||
# 返回值
|
||||
all_plugins = []
|
||||
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
|
||||
higher_version_plugins = []
|
||||
# 用于存储 v1 版本插件
|
||||
@@ -854,25 +984,7 @@ class PluginManager(metaclass=Singleton):
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
# 优先处理高版本插件
|
||||
all_plugins.extend(higher_version_plugins)
|
||||
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
|
||||
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
|
||||
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
|
||||
# 去重
|
||||
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
|
||||
# 所有插件按 repo 在设置中的顺序排序
|
||||
all_plugins.sort(
|
||||
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
|
||||
)
|
||||
# 相同 ID 的插件保留版本号最大的版本
|
||||
max_versions = {}
|
||||
for p in all_plugins:
|
||||
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
|
||||
max_versions[p.id] = p.plugin_version
|
||||
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
|
||||
logger.info(f"共获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
|
||||
def get_local_plugins(self) -> List[schemas.Plugin]:
|
||||
"""
|
||||
@@ -994,7 +1106,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version, force)
|
||||
with fresh(force):
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
@@ -1002,81 +1115,217 @@ class PluginManager(metaclass=Singleton):
|
||||
ret_plugins = []
|
||||
add_time = len(online_plugins)
|
||||
for pid, plugin_info in online_plugins.items():
|
||||
# 如 package_version 为空,则需要判断插件是否兼容当前版本
|
||||
if not package_version:
|
||||
if plugin_info.get(settings.VERSION_FLAG) is not True:
|
||||
# 插件当前版本不兼容
|
||||
continue
|
||||
# 运行状插件
|
||||
plugin_obj = self._running_plugins.get(pid)
|
||||
# 非运行态插件
|
||||
plugin_static = self._plugins.get(pid)
|
||||
# 基本属性
|
||||
plugin = schemas.Plugin()
|
||||
# ID
|
||||
plugin.id = pid
|
||||
# 安装状态
|
||||
if pid in installed_apps and plugin_static:
|
||||
plugin.installed = True
|
||||
else:
|
||||
plugin.installed = False
|
||||
# 是否有新版本
|
||||
plugin.has_update = False
|
||||
if plugin_static:
|
||||
installed_version = getattr(plugin_static, "plugin_version")
|
||||
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
|
||||
# 需要更新
|
||||
plugin.has_update = True
|
||||
# 运行状态
|
||||
if plugin_obj and hasattr(plugin_obj, "get_state"):
|
||||
try:
|
||||
state = plugin_obj.get_state()
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
|
||||
state = False
|
||||
plugin.state = state
|
||||
else:
|
||||
plugin.state = False
|
||||
# 是否有详情页面
|
||||
plugin.has_page = False
|
||||
if plugin_obj and hasattr(plugin_obj, "get_page"):
|
||||
if ObjectUtils.check_method(plugin_obj.get_page):
|
||||
plugin.has_page = True
|
||||
# 公钥
|
||||
if plugin_info.get("key"):
|
||||
plugin.plugin_public_key = plugin_info.get("key")
|
||||
# 权限
|
||||
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
|
||||
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
|
||||
if plugin:
|
||||
ret_plugins.append(plugin)
|
||||
add_time -= 1
|
||||
|
||||
return ret_plugins
|
||||
|
||||
@staticmethod
|
||||
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
|
||||
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
|
||||
"""
|
||||
处理插件列表:合并、去重、排序、保留最高版本
|
||||
:param higher_version_plugins: 高版本插件列表
|
||||
:param base_version_plugins: 基础版本插件列表
|
||||
:return: 处理后的插件列表
|
||||
"""
|
||||
# 优先处理高版本插件
|
||||
all_plugins = []
|
||||
all_plugins.extend(higher_version_plugins)
|
||||
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
|
||||
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
|
||||
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
|
||||
# 去重
|
||||
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
|
||||
# 所有插件按 repo 在设置中的顺序排序
|
||||
all_plugins.sort(
|
||||
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
|
||||
)
|
||||
# 相同 ID 的插件保留版本号最大的版本
|
||||
max_versions = {}
|
||||
for p in all_plugins:
|
||||
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
|
||||
max_versions[p.id] = p.plugin_version
|
||||
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
|
||||
logger.info(f"共获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
|
||||
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
|
||||
installed_apps: List[str], add_time: int,
|
||||
package_version: Optional[str] = None) -> Optional[schemas.Plugin]:
|
||||
"""
|
||||
处理单个插件信息,创建 schemas.Plugin 对象
|
||||
:param pid: 插件ID
|
||||
:param plugin_info: 插件信息字典
|
||||
:param market: 市场URL
|
||||
:param installed_apps: 已安装插件列表
|
||||
:param add_time: 添加顺序
|
||||
:param package_version: 包版本
|
||||
:return: 创建的插件对象,如果验证失败返回None
|
||||
"""
|
||||
if not isinstance(plugin_info, dict):
|
||||
return None
|
||||
|
||||
# 如 package_version 为空,则需要判断插件是否兼容当前版本
|
||||
if not package_version:
|
||||
if plugin_info.get(settings.VERSION_FLAG) is not True:
|
||||
# 插件当前版本不兼容
|
||||
return None
|
||||
|
||||
# 运行状插件
|
||||
plugin_obj = self._running_plugins.get(pid)
|
||||
# 非运行态插件
|
||||
plugin_static = self._plugins.get(pid)
|
||||
# 基本属性
|
||||
plugin = schemas.Plugin()
|
||||
# ID
|
||||
plugin.id = pid
|
||||
# 安装状态
|
||||
if pid in installed_apps and plugin_static:
|
||||
plugin.installed = True
|
||||
else:
|
||||
plugin.installed = False
|
||||
# 是否有新版本
|
||||
plugin.has_update = False
|
||||
if plugin_static:
|
||||
installed_version = getattr(plugin_static, "plugin_version")
|
||||
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
|
||||
# 需要更新
|
||||
plugin.has_update = True
|
||||
# 运行状态
|
||||
if plugin_obj and hasattr(plugin_obj, "get_state"):
|
||||
try:
|
||||
state = plugin_obj.get_state()
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
|
||||
state = False
|
||||
plugin.state = state
|
||||
else:
|
||||
plugin.state = False
|
||||
# 是否有详情页面
|
||||
plugin.has_page = False
|
||||
if plugin_obj and hasattr(plugin_obj, "get_page"):
|
||||
if ObjectUtils.check_method(plugin_obj.get_page):
|
||||
plugin.has_page = True
|
||||
# 公钥
|
||||
if plugin_info.get("key"):
|
||||
plugin.plugin_public_key = plugin_info.get("key")
|
||||
# 权限
|
||||
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
|
||||
return None
|
||||
# 名称
|
||||
if plugin_info.get("name"):
|
||||
plugin.plugin_name = plugin_info.get("name")
|
||||
# 描述
|
||||
if plugin_info.get("description"):
|
||||
plugin.plugin_desc = plugin_info.get("description")
|
||||
# 版本
|
||||
if plugin_info.get("version"):
|
||||
plugin.plugin_version = plugin_info.get("version")
|
||||
# 图标
|
||||
if plugin_info.get("icon"):
|
||||
plugin.plugin_icon = plugin_info.get("icon")
|
||||
# 标签
|
||||
if plugin_info.get("labels"):
|
||||
plugin.plugin_label = plugin_info.get("labels")
|
||||
# 作者
|
||||
if plugin_info.get("author"):
|
||||
plugin.plugin_author = plugin_info.get("author")
|
||||
# 更新历史
|
||||
if plugin_info.get("history"):
|
||||
plugin.history = plugin_info.get("history")
|
||||
# 仓库链接
|
||||
plugin.repo_url = market
|
||||
# 本地标志
|
||||
plugin.is_local = False
|
||||
# 添加顺序
|
||||
plugin.add_time = add_time
|
||||
|
||||
return plugin
|
||||
|
||||
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
|
||||
"""
|
||||
异步获取所有在线插件信息
|
||||
:param force: 是否强制刷新(忽略缓存)
|
||||
"""
|
||||
if not settings.PLUGIN_MARKET:
|
||||
return []
|
||||
|
||||
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
|
||||
higher_version_plugins = []
|
||||
# 用于存储 v1 版本插件
|
||||
base_version_plugins = []
|
||||
|
||||
# 使用异步并发获取线上插件
|
||||
import asyncio
|
||||
tasks = []
|
||||
task_to_version = {}
|
||||
|
||||
for m in settings.PLUGIN_MARKET.split(","):
|
||||
if not m:
|
||||
continue
|
||||
# 名称
|
||||
if plugin_info.get("name"):
|
||||
plugin.plugin_name = plugin_info.get("name")
|
||||
# 描述
|
||||
if plugin_info.get("description"):
|
||||
plugin.plugin_desc = plugin_info.get("description")
|
||||
# 版本
|
||||
if plugin_info.get("version"):
|
||||
plugin.plugin_version = plugin_info.get("version")
|
||||
# 图标
|
||||
if plugin_info.get("icon"):
|
||||
plugin.plugin_icon = plugin_info.get("icon")
|
||||
# 标签
|
||||
if plugin_info.get("labels"):
|
||||
plugin.plugin_label = plugin_info.get("labels")
|
||||
# 作者
|
||||
if plugin_info.get("author"):
|
||||
plugin.plugin_author = plugin_info.get("author")
|
||||
# 更新历史
|
||||
if plugin_info.get("history"):
|
||||
plugin.history = plugin_info.get("history")
|
||||
# 仓库链接
|
||||
plugin.repo_url = market
|
||||
# 本地标志
|
||||
plugin.is_local = False
|
||||
# 添加顺序
|
||||
plugin.add_time = add_time
|
||||
# 汇总
|
||||
ret_plugins.append(plugin)
|
||||
# 创建任务获取 v1 版本插件
|
||||
base_task = asyncio.create_task(self.async_get_plugins_from_market(m, None, force))
|
||||
tasks.append(base_task)
|
||||
task_to_version[base_task] = "base_version"
|
||||
|
||||
# 创建任务获取高版本插件(如 v2、v3)
|
||||
if settings.VERSION_FLAG:
|
||||
higher_version_task = asyncio.create_task(
|
||||
self.async_get_plugins_from_market(m, settings.VERSION_FLAG, force))
|
||||
tasks.append(higher_version_task)
|
||||
task_to_version[higher_version_task] = "higher_version"
|
||||
|
||||
# 并发执行所有任务
|
||||
if tasks:
|
||||
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, result in enumerate(completed_tasks):
|
||||
task = tasks[i]
|
||||
version = task_to_version[task]
|
||||
|
||||
# 检查是否有异常
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"获取插件市场数据失败:{str(result)}")
|
||||
continue
|
||||
|
||||
plugins = result
|
||||
if plugins:
|
||||
if version == "higher_version":
|
||||
higher_version_plugins.extend(plugins) # 收集高版本插件
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
|
||||
async def async_get_plugins_from_market(self, market: str,
|
||||
package_version: Optional[str] = None,
|
||||
force: bool = False) -> Optional[List[schemas.Plugin]]:
|
||||
"""
|
||||
异步从指定的市场获取插件信息
|
||||
:param market: 市场的 URL 或标识
|
||||
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
|
||||
:param force: 是否强制刷新(忽略缓存)
|
||||
:return: 返回插件的列表,若获取失败返回 []
|
||||
"""
|
||||
if not market:
|
||||
return []
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
async with async_fresh(force):
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
return []
|
||||
ret_plugins = []
|
||||
add_time = len(online_plugins)
|
||||
for pid, plugin_info in online_plugins.items():
|
||||
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
|
||||
if plugin:
|
||||
ret_plugins.append(plugin)
|
||||
add_time -= 1
|
||||
|
||||
return ret_plugins
|
||||
@@ -1416,8 +1665,9 @@ class PluginManager(metaclass=Singleton):
|
||||
content = f.read()
|
||||
|
||||
# 替换CSS中可能的类名引用
|
||||
content = content.replace(original_class_name.lower(), clone_class_name.lower())
|
||||
content = content.replace(original_class_name, clone_class_name)
|
||||
content = content.replace(original_class_name.lower(),
|
||||
clone_class_name.lower()).replace(original_class_name,
|
||||
clone_class_name)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
@@ -252,19 +252,19 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
"""
|
||||
使用 API Token 进行身份认证
|
||||
:param token: API Token,从 URL 查询参数中获取
|
||||
:param token: API Token,从 URL 查询参数中获取 token=xxx
|
||||
:return: 返回校验通过的 API Token
|
||||
"""
|
||||
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
|
||||
return __verify_key(token, settings.API_TOKEN, "token")
|
||||
|
||||
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数或请求头中获取
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx
|
||||
:return: 返回校验通过的 API Key
|
||||
"""
|
||||
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
|
||||
return __verify_key(apikey, settings.API_TOKEN, "apikey")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
from time import sleep
|
||||
from typing import Dict, Any, Tuple, List
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.helper.module import ModuleHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Action, ActionContext
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
工作流管理器
|
||||
"""
|
||||
|
||||
# 所有动作定义
|
||||
_actions: Dict[str, Any] = {}
|
||||
|
||||
def __init__(self):
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
初始化
|
||||
"""
|
||||
|
||||
def filter_func(obj: Any):
|
||||
"""
|
||||
过滤函数,确保只加载新定义的类
|
||||
"""
|
||||
if not isinstance(obj, type):
|
||||
return False
|
||||
if not hasattr(obj, 'execute') or not hasattr(obj, "name"):
|
||||
return False
|
||||
if obj.__name__ == "BaseAction":
|
||||
return False
|
||||
return obj.__module__.startswith("app.actions")
|
||||
|
||||
# 加载所有动作
|
||||
self._actions = {}
|
||||
actions = ModuleHelper.load(
|
||||
"app.actions",
|
||||
filter_func=lambda _, obj: filter_func(obj)
|
||||
)
|
||||
for action in actions:
|
||||
logger.debug(f"加载动作: {action.__name__}")
|
||||
try:
|
||||
self._actions[action.__name__] = action
|
||||
except Exception as err:
|
||||
logger.error(f"加载动作失败: {action.__name__} - {err}")
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止
|
||||
"""
|
||||
pass
|
||||
|
||||
def excute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
"""
|
||||
执行工作流动作
|
||||
"""
|
||||
if not context:
|
||||
context = ActionContext()
|
||||
if action.type in self._actions:
|
||||
# 实例化之前,清理掉类对象的数据
|
||||
|
||||
# 实例化
|
||||
action_obj = self._actions[action.type](action.id)
|
||||
# 执行
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
try:
|
||||
result_context = action_obj.execute(workflow_id, action.data, context)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return False, f"{err}", context
|
||||
loop = action.data.get("loop")
|
||||
loop_interval = action.data.get("loop_interval")
|
||||
if loop and loop_interval:
|
||||
while not action_obj.done:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
# 等待
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
sleep(loop_interval)
|
||||
# 执行
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
result_context = action_obj.execute(workflow_id, action.data, result_context)
|
||||
if action_obj.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
else:
|
||||
logger.error(f"{action.name} 执行失败!")
|
||||
return action_obj.success, action_obj.message, result_context
|
||||
else:
|
||||
logger.error(f"未找到动作: {action.type} - {action.name}")
|
||||
return False, " ", context
|
||||
|
||||
def list_actions(self) -> List[dict]:
|
||||
"""
|
||||
获取所有动作
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": key,
|
||||
"name": action.name,
|
||||
"description": action.description,
|
||||
"data": {
|
||||
"label": action.name,
|
||||
**action.data
|
||||
}
|
||||
} for key, action in self._actions.items()
|
||||
]
|
||||
@@ -1,45 +1,191 @@
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple
|
||||
import asyncio
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Union
|
||||
|
||||
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text
|
||||
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete, Column, Integer, \
|
||||
Sequence, Identity
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 根据池类型设置 poolclass 和相关参数
|
||||
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
|
||||
connect_args = {
|
||||
"timeout": settings.DB_TIMEOUT
|
||||
}
|
||||
# 启用 WAL 模式时的额外配置
|
||||
if settings.DB_WAL_ENABLE:
|
||||
connect_args["check_same_thread"] = False
|
||||
db_kwargs = {
|
||||
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": pool_class,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": connect_args
|
||||
}
|
||||
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
|
||||
if pool_class == QueuePool:
|
||||
db_kwargs.update({
|
||||
"pool_size": settings.CONF.dbpool,
|
||||
"pool_timeout": settings.DB_POOL_TIMEOUT,
|
||||
"max_overflow": settings.CONF.dbpooloverflow
|
||||
})
|
||||
# 创建数据库引擎
|
||||
Engine = create_engine(**db_kwargs)
|
||||
# 根据配置设置日志模式
|
||||
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
with Engine.connect() as connection:
|
||||
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
|
||||
print(f"Database journal mode set to: {current_mode}")
|
||||
|
||||
# 会话工厂
|
||||
def get_id_column():
|
||||
"""
|
||||
根据数据库类型返回合适的ID列定义
|
||||
"""
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
# PostgreSQL使用SERIAL类型,让数据库自动处理序列
|
||||
return Column(Integer, Identity(start=1, cycle=True), primary_key=True, index=True)
|
||||
else:
|
||||
# SQLite使用Sequence
|
||||
return Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
|
||||
|
||||
def _get_database_engine(is_async: bool = False):
|
||||
"""
|
||||
获取数据库连接参数并设置WAL模式
|
||||
:param is_async: 是否创建异步引擎,True - 异步引擎, False - 同步引擎
|
||||
:return: 返回对应的数据库引擎
|
||||
"""
|
||||
# 根据数据库类型选择连接方式
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
return _get_postgresql_engine(is_async)
|
||||
else:
|
||||
return _get_sqlite_engine(is_async)
|
||||
|
||||
|
||||
def _get_sqlite_engine(is_async: bool = False):
|
||||
"""
|
||||
获取SQLite数据库引擎
|
||||
"""
|
||||
# 连接参数
|
||||
_connect_args = {
|
||||
"timeout": settings.DB_TIMEOUT,
|
||||
}
|
||||
# 启用 WAL 模式时的额外配置
|
||||
if settings.DB_WAL_ENABLE:
|
||||
_connect_args["check_same_thread"] = False
|
||||
|
||||
# 创建同步引擎
|
||||
if not is_async:
|
||||
# 根据池类型设置 poolclass 和相关参数
|
||||
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
|
||||
|
||||
# 数据库参数
|
||||
_db_kwargs = {
|
||||
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": _pool_class,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
|
||||
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
|
||||
if _pool_class == QueuePool:
|
||||
_db_kwargs.update({
|
||||
"pool_size": settings.DB_SQLITE_POOL_SIZE,
|
||||
"pool_timeout": settings.DB_POOL_TIMEOUT,
|
||||
"max_overflow": settings.DB_SQLITE_MAX_OVERFLOW
|
||||
})
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(**_db_kwargs)
|
||||
|
||||
# 设置WAL模式
|
||||
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
with engine.connect() as connection:
|
||||
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
|
||||
print(f"SQLite database journal mode set to: {current_mode}")
|
||||
|
||||
return engine
|
||||
else:
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
"url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db",
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": NullPool,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
|
||||
# 设置WAL模式
|
||||
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
|
||||
async def set_async_wal_mode():
|
||||
"""
|
||||
设置异步引擎的WAL模式
|
||||
"""
|
||||
async with async_engine.connect() as _connection:
|
||||
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
|
||||
_current_mode = result.scalar()
|
||||
print(f"Async SQLite database journal mode set to: {_current_mode}")
|
||||
|
||||
try:
|
||||
asyncio.run(set_async_wal_mode())
|
||||
except Exception as e:
|
||||
print(f"Failed to set async SQLite WAL mode: {e}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
def _get_postgresql_engine(is_async: bool = False):
|
||||
"""
|
||||
获取PostgreSQL数据库引擎
|
||||
"""
|
||||
# 构建PostgreSQL连接URL
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
|
||||
# PostgreSQL连接参数
|
||||
_connect_args = {}
|
||||
|
||||
# 创建同步引擎
|
||||
if not is_async:
|
||||
# 根据池类型设置 poolclass 和相关参数
|
||||
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
|
||||
|
||||
# 数据库参数
|
||||
_db_kwargs = {
|
||||
"url": db_url,
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": _pool_class,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
|
||||
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
|
||||
if _pool_class == QueuePool:
|
||||
_db_kwargs.update({
|
||||
"pool_size": settings.DB_POSTGRESQL_POOL_SIZE,
|
||||
"pool_timeout": settings.DB_POOL_TIMEOUT,
|
||||
"max_overflow": settings.DB_POSTGRESQL_MAX_OVERFLOW
|
||||
})
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(**_db_kwargs)
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return engine
|
||||
else:
|
||||
# 构建异步PostgreSQL连接URL
|
||||
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
"url": async_db_url,
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": NullPool,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
# 同步数据库引擎
|
||||
Engine = _get_database_engine(is_async=False)
|
||||
|
||||
# 异步数据库引擎
|
||||
AsyncEngine = _get_database_engine(is_async=True)
|
||||
|
||||
# 同步会话工厂
|
||||
SessionFactory = sessionmaker(bind=Engine)
|
||||
|
||||
# 多线程全局使用的数据库会话
|
||||
# 异步会话工厂
|
||||
AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession)
|
||||
|
||||
# 同步多线程全局使用的数据库会话
|
||||
ScopedSession = scoped_session(SessionFactory)
|
||||
|
||||
|
||||
@@ -57,37 +203,32 @@ def get_db() -> Generator:
|
||||
db.close()
|
||||
|
||||
|
||||
def perform_checkpoint(mode: str = "PASSIVE"):
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库
|
||||
:param mode: checkpoint 模式,可选值包括 "PASSIVE"、"FULL"、"RESTART"、"TRUNCATE"
|
||||
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
|
||||
获取异步数据库会话,用于WEB请求
|
||||
:return: AsyncSession
|
||||
"""
|
||||
if not settings.DB_WAL_ENABLE:
|
||||
return
|
||||
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}
|
||||
if mode.upper() not in valid_modes:
|
||||
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}")
|
||||
try:
|
||||
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
|
||||
with Engine.connect() as conn:
|
||||
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
|
||||
except Exception as e:
|
||||
print(f"Error during WAL checkpoint: {e}")
|
||||
async with AsyncSessionFactory() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def close_database():
|
||||
async def close_database():
|
||||
"""
|
||||
关闭所有数据库连接并清理资源
|
||||
"""
|
||||
try:
|
||||
# 释放连接池,SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint
|
||||
Engine.dispose()
|
||||
except Exception as e:
|
||||
print(f"Error while disposing database connections: {e}")
|
||||
# 释放同步连接池
|
||||
Engine.dispose() # noqa
|
||||
# 释放异步连接池
|
||||
await AsyncEngine.dispose()
|
||||
except Exception as err:
|
||||
print(f"Error while disposing database connections: {err}")
|
||||
|
||||
|
||||
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
"""
|
||||
从参数中获取数据库Session对象
|
||||
"""
|
||||
@@ -105,7 +246,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
return db
|
||||
|
||||
|
||||
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
|
||||
def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]:
|
||||
"""
|
||||
从参数中获取异步数据库AsyncSession对象
|
||||
"""
|
||||
db = None
|
||||
if args:
|
||||
for arg in args:
|
||||
if isinstance(arg, AsyncSession):
|
||||
db = arg
|
||||
break
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, AsyncSession):
|
||||
db = value
|
||||
break
|
||||
return db
|
||||
|
||||
|
||||
def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
|
||||
"""
|
||||
更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数
|
||||
"""
|
||||
@@ -119,6 +278,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]:
|
||||
"""
|
||||
更新参数中的异步数据库AsyncSession对象,关键字传参时更新db的值,否则更新第1或第2个参数
|
||||
"""
|
||||
if kwargs and 'db' in kwargs:
|
||||
kwargs['db'] = db
|
||||
elif args:
|
||||
if args[0] is None:
|
||||
args = (db, *args[1:])
|
||||
else:
|
||||
args = (args[0], db, *args[2:])
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def db_update(func):
|
||||
"""
|
||||
数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||||
@@ -128,14 +301,14 @@ def db_update(func):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取数据库会话
|
||||
db = get_args_db(args, kwargs)
|
||||
db = _get_args_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到数据库会话,创建一个
|
||||
db = ScopedSession()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的数据库会话
|
||||
args, kwargs = update_args_db(args, kwargs, db)
|
||||
args, kwargs = _update_args_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = func(*args, **kwargs)
|
||||
@@ -154,6 +327,41 @@ def db_update(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def async_db_update(func):
|
||||
"""
|
||||
异步数据库更新类操作装饰器,第一个参数必须是异步数据库会话或存在db参数
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取异步数据库会话
|
||||
db = _get_args_async_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到异步数据库会话,创建一个
|
||||
db = AsyncSessionFactory()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的异步数据库会话
|
||||
args, kwargs = _update_args_async_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
except Exception as err:
|
||||
# 回滚事务
|
||||
await db.rollback()
|
||||
raise err
|
||||
finally:
|
||||
# 关闭数据库会话
|
||||
if _close_db:
|
||||
await db.close()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def db_query(func):
|
||||
"""
|
||||
数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||||
@@ -164,14 +372,14 @@ def db_query(func):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取数据库会话
|
||||
db = get_args_db(args, kwargs)
|
||||
db = _get_args_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到数据库会话,创建一个
|
||||
db = ScopedSession()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的数据库会话
|
||||
args, kwargs = update_args_db(args, kwargs, db)
|
||||
args, kwargs = _update_args_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = func(*args, **kwargs)
|
||||
@@ -186,6 +394,38 @@ def db_query(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def async_db_query(func):
|
||||
"""
|
||||
异步数据库查询操作装饰器,第一个参数必须是异步数据库会话或存在db参数
|
||||
注意:db.query列表数据时,需要转换为list返回
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取异步数据库会话
|
||||
db = _get_args_async_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到异步数据库会话,创建一个
|
||||
db = AsyncSessionFactory()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的异步数据库会话
|
||||
args, kwargs = _update_args_async_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
raise err
|
||||
finally:
|
||||
# 关闭数据库会话
|
||||
if _close_db:
|
||||
await db.close()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@as_declarative()
|
||||
class Base:
|
||||
id: Any
|
||||
@@ -195,11 +435,23 @@ class Base:
|
||||
def create(self, db: Session):
|
||||
db.add(self)
|
||||
|
||||
@async_db_update
|
||||
async def async_create(self, db: AsyncSession):
|
||||
db.add(self)
|
||||
await db.flush()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get(cls, db: Session, rid: int) -> Self:
|
||||
return db.query(cls).filter(and_(cls.id == rid)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get(cls, db: AsyncSession, rid: int) -> Self:
|
||||
result = await db.execute(select(cls).where(and_(cls.id == rid)))
|
||||
return result.scalars().first()
|
||||
|
||||
@db_update
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
@@ -208,23 +460,50 @@ class Base:
|
||||
if inspect(self).detached:
|
||||
db.add(self)
|
||||
|
||||
@async_db_update
|
||||
async def async_update(self, db: AsyncSession, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
db.add(self)
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete(cls, db: Session, rid):
|
||||
db.query(cls).filter(and_(cls.id == rid)).delete()
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete(cls, db: AsyncSession, rid):
|
||||
result = await db.execute(select(cls).where(and_(cls.id == rid)))
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
await db.delete(user)
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def truncate(cls, db: Session):
|
||||
db.query(cls).delete()
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_truncate(cls, db: AsyncSession):
|
||||
await db.execute(delete(cls))
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list(cls, db: Session) -> List[Self]:
|
||||
return db.query(cls).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list(cls, db: AsyncSession) -> Sequence[Self]:
|
||||
result = await db.execute(select(cls))
|
||||
return result.scalars().all()
|
||||
|
||||
def to_dict(self):
|
||||
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
|
||||
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
|
||||
|
||||
@declared_attr
|
||||
def __tablename__(self) -> str:
|
||||
@@ -236,5 +515,5 @@ class DbOper:
|
||||
数据库操作基类
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
def __init__(self, db: Union[Session, AsyncSession] = None):
|
||||
self._db = db
|
||||
|
||||
@@ -18,12 +18,22 @@ def update_db():
|
||||
"""
|
||||
更新数据库
|
||||
"""
|
||||
db_location = settings.CONFIG_PATH / 'user.db'
|
||||
script_location = settings.ROOT_PATH / 'database'
|
||||
try:
|
||||
alembic_cfg = Config()
|
||||
alembic_cfg.set_main_option('script_location', str(script_location))
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}")
|
||||
|
||||
# 根据数据库类型设置不同的URL
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_location = settings.CONFIG_PATH / 'user.db'
|
||||
db_url = f"sqlite:///{db_location}"
|
||||
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', db_url)
|
||||
upgrade(alembic_cfg, 'head')
|
||||
except Exception as e:
|
||||
logger.error(f'数据库更新失败:{str(e)}')
|
||||
|
||||
@@ -58,6 +58,32 @@ class MediaServerOper(DbOper):
|
||||
return None
|
||||
return item
|
||||
|
||||
async def async_exists(self, **kwargs) -> Optional[MediaServerItem]:
|
||||
"""
|
||||
异步判断媒体服务器数据是否存在
|
||||
"""
|
||||
if kwargs.get("tmdbid"):
|
||||
# 优先按TMDBID查
|
||||
item = await MediaServerItem.async_exist_by_tmdbid(self._db, tmdbid=kwargs.get("tmdbid"),
|
||||
mtype=kwargs.get("mtype"))
|
||||
elif kwargs.get("title"):
|
||||
# 按标题、类型、年份查
|
||||
item = await MediaServerItem.async_exists_by_title(self._db, title=kwargs.get("title"),
|
||||
mtype=kwargs.get("mtype"), year=kwargs.get("year"))
|
||||
else:
|
||||
return None
|
||||
if not item:
|
||||
return None
|
||||
|
||||
if kwargs.get("season"):
|
||||
# 判断季是否存在
|
||||
if not item.seasoninfo:
|
||||
return None
|
||||
seasoninfo = item.seasoninfo or {}
|
||||
if kwargs.get("season") not in seasoninfo.keys():
|
||||
return None
|
||||
return item
|
||||
|
||||
def get_item_id(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取媒体服务器数据ID
|
||||
@@ -66,3 +92,12 @@ class MediaServerOper(DbOper):
|
||||
if not item:
|
||||
return None
|
||||
return str(item.item_id)
|
||||
|
||||
async def async_get_item_id(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
异步获取媒体服务器数据ID
|
||||
"""
|
||||
item = await self.async_exists(**kwargs)
|
||||
if not item:
|
||||
return None
|
||||
return str(item.item_id)
|
||||
|
||||
@@ -29,7 +29,7 @@ class MessageOper(DbOper):
|
||||
note: Union[list, dict] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
新增媒体服务器数据
|
||||
新增消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 来源
|
||||
:param mtype: 消息类型
|
||||
@@ -57,11 +57,47 @@ class MessageOper(DbOper):
|
||||
|
||||
# 从kwargs中去掉Message中没有的字段
|
||||
for k in list(kwargs.keys()):
|
||||
if k not in Message.__table__.columns.keys(): # noqa
|
||||
if k not in Message.__table__.columns.keys(): # noqa
|
||||
kwargs.pop(k)
|
||||
|
||||
Message(**kwargs).create(self._db)
|
||||
|
||||
async def async_add(self,
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
mtype: NotificationType = None,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
action: Optional[int] = 1,
|
||||
note: Union[list, dict] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
异步新增消息
|
||||
"""
|
||||
kwargs.update({
|
||||
"channel": channel.value if channel else '',
|
||||
"source": source,
|
||||
"mtype": mtype.value if mtype else '',
|
||||
"title": title,
|
||||
"text": text,
|
||||
"image": image,
|
||||
"link": link,
|
||||
"userid": userid,
|
||||
"action": action,
|
||||
"reg_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
"note": note or {}
|
||||
})
|
||||
|
||||
# 从kwargs中去掉Message中没有的字段
|
||||
for k in list(kwargs.keys()):
|
||||
if k not in Message.__table__.columns.keys(): # noqa
|
||||
kwargs.pop(k)
|
||||
|
||||
await Message(**kwargs).async_create(self._db)
|
||||
|
||||
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
|
||||
"""
|
||||
获取媒体服务器数据ID
|
||||
|
||||
@@ -9,4 +9,3 @@ from .transferhistory import TransferHistory
|
||||
from .user import User
|
||||
from .userconfig import UserConfig
|
||||
from .workflow import Workflow
|
||||
from .userrequest import UserRequest
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Column, Integer, String, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, get_id_column, Base, async_db_query
|
||||
|
||||
|
||||
class DownloadHistory(Base):
|
||||
"""
|
||||
下载历史记录
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 保存路径
|
||||
path = Column(String, nullable=False, index=True)
|
||||
# 类型 电影/电视剧
|
||||
@@ -55,35 +56,43 @@ class DownloadHistory(Base):
|
||||
# 剧集组
|
||||
episode_group = Column(String)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_hash(db: Session, download_hash: str):
|
||||
def get_by_hash(cls, db: Session, download_hash: str):
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
|
||||
DownloadHistory.date.desc()
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str):
|
||||
def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
|
||||
if tmdbid:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
|
||||
elif doubanid:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
result = await db.execute(
|
||||
select(cls).offset((page - 1) * count).limit(count)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_path(db: Session, path: str):
|
||||
def get_by_path(cls, db: Session, path: str):
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_last_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
|
||||
def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
|
||||
year: Optional[str] = None, season: Optional[str] = None,
|
||||
episode: Optional[str] = None, tmdbid: Optional[int] = None):
|
||||
"""
|
||||
@@ -133,9 +142,9 @@ class DownloadHistory(Base):
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_user_date(db: Session, date: str, username: Optional[str] = None):
|
||||
def list_by_user_date(cls, db: Session, date: str, username: Optional[str] = None):
|
||||
"""
|
||||
查询某用户某时间之后的下载历史
|
||||
"""
|
||||
@@ -147,9 +156,9 @@ class DownloadHistory(Base):
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_date(db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
|
||||
def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
|
||||
"""
|
||||
查询某时间之后的下载历史
|
||||
"""
|
||||
@@ -165,9 +174,9 @@ class DownloadHistory(Base):
|
||||
DownloadHistory.tmdbid == tmdbid).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_type(db: Session, mtype: str, days: int):
|
||||
def list_by_type(cls, db: Session, mtype: str, days: int):
|
||||
return db.query(DownloadHistory) \
|
||||
.filter(DownloadHistory.type == mtype,
|
||||
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
|
||||
@@ -179,7 +188,7 @@ class DownloadFiles(Base):
|
||||
"""
|
||||
下载文件记录
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 下载器
|
||||
downloader = Column(String)
|
||||
# 下载任务Hash
|
||||
@@ -195,35 +204,35 @@ class DownloadFiles(Base):
|
||||
# 状态 0-已删除 1-正常
|
||||
state = Column(Integer, nullable=False, default=1)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_hash(db: Session, download_hash: str, state: Optional[int] = None):
|
||||
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
|
||||
if state:
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash,
|
||||
DownloadFiles.state == state).all()
|
||||
return db.query(cls).filter(cls.download_hash == download_hash,
|
||||
cls.state == state).all()
|
||||
else:
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all()
|
||||
return db.query(cls).filter(cls.download_hash == download_hash).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_fullpath(db: Session, fullpath: str, all_files: bool = False):
|
||||
def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
|
||||
if not all_files:
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
|
||||
DownloadFiles.id.desc()).first()
|
||||
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
|
||||
cls.id.desc()).first()
|
||||
else:
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
|
||||
DownloadFiles.id.desc()).all()
|
||||
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
|
||||
cls.id.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_savepath(db: Session, savepath: str):
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
|
||||
def get_by_savepath(cls, db: Session, savepath: str):
|
||||
return db.query(cls).filter(cls.savepath == savepath).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_fullpath(db: Session, fullpath: str):
|
||||
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
|
||||
DownloadFiles.state == 1).update(
|
||||
def delete_by_fullpath(cls, db: Session, fullpath: str):
|
||||
db.query(cls).filter(cls.fullpath == fullpath,
|
||||
cls.state == 1).update(
|
||||
{
|
||||
"state": 0
|
||||
}
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Column, Integer, String, JSON
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, get_id_column, async_db_query, Base
|
||||
|
||||
|
||||
class MediaServerItem(Base):
|
||||
"""
|
||||
媒体服务器媒体条目表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 服务器类型
|
||||
server = Column(String)
|
||||
# 媒体库ID
|
||||
@@ -41,28 +43,49 @@ class MediaServerItem(Base):
|
||||
# 同步时间
|
||||
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_itemid(db: Session, item_id: str):
|
||||
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
|
||||
def get_by_itemid(cls, db: Session, item_id: str):
|
||||
return db.query(cls).filter(cls.item_id == item_id).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_update
|
||||
def empty(db: Session, server: Optional[str] = None):
|
||||
def empty(cls, db: Session, server: Optional[str] = None):
|
||||
if server is None:
|
||||
db.query(MediaServerItem).delete()
|
||||
db.query(cls).delete()
|
||||
else:
|
||||
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
|
||||
db.query(cls).filter(cls.server == server).delete()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
|
||||
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid,
|
||||
MediaServerItem.item_type == mtype).first()
|
||||
def exist_by_tmdbid(cls, db: Session, tmdbid: int, mtype: str):
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.item_type == mtype).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists_by_title(db: Session, title: str, mtype: str, year: str):
|
||||
return db.query(MediaServerItem).filter(MediaServerItem.title == title,
|
||||
MediaServerItem.item_type == mtype,
|
||||
MediaServerItem.year == str(year)).first()
|
||||
def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_itemid(cls, db: AsyncSession, item_id: str):
|
||||
result = await db.execute(select(cls).filter(cls.item_id == item_id))
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exist_by_tmdbid(cls, db: AsyncSession, tmdbid: int, mtype: str):
|
||||
result = await db.execute(select(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.item_type == mtype))
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)))
|
||||
return result.scalars().first()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Column, Integer, String, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, Base
|
||||
from app.db import db_query, Base, get_id_column, async_db_query
|
||||
|
||||
|
||||
class Message(Base):
|
||||
"""
|
||||
消息表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 消息渠道
|
||||
channel = Column(String)
|
||||
# 消息来源
|
||||
@@ -34,7 +35,15 @@ class Message(Base):
|
||||
# 附件json
|
||||
note = Column(JSON)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
return db.query(Message).order_by(Message.reg_time.desc()).offset((page - 1) * count).limit(count).all()
|
||||
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
result = await db.execute(
|
||||
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
from sqlalchemy import Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Column, String, JSON
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, get_id_column, Base
|
||||
|
||||
|
||||
class PluginData(Base):
|
||||
"""
|
||||
插件数据表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
plugin_id = Column(String, nullable=False, index=True)
|
||||
key = Column(String, index=True, nullable=False)
|
||||
value = Column(JSON)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_plugin_data(db: Session, plugin_id: str):
|
||||
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
|
||||
def get_plugin_data(cls, db: Session, plugin_id: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str):
|
||||
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
|
||||
def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_update
|
||||
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
|
||||
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
|
||||
def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
|
||||
db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).delete()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_update
|
||||
def del_plugin_data(db: Session, plugin_id: str):
|
||||
db.query(PluginData).filter(PluginData.plugin_id == plugin_id).delete()
|
||||
def del_plugin_data(cls, db: Session, plugin_id: str):
|
||||
db.query(cls).filter(cls.plugin_id == plugin_id).delete()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
|
||||
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
|
||||
def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Boolean, Column, Integer, String, JSON, select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, Base, async_db_query, async_db_update, get_id_column
|
||||
|
||||
|
||||
class Site(Base):
|
||||
"""
|
||||
站点表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 站点名
|
||||
name = Column(String, nullable=False)
|
||||
# 域名Key
|
||||
@@ -54,27 +55,56 @@ class Site(Base):
|
||||
# 下载器
|
||||
downloader = Column(String)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_domain(db: Session, domain: str):
|
||||
return db.query(Site).filter(Site.domain == domain).first()
|
||||
def get_by_domain(cls, db: Session, domain: str):
|
||||
return db.query(cls).filter(cls.domain == domain).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_name(cls, db: AsyncSession, name: str):
|
||||
result = await db.execute(select(cls).where(cls.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_actives(db: Session):
|
||||
return db.query(Site).filter(Site.is_active == 1).all()
|
||||
def get_actives(cls, db: Session):
|
||||
return db.query(cls).filter(cls.is_active).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_actives(cls, db: AsyncSession):
|
||||
result = await db.execute(select(cls).where(cls.is_active))
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_order_by_pri(db: Session):
|
||||
return db.query(Site).order_by(Site.pri).all()
|
||||
def list_order_by_pri(cls, db: Session):
|
||||
return db.query(cls).order_by(cls.pri).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_order_by_pri(cls, db: AsyncSession):
|
||||
result = await db.execute(select(cls).order_by(cls.pri))
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_domains_by_ids(db: Session, ids: list):
|
||||
return [r[0] for r in db.query(Site.domain).filter(Site.id.in_(ids)).all()]
|
||||
def get_domains_by_ids(cls, db: Session, ids: list):
|
||||
return [r[0] for r in db.query(cls.domain).filter(cls.id.in_(ids)).all()]
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_update
|
||||
def reset(db: Session):
|
||||
db.query(Site).delete()
|
||||
def reset(cls, db: Session):
|
||||
db.query(cls).delete()
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_reset(cls, db: AsyncSession):
|
||||
await db.execute(delete(cls))
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from sqlalchemy import Column, Integer, String, Sequence
|
||||
from sqlalchemy import Column, String, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, Base
|
||||
from app.db import db_query, Base, get_id_column, async_db_query
|
||||
|
||||
|
||||
class SiteIcon(Base):
|
||||
"""
|
||||
站点图标表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 站点名称
|
||||
name = Column(String, nullable=False)
|
||||
# 域名Key
|
||||
@@ -18,7 +19,13 @@ class SiteIcon(Base):
|
||||
# 图标Base64
|
||||
base64 = Column(String)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_domain(db: Session, domain: str):
|
||||
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first()
|
||||
def get_by_domain(cls, db: Session, domain: str):
|
||||
return db.query(cls).filter(cls.domain == domain).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, JSON
|
||||
from sqlalchemy import Column, Integer, String, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, get_id_column, Base, async_db_query
|
||||
|
||||
|
||||
class SiteStatistic(Base):
|
||||
"""
|
||||
站点统计表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 域名Key
|
||||
domain = Column(String, index=True)
|
||||
# 成功次数
|
||||
@@ -26,12 +27,18 @@ class SiteStatistic(Base):
|
||||
# 耗时记录 Json
|
||||
note = Column(JSON)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_domain(db: Session, domain: str):
|
||||
return db.query(SiteStatistic).filter(SiteStatistic.domain == domain).first()
|
||||
def get_by_domain(cls, db: Session, domain: str):
|
||||
return db.query(cls).filter(cls.domain == domain).first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def reset(db: Session):
|
||||
db.query(SiteStatistic).delete()
|
||||
def reset(cls, db: Session):
|
||||
db.query(cls).delete()
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_
|
||||
from sqlalchemy import Column, Integer, String, Float, JSON, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, Base
|
||||
from app.db import db_query, Base, get_id_column, async_db_query
|
||||
|
||||
|
||||
class SiteUserData(Base):
|
||||
"""
|
||||
站点数据表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 站点域名
|
||||
domain = Column(String, index=True)
|
||||
# 站点名称
|
||||
@@ -19,7 +20,7 @@ class SiteUserData(Base):
|
||||
# 用户名
|
||||
username = Column(String)
|
||||
# 用户ID
|
||||
userid = Column(Integer)
|
||||
userid = Column(String)
|
||||
# 用户等级
|
||||
user_level = Column(String)
|
||||
# 加入时间
|
||||
@@ -53,42 +54,78 @@ class SiteUserData(Base):
|
||||
# 更新时间
|
||||
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_domain(db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
|
||||
def get_by_domain(cls, db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
|
||||
if workdate and worktime:
|
||||
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
|
||||
SiteUserData.updated_day == workdate,
|
||||
SiteUserData.updated_time == worktime).all()
|
||||
return db.query(cls).filter(cls.domain == domain,
|
||||
cls.updated_day == workdate,
|
||||
cls.updated_time == worktime).all()
|
||||
elif workdate:
|
||||
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
|
||||
SiteUserData.updated_day == workdate).all()
|
||||
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all()
|
||||
return db.query(cls).filter(cls.domain == domain,
|
||||
cls.updated_day == workdate).all()
|
||||
return db.query(cls).filter(cls.domain == domain).all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_by_date(db: Session, date: str):
|
||||
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all()
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
|
||||
query = select(cls).filter(cls.domain == domain)
|
||||
if workdate and worktime:
|
||||
query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime)
|
||||
elif workdate:
|
||||
query = query.filter(cls.updated_day == workdate)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_latest(db: Session):
|
||||
def get_by_date(cls, db: Session, date: str):
|
||||
return db.query(cls).filter(cls.updated_day == date).all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_latest(cls, db: Session):
|
||||
"""
|
||||
获取各站点最新一天的数据
|
||||
"""
|
||||
subquery = (
|
||||
db.query(
|
||||
SiteUserData.domain,
|
||||
func.max(SiteUserData.updated_day).label('latest_update_day')
|
||||
cls.domain,
|
||||
func.max(cls.updated_day).label('latest_update_day')
|
||||
)
|
||||
.group_by(SiteUserData.domain)
|
||||
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == ""))
|
||||
.group_by(cls.domain)
|
||||
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 主查询:按 domain 和 updated_day 获取最新的记录
|
||||
return db.query(SiteUserData).join(
|
||||
return db.query(cls).join(
|
||||
subquery,
|
||||
(SiteUserData.domain == subquery.c.domain) &
|
||||
(SiteUserData.updated_day == subquery.c.latest_update_day)
|
||||
).order_by(SiteUserData.updated_time.desc()).all()
|
||||
(cls.domain == subquery.c.domain) &
|
||||
(cls.updated_day == subquery.c.latest_update_day)
|
||||
).order_by(cls.updated_time.desc()).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_latest(cls, db: AsyncSession):
|
||||
"""
|
||||
异步获取各站点最新一天的数据
|
||||
"""
|
||||
subquery = (
|
||||
select(
|
||||
cls.domain,
|
||||
func.max(cls.updated_day).label('latest_update_day')
|
||||
)
|
||||
.group_by(cls.domain)
|
||||
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 主查询:按 domain 和 updated_day 获取最新的记录
|
||||
result = await db.execute(
|
||||
select(cls).join(
|
||||
subquery,
|
||||
(cls.domain == subquery.c.domain) &
|
||||
(cls.updated_day == subquery.c.latest_update_day)
|
||||
).order_by(cls.updated_time.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
|
||||
from sqlalchemy import Column, Integer, String, Float, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, Base
|
||||
from app.db import db_query, db_update, get_id_column, Base, async_db_query, async_db_update
|
||||
|
||||
|
||||
class Subscribe(Base):
|
||||
"""
|
||||
订阅表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 标题
|
||||
name = Column(String, nullable=False, index=True)
|
||||
# 年份
|
||||
@@ -87,59 +88,144 @@ class Subscribe(Base):
|
||||
# 选择的剧集组
|
||||
episode_group = Column(String)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
|
||||
Subscribe.season == season).first()
|
||||
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
elif doubanid:
|
||||
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
|
||||
return db.query(cls).filter(cls.doubanid == doubanid).first()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid)
|
||||
)
|
||||
elif doubanid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_state(db: Session, state: str):
|
||||
def get_by_state(cls, db: Session, state: str):
|
||||
# 如果 state 为空或 None,返回所有订阅
|
||||
if not state:
|
||||
return db.query(Subscribe).all()
|
||||
return db.query(cls).all()
|
||||
else:
|
||||
# 如果传入的状态不为空,拆分成多个状态
|
||||
return db.query(Subscribe).filter(Subscribe.state.in_(state.split(','))).all()
|
||||
return db.query(cls).filter(cls.state.in_(state.split(','))).all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_by_title(db: Session, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
return db.query(Subscribe).filter(Subscribe.name == title,
|
||||
Subscribe.season == season).first()
|
||||
return db.query(Subscribe).filter(Subscribe.name == title).first()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
|
||||
Subscribe.season == season).all()
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_state(cls, db: AsyncSession, state: str):
|
||||
# 如果 state 为空或 None,返回所有订阅
|
||||
if not state:
|
||||
result = await db.execute(select(cls))
|
||||
else:
|
||||
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all()
|
||||
# 如果传入的状态不为空,拆分成多个状态
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.state.in_(state.split(',')))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_doubanid(db: Session, doubanid: str):
|
||||
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
|
||||
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
return db.query(cls).filter(cls.name == title,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.name == title).first()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_by_bangumiid(db: Session, bangumiid: int):
|
||||
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first()
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title, cls.season == season)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_mediaid(db: Session, mediaid: str):
|
||||
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first()
|
||||
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).all()
|
||||
else:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_doubanid(cls, db: Session, doubanid: str):
|
||||
return db.query(cls).filter(cls.doubanid == doubanid).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_doubanid(cls, db: AsyncSession, doubanid: str):
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_bangumiid(cls, db: Session, bangumiid: int):
|
||||
return db.query(cls).filter(cls.bangumiid == bangumiid).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.bangumiid == bangumiid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_mediaid(cls, db: Session, mediaid: str):
|
||||
return db.query(cls).filter(cls.mediaid == mediaid).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_mediaid(cls, db: AsyncSession, mediaid: str):
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.mediaid == mediaid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@db_update
|
||||
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
|
||||
@@ -148,6 +234,13 @@ class Subscribe(Base):
|
||||
subscrbie.delete(db, subscrbie.id)
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_delete_by_tmdbid(self, db: AsyncSession, tmdbid: int, season: int):
|
||||
subscrbies = await self.async_get_by_tmdbid(db, tmdbid, season)
|
||||
for subscrbie in subscrbies:
|
||||
await subscrbie.async_delete(db, subscrbie.id)
|
||||
return True
|
||||
|
||||
@db_update
|
||||
def delete_by_doubanid(self, db: Session, doubanid: str):
|
||||
subscribe = self.get_by_doubanid(db, doubanid)
|
||||
@@ -155,6 +248,13 @@ class Subscribe(Base):
|
||||
subscribe.delete(db, subscribe.id)
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_delete_by_doubanid(self, db: AsyncSession, doubanid: str):
|
||||
subscribe = await self.async_get_by_doubanid(db, doubanid)
|
||||
if subscribe:
|
||||
await subscribe.async_delete(db, subscribe.id)
|
||||
return True
|
||||
|
||||
@db_update
|
||||
def delete_by_mediaid(self, db: Session, mediaid: str):
|
||||
subscribe = self.get_by_mediaid(db, mediaid)
|
||||
@@ -162,29 +262,72 @@ class Subscribe(Base):
|
||||
subscribe.delete(db, subscribe.id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_delete_by_mediaid(self, db: AsyncSession, mediaid: str):
|
||||
subscribe = await self.async_get_by_mediaid(db, mediaid)
|
||||
if subscribe:
|
||||
await subscribe.async_delete(db, subscribe.id)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
|
||||
def list_by_username(cls, db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
|
||||
if mtype:
|
||||
if state:
|
||||
return db.query(Subscribe).filter(Subscribe.state == state,
|
||||
Subscribe.username == username,
|
||||
Subscribe.type == mtype).all()
|
||||
return db.query(cls).filter(cls.state == state,
|
||||
cls.username == username,
|
||||
cls.type == mtype).all()
|
||||
else:
|
||||
return db.query(Subscribe).filter(Subscribe.username == username,
|
||||
Subscribe.type == mtype).all()
|
||||
return db.query(cls).filter(cls.username == username,
|
||||
cls.type == mtype).all()
|
||||
else:
|
||||
if state:
|
||||
return db.query(Subscribe).filter(Subscribe.state == state,
|
||||
Subscribe.username == username).all()
|
||||
return db.query(cls).filter(cls.state == state,
|
||||
cls.username == username).all()
|
||||
else:
|
||||
return db.query(Subscribe).filter(Subscribe.username == username).all()
|
||||
return db.query(cls).filter(cls.username == username).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_username(cls, db: AsyncSession, username: str, state: Optional[str] = None,
|
||||
mtype: Optional[str] = None):
|
||||
if mtype:
|
||||
if state:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.state == state, cls.username == username, cls.type == mtype)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.username == username, cls.type == mtype)
|
||||
)
|
||||
else:
|
||||
if state:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.state == state, cls.username == username)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.username == username)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_type(db: Session, mtype: str, days: int):
|
||||
return db.query(Subscribe) \
|
||||
.filter(Subscribe.type == mtype,
|
||||
Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S",
|
||||
time.localtime(time.time() - 86400 * int(days)))
|
||||
def list_by_type(cls, db: Session, mtype: str, days: int):
|
||||
return db.query(cls) \
|
||||
.filter(cls.type == mtype,
|
||||
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
|
||||
time.localtime(time.time() - 86400 * int(days)))
|
||||
).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_type(cls, db: AsyncSession, mtype: str, days: int):
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.type == mtype,
|
||||
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
|
||||
time.localtime(time.time() - 86400 * int(days)))
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
|
||||
from sqlalchemy import Column, Integer, String, Float, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, Base
|
||||
from app.db import db_query, Base, get_id_column, async_db_query
|
||||
|
||||
|
||||
class SubscribeHistory(Base):
|
||||
"""
|
||||
订阅历史表
|
||||
"""
|
||||
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
|
||||
id = get_id_column()
|
||||
# 标题
|
||||
name = Column(String, nullable=False, index=True)
|
||||
# 年份
|
||||
@@ -72,23 +73,57 @@ class SubscribeHistory(Base):
|
||||
# 剧集组
|
||||
episode_group = Column(String)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_type(db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
return db.query(SubscribeHistory).filter(
|
||||
SubscribeHistory.type == mtype
|
||||
def list_by_type(cls, db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
return db.query(cls).filter(
|
||||
cls.type == mtype
|
||||
).order_by(
|
||||
SubscribeHistory.date.desc()
|
||||
cls.date.desc()
|
||||
).offset((page - 1) * count).limit(count).all()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_type(cls, db: AsyncSession, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.type == mtype
|
||||
).order_by(
|
||||
cls.date.desc()
|
||||
).offset((page - 1) * count).limit(count)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid,
|
||||
SubscribeHistory.season == season).first()
|
||||
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
elif doubanid:
|
||||
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first()
|
||||
return db.query(cls).filter(cls.doubanid == doubanid).first()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid)
|
||||
)
|
||||
elif doubanid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return result.scalars().first()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user