diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 0073a537..5253f55d 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -21,6 +21,7 @@ from app.core.module import ModuleManager from app.core.plugin import PluginManager from app.db.message_oper import MessageOper from app.db.user_oper import UserOper +from app.helper.recognize import MediaRecognizeShareHelper from app.helper.message import MessageHelper, MessageQueueManager, MessageTemplateHelper from app.helper.service import ServiceConfigHelper from app.log import logger @@ -398,6 +399,22 @@ class ChainBase(metaclass=ABCMeta): method, result, *args, **kwargs ) + @staticmethod + def _can_use_media_recognize_share( + meta: Optional[MetaBase], + tmdbid: Optional[int], + doubanid: Optional[str], + bangumiid: Optional[int], + ) -> bool: + """ + 仅在名称识别场景下使用共享识别,显式ID识别不再重复回查 + """ + return bool( + settings.MEDIA_RECOGNIZE_SHARE + and meta + and not any([tmdbid, doubanid, bangumiid]) + ) + def recognize_media( self, meta: MetaBase = None, @@ -430,8 +447,9 @@ class ChainBase(metaclass=ABCMeta): bangumiid = None elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]: mtype = meta.type + share_helper = MediaRecognizeShareHelper() with fresh(not cache): - return self.run_module( + mediainfo = self.run_module( "recognize_media", meta=meta, mtype=mtype, @@ -441,6 +459,29 @@ class ChainBase(metaclass=ABCMeta): episode_group=episode_group, cache=cache, ) + if mediainfo: + share_helper.report(meta=meta, mediainfo=mediainfo) + return mediainfo + + if self._can_use_media_recognize_share(meta, tmdbid, doubanid, bangumiid): + shared_item = share_helper.query(meta=meta, mtype=mtype) + shared_params = share_helper.to_recognize_params(shared_item) + if shared_params: + with fresh(not cache): + mediainfo = self.run_module( + "recognize_media", + meta=meta, + mtype=shared_params.get("mtype") or mtype, + tmdbid=shared_params.get("tmdbid"), + doubanid=shared_params.get("doubanid"), + bangumiid=shared_params.get("bangumiid"), + episode_group=episode_group, + cache=cache, + ) + if mediainfo: + share_helper.report(meta=meta, mediainfo=mediainfo) + return mediainfo + return None async def async_recognize_media( self, @@ -474,8 +515,9 @@ class ChainBase(metaclass=ABCMeta): bangumiid = None elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]: mtype = meta.type + share_helper = MediaRecognizeShareHelper() async with async_fresh(not cache): - return await self.async_run_module( + mediainfo = await self.async_run_module( "async_recognize_media", meta=meta, mtype=mtype, @@ -485,6 +527,29 @@ class ChainBase(metaclass=ABCMeta): episode_group=episode_group, cache=cache, ) + if mediainfo: + await share_helper.async_report(meta=meta, mediainfo=mediainfo) + return mediainfo + + if self._can_use_media_recognize_share(meta, tmdbid, doubanid, bangumiid): + shared_item = await share_helper.async_query(meta=meta, mtype=mtype) + shared_params = share_helper.to_recognize_params(shared_item) + if shared_params: + async with async_fresh(not cache): + mediainfo = await self.async_run_module( + "async_recognize_media", + meta=meta, + mtype=shared_params.get("mtype") or mtype, + tmdbid=shared_params.get("tmdbid"), + doubanid=shared_params.get("doubanid"), + bangumiid=shared_params.get("bangumiid"), + episode_group=episode_group, + cache=cache, + ) + if mediainfo: + await share_helper.async_report(meta=meta, mediainfo=mediainfo) + return mediainfo + return None def match_doubaninfo( self, diff --git a/app/core/config.py b/app/core/config.py index 24316400..56d72cff 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -378,10 +378,14 @@ class ConfigModel(BaseModel): SCRAP_FOLLOW_TMDB: bool = True # 优先使用辅助识别 RECOGNIZE_PLUGIN_FIRST: bool = False + # 共享使用媒体识别数据 + MEDIA_RECOGNIZE_SHARE: bool = True # ==================== 服务地址配置 ==================== # 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目 MP_SERVER_HOST: str = "https://movie-pilot.org" + # 共享媒体识别API地址,留空时默认拼接为 MP_SERVER_HOST + /recognize/share + MEDIA_RECOGNIZE_SHARE_API: Optional[str] = None # ==================== 个性化 ==================== # 登录页面电影海报,tmdb/bing/mediaserver diff --git a/app/helper/recognize.py b/app/helper/recognize.py new file mode 100644 index 00000000..3592e401 --- /dev/null +++ b/app/helper/recognize.py @@ -0,0 +1,368 @@ +import json +from typing import Optional + +from app.core.config import settings +from app.core.context import MediaInfo +from app.core.meta import MetaBase +from app.log import logger +from app.schemas.types import MediaType, media_type_to_agent +from app.utils.http import RequestUtils, AsyncRequestUtils +from app.utils.singleton import WeakSingleton + + +class MediaRecognizeShareHelper(metaclass=WeakSingleton): + """ + 共享媒体识别帮助类 + """ + + _default_path = "/recognize/share" + + @classmethod + def _normalize_media_type(cls, media_type: Optional[object]) -> Optional[str]: + """ + 统一媒体类型,兼容枚举、中文值和 agent 风格字符串 + """ + normalized = media_type_to_agent(media_type) + if normalized in {"movie", "tv"}: + return normalized + if isinstance(media_type, str): + if media_type == MediaType.MOVIE.value: + return "movie" + if media_type == MediaType.TV.value: + return "tv" + return None + + @staticmethod + def _extract_keyword(meta: Optional[MetaBase]) -> Optional[str]: + """ + 提取识别关键字 + """ + if not meta: + return None + keyword = getattr(meta, "name", None) + if keyword: + keyword = str(keyword).strip() + return keyword or None + + @classmethod + def _extract_media_type( + cls, + meta: Optional[MetaBase] = None, + mtype: Optional[MediaType] = None, + mediainfo: Optional[MediaInfo] = None, + ) -> Optional[str]: + """ + 提取媒体类型 + """ + media_type = cls._normalize_media_type(mtype) + if media_type: + return media_type + if mediainfo and mediainfo.type in {MediaType.MOVIE, MediaType.TV}: + return mediainfo.type.to_agent() + if meta and meta.type in {MediaType.MOVIE, MediaType.TV}: + return meta.type.to_agent() + if meta and (meta.begin_season is not None or meta.begin_episode is not None): + return "tv" + return None + + @classmethod + def _extract_season( + cls, + media_type: Optional[str], + meta: Optional[MetaBase] = None, + mediainfo: Optional[MediaInfo] = None, + ) -> Optional[int]: + """ + 提取季信息,仅电视剧使用 + """ + if media_type != "tv": + return None + season = getattr(meta, "begin_season", None) + if season is None and mediainfo: + season = mediainfo.season + try: + return int(season) if season is not None else None + except (TypeError, ValueError): + return None + + @staticmethod + def _extract_year( + meta: Optional[MetaBase] = None, + mediainfo: Optional[MediaInfo] = None, + ) -> Optional[str]: + """ + 提取年份 + """ + year = getattr(meta, "year", None) or (mediainfo.year if mediainfo else None) + if year is None: + return None + year_text = str(year).strip() + return year_text or None + + @classmethod + def _build_api_url(cls) -> Optional[str]: + """ + 获取共享识别API地址 + """ + custom_api = (settings.MEDIA_RECOGNIZE_SHARE_API or "").strip() + if custom_api: + return custom_api.rstrip("/") + server_host = (settings.MP_SERVER_HOST or "").strip().rstrip("/") + if not server_host: + return None + return f"{server_host}{cls._default_path}" + + @classmethod + def _build_query_params( + cls, meta: Optional[MetaBase], mtype: Optional[MediaType] = None + ) -> Optional[dict]: + """ + 组装共享识别查询参数 + """ + keyword = cls._extract_keyword(meta) + if not keyword: + return None + + media_type = cls._extract_media_type(meta=meta, mtype=mtype) + params = { + "keyword": keyword, + } + if media_type: + params["type"] = media_type + if year := cls._extract_year(meta=meta): + params["year"] = year + if season := cls._extract_season(media_type=media_type, meta=meta): + params["season"] = season + return params + + @classmethod + def _build_report_payload( + cls, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo] + ) -> Optional[dict]: + """ + 组装共享识别上报载荷 + """ + if not meta or not mediainfo: + return None + + keyword = cls._extract_keyword(meta) + media_type = cls._extract_media_type(meta=meta, mediainfo=mediainfo) + if not keyword or not media_type: + return None + if not any([mediainfo.tmdb_id, mediainfo.douban_id, mediainfo.bangumi_id]): + return None + + return { + "keyword": keyword, + "type": media_type, + "title": mediainfo.title or keyword, + "year": cls._extract_year(meta=meta, mediainfo=mediainfo), + "season": cls._extract_season( + media_type=media_type, + meta=meta, + mediainfo=mediainfo, + ), + "tmdbid": mediainfo.tmdb_id, + "doubanid": mediainfo.douban_id, + "bangumiid": mediainfo.bangumi_id, + } + + @staticmethod + def _parse_response_item(data: Optional[dict]) -> Optional[dict]: + """ + 解析服务端返回的共享识别数据 + """ + if not isinstance(data, dict): + return None + item = (data.get("data") or {}).get("item") + if not isinstance(item, dict): + return None + return item + + @staticmethod + def _response_message(response) -> str: + """ + 获取响应消息,兼容非JSON响应 + """ + try: + payload = response.json() + return str(payload.get("message") or "") + except (json.JSONDecodeError, ValueError, AttributeError): + return "" + + @staticmethod + def _is_enabled() -> bool: + """ + 是否启用共享识别 + """ + return bool(settings.MEDIA_RECOGNIZE_SHARE) + + def query(self, meta: Optional[MetaBase], mtype: Optional[MediaType] = None) -> Optional[dict]: + """ + 查询共享识别结果 + """ + if not self._is_enabled(): + return None + + api_url = self._build_api_url() + params = self._build_query_params(meta=meta, mtype=mtype) + if not api_url or not params: + return None + + response = RequestUtils(proxies=settings.PROXY or {}, timeout=5).get_res( + api_url, + params=params, + ) + if not response or response.status_code != 200: + if response is not None: + logger.warn( + f"查询共享媒体识别失败:status={response.status_code} " + f"message={self._response_message(response)}" + ) + return None + + try: + payload = response.json() + except (json.JSONDecodeError, ValueError) as err: + logger.warn(f"解析共享媒体识别响应失败:{err}") + return None + + if payload.get("code") != 0: + return None + + item = self._parse_response_item(payload) + if item: + logger.info(f"共享媒体识别命中:{params.get('keyword')}") + return item + + async def async_query( + self, meta: Optional[MetaBase], mtype: Optional[MediaType] = None + ) -> Optional[dict]: + """ + 异步查询共享识别结果 + """ + if not self._is_enabled(): + return None + + api_url = self._build_api_url() + params = self._build_query_params(meta=meta, mtype=mtype) + if not api_url or not params: + return None + + response = await AsyncRequestUtils( + proxies=settings.PROXY or {}, + timeout=5, + ).get_res(api_url, params=params) + if not response or response.status_code != 200: + if response is not None: + logger.warn( + f"异步查询共享媒体识别失败:status={response.status_code} " + f"message={self._response_message(response)}" + ) + return None + + try: + payload = response.json() + except (json.JSONDecodeError, ValueError) as err: + logger.warn(f"解析共享媒体识别响应失败:{err}") + return None + + if payload.get("code") != 0: + return None + + item = self._parse_response_item(payload) + if item: + logger.info(f"共享媒体识别命中:{params.get('keyword')}") + return item + + def report(self, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo]) -> bool: + """ + 上报共享识别结果 + """ + if not self._is_enabled(): + return False + + api_url = self._build_api_url() + payload = self._build_report_payload(meta=meta, mediainfo=mediainfo) + if not api_url or not payload: + return False + + response = RequestUtils( + proxies=settings.PROXY or {}, + timeout=5, + content_type="application/json", + ).post_res(api_url, json=payload) + if not response or response.status_code != 200: + if response is not None: + logger.warn( + f"上报共享媒体识别失败:status={response.status_code} " + f"message={self._response_message(response)}" + ) + return False + + try: + result = response.json() + except (json.JSONDecodeError, ValueError) as err: + logger.warn(f"解析共享媒体识别上报响应失败:{err}") + return False + + return result.get("code") == 0 + + async def async_report( + self, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo] + ) -> bool: + """ + 异步上报共享识别结果 + """ + if not self._is_enabled(): + return False + + api_url = self._build_api_url() + payload = self._build_report_payload(meta=meta, mediainfo=mediainfo) + if not api_url or not payload: + return False + + response = await AsyncRequestUtils( + proxies=settings.PROXY or {}, + timeout=5, + content_type="application/json", + ).post_res(api_url, json=payload) + if not response or response.status_code != 200: + if response is not None: + logger.warn( + f"异步上报共享媒体识别失败:status={response.status_code} " + f"message={self._response_message(response)}" + ) + return False + + try: + result = response.json() + except (json.JSONDecodeError, ValueError) as err: + logger.warn(f"解析共享媒体识别上报响应失败:{err}") + return False + + return result.get("code") == 0 + + @classmethod + def to_recognize_params(cls, item: Optional[dict]) -> Optional[dict]: + """ + 将服务端返回的共享识别结果转成本地识别参数 + """ + if not isinstance(item, dict): + return None + + media_type = cls._normalize_media_type(item.get("type")) + mtype = MediaType.from_agent(media_type) if media_type else None + tmdbid = item.get("tmdbid") + doubanid = item.get("doubanid") + bangumiid = item.get("bangumiid") + if not any([tmdbid, doubanid, bangumiid]): + return None + + return { + "mtype": mtype, + "tmdbid": tmdbid, + "doubanid": doubanid, + "bangumiid": bangumiid, + "season": item.get("season"), + } diff --git a/tests/test_media_recognize_share.py b/tests/test_media_recognize_share.py new file mode 100644 index 00000000..49063a11 --- /dev/null +++ b/tests/test_media_recognize_share.py @@ -0,0 +1,132 @@ +import asyncio +import sys +import unittest +from types import ModuleType +from unittest.mock import AsyncMock, patch + +sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi")) +setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list) +sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) +setattr(sys.modules["transmission_rpc"], "File", object) +sys.modules.setdefault("psutil", ModuleType("psutil")) + +from app.chain import ChainBase +from app.core.context import MediaInfo +from app.core.meta import MetaBase +from app.schemas.types import MediaType + + +class TestMediaRecognizeShare(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.chain = ChainBase() + + @staticmethod + def _build_meta(name: str, media_type: MediaType = MediaType.UNKNOWN) -> MetaBase: + """ + 构造测试用元数据 + """ + meta = MetaBase(name) + meta.name = name + meta.type = media_type + return meta + + def test_report_shared_result_after_local_recognize_success(self): + """ + 本地识别成功后应上报共享识别结果 + """ + meta = self._build_meta("测试电影", MediaType.MOVIE) + mediainfo = MediaInfo(title="测试电影", year="2024", tmdb_id=100, type=MediaType.MOVIE) + + with patch.object(self.chain, "run_module", return_value=mediainfo) as run_module, patch( + "app.chain.MediaRecognizeShareHelper.report", + return_value=True, + ) as report_mock, patch( + "app.chain.MediaRecognizeShareHelper.query" + ) as query_mock: + result = self.chain.recognize_media(meta=meta, cache=False) + + self.assertIs(result, mediainfo) + run_module.assert_called_once() + report_mock.assert_called_once_with(meta=meta, mediainfo=mediainfo) + query_mock.assert_not_called() + + def test_query_shared_result_when_local_recognize_failed(self): + """ + 本地识别失败后应回查共享识别结果,并按共享ID再次识别 + """ + meta = self._build_meta("测试剧集") + shared_media = MediaInfo(title="测试剧集", year="2024", tmdb_id=200, type=MediaType.TV) + + with patch.object( + self.chain, + "run_module", + side_effect=[None, shared_media], + ) as run_module, patch( + "app.chain.MediaRecognizeShareHelper.query", + return_value={"type": "tv", "tmdbid": 200, "season": 1}, + ) as query_mock, patch( + "app.chain.MediaRecognizeShareHelper.to_recognize_params", + return_value={ + "mtype": MediaType.TV, + "tmdbid": 200, + "doubanid": None, + "bangumiid": None, + "season": 1, + }, + ), patch( + "app.chain.MediaRecognizeShareHelper.report", + return_value=False, + ): + result = self.chain.recognize_media(meta=meta, cache=False) + + self.assertIs(result, shared_media) + self.assertEqual(run_module.call_count, 2) + query_mock.assert_called_once_with(meta=meta, mtype=None) + second_call = run_module.call_args_list[1] + self.assertEqual(second_call.kwargs["tmdbid"], 200) + self.assertEqual(second_call.kwargs["mtype"], MediaType.TV) + self.assertEqual(meta.begin_season, 1) + + def test_async_query_shared_result_when_local_recognize_failed(self): + """ + 异步识别失败后也应回查共享识别结果 + """ + meta = self._build_meta("测试异步剧集") + shared_media = MediaInfo(title="测试异步剧集", year="2025", tmdb_id=300, type=MediaType.TV) + async_run_module = AsyncMock(side_effect=[None, shared_media]) + + async def runner(): + with patch.object( + self.chain, + "async_run_module", + async_run_module, + ), patch( + "app.chain.MediaRecognizeShareHelper.async_query", + AsyncMock(return_value={"type": "tv", "tmdbid": 300, "season": 2}), + ) as query_mock, patch( + "app.chain.MediaRecognizeShareHelper.to_recognize_params", + return_value={ + "mtype": MediaType.TV, + "tmdbid": 300, + "doubanid": None, + "bangumiid": None, + "season": 2, + }, + ), patch( + "app.chain.MediaRecognizeShareHelper.async_report", + AsyncMock(return_value=False), + ): + result = await self.chain.async_recognize_media(meta=meta, cache=False) + return result, query_mock + + result, query_mock = asyncio.run(runner()) + + self.assertIs(result, shared_media) + self.assertEqual(async_run_module.await_count, 2) + query_mock.assert_awaited_once_with(meta=meta, mtype=None) + self.assertEqual(meta.begin_season, 2) + + +if __name__ == "__main__": + unittest.main()