mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-13 07:26:45 +00:00
feat(recognize): implement media recognition sharing functionality with API integration
This commit is contained in:
@@ -21,6 +21,7 @@ from app.core.module import ModuleManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.message_oper import MessageOper
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.recognize import MediaRecognizeShareHelper
|
||||
from app.helper.message import MessageHelper, MessageQueueManager, MessageTemplateHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
@@ -398,6 +399,22 @@ class ChainBase(metaclass=ABCMeta):
|
||||
method, result, *args, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _can_use_media_recognize_share(
|
||||
meta: Optional[MetaBase],
|
||||
tmdbid: Optional[int],
|
||||
doubanid: Optional[str],
|
||||
bangumiid: Optional[int],
|
||||
) -> bool:
|
||||
"""
|
||||
仅在名称识别场景下使用共享识别,显式ID识别不再重复回查
|
||||
"""
|
||||
return bool(
|
||||
settings.MEDIA_RECOGNIZE_SHARE
|
||||
and meta
|
||||
and not any([tmdbid, doubanid, bangumiid])
|
||||
)
|
||||
|
||||
def recognize_media(
|
||||
self,
|
||||
meta: MetaBase = None,
|
||||
@@ -430,8 +447,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
bangumiid = None
|
||||
elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
share_helper = MediaRecognizeShareHelper()
|
||||
with fresh(not cache):
|
||||
return self.run_module(
|
||||
mediainfo = self.run_module(
|
||||
"recognize_media",
|
||||
meta=meta,
|
||||
mtype=mtype,
|
||||
@@ -441,6 +459,29 @@ class ChainBase(metaclass=ABCMeta):
|
||||
episode_group=episode_group,
|
||||
cache=cache,
|
||||
)
|
||||
if mediainfo:
|
||||
share_helper.report(meta=meta, mediainfo=mediainfo)
|
||||
return mediainfo
|
||||
|
||||
if self._can_use_media_recognize_share(meta, tmdbid, doubanid, bangumiid):
|
||||
shared_item = share_helper.query(meta=meta, mtype=mtype)
|
||||
shared_params = share_helper.to_recognize_params(shared_item)
|
||||
if shared_params:
|
||||
with fresh(not cache):
|
||||
mediainfo = self.run_module(
|
||||
"recognize_media",
|
||||
meta=meta,
|
||||
mtype=shared_params.get("mtype") or mtype,
|
||||
tmdbid=shared_params.get("tmdbid"),
|
||||
doubanid=shared_params.get("doubanid"),
|
||||
bangumiid=shared_params.get("bangumiid"),
|
||||
episode_group=episode_group,
|
||||
cache=cache,
|
||||
)
|
||||
if mediainfo:
|
||||
share_helper.report(meta=meta, mediainfo=mediainfo)
|
||||
return mediainfo
|
||||
return None
|
||||
|
||||
async def async_recognize_media(
|
||||
self,
|
||||
@@ -474,8 +515,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
bangumiid = None
|
||||
elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
share_helper = MediaRecognizeShareHelper()
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module(
|
||||
mediainfo = await self.async_run_module(
|
||||
"async_recognize_media",
|
||||
meta=meta,
|
||||
mtype=mtype,
|
||||
@@ -485,6 +527,29 @@ class ChainBase(metaclass=ABCMeta):
|
||||
episode_group=episode_group,
|
||||
cache=cache,
|
||||
)
|
||||
if mediainfo:
|
||||
await share_helper.async_report(meta=meta, mediainfo=mediainfo)
|
||||
return mediainfo
|
||||
|
||||
if self._can_use_media_recognize_share(meta, tmdbid, doubanid, bangumiid):
|
||||
shared_item = await share_helper.async_query(meta=meta, mtype=mtype)
|
||||
shared_params = share_helper.to_recognize_params(shared_item)
|
||||
if shared_params:
|
||||
async with async_fresh(not cache):
|
||||
mediainfo = await self.async_run_module(
|
||||
"async_recognize_media",
|
||||
meta=meta,
|
||||
mtype=shared_params.get("mtype") or mtype,
|
||||
tmdbid=shared_params.get("tmdbid"),
|
||||
doubanid=shared_params.get("doubanid"),
|
||||
bangumiid=shared_params.get("bangumiid"),
|
||||
episode_group=episode_group,
|
||||
cache=cache,
|
||||
)
|
||||
if mediainfo:
|
||||
await share_helper.async_report(meta=meta, mediainfo=mediainfo)
|
||||
return mediainfo
|
||||
return None
|
||||
|
||||
def match_doubaninfo(
|
||||
self,
|
||||
|
||||
@@ -378,10 +378,14 @@ class ConfigModel(BaseModel):
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
# 优先使用辅助识别
|
||||
RECOGNIZE_PLUGIN_FIRST: bool = False
|
||||
# 共享使用媒体识别数据
|
||||
MEDIA_RECOGNIZE_SHARE: bool = True
|
||||
|
||||
# ==================== 服务地址配置 ====================
|
||||
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
|
||||
MP_SERVER_HOST: str = "https://movie-pilot.org"
|
||||
# 共享媒体识别API地址,留空时默认拼接为 MP_SERVER_HOST + /recognize/share
|
||||
MEDIA_RECOGNIZE_SHARE_API: Optional[str] = None
|
||||
|
||||
# ==================== 个性化 ====================
|
||||
# 登录页面电影海报,tmdb/bing/mediaserver
|
||||
|
||||
368
app/helper/recognize.py
Normal file
368
app/helper/recognize.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
|
||||
|
||||
class MediaRecognizeShareHelper(metaclass=WeakSingleton):
|
||||
"""
|
||||
共享媒体识别帮助类
|
||||
"""
|
||||
|
||||
_default_path = "/recognize/share"
|
||||
|
||||
@classmethod
|
||||
def _normalize_media_type(cls, media_type: Optional[object]) -> Optional[str]:
|
||||
"""
|
||||
统一媒体类型,兼容枚举、中文值和 agent 风格字符串
|
||||
"""
|
||||
normalized = media_type_to_agent(media_type)
|
||||
if normalized in {"movie", "tv"}:
|
||||
return normalized
|
||||
if isinstance(media_type, str):
|
||||
if media_type == MediaType.MOVIE.value:
|
||||
return "movie"
|
||||
if media_type == MediaType.TV.value:
|
||||
return "tv"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_keyword(meta: Optional[MetaBase]) -> Optional[str]:
|
||||
"""
|
||||
提取识别关键字
|
||||
"""
|
||||
if not meta:
|
||||
return None
|
||||
keyword = getattr(meta, "name", None)
|
||||
if keyword:
|
||||
keyword = str(keyword).strip()
|
||||
return keyword or None
|
||||
|
||||
@classmethod
|
||||
def _extract_media_type(
|
||||
cls,
|
||||
meta: Optional[MetaBase] = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
提取媒体类型
|
||||
"""
|
||||
media_type = cls._normalize_media_type(mtype)
|
||||
if media_type:
|
||||
return media_type
|
||||
if mediainfo and mediainfo.type in {MediaType.MOVIE, MediaType.TV}:
|
||||
return mediainfo.type.to_agent()
|
||||
if meta and meta.type in {MediaType.MOVIE, MediaType.TV}:
|
||||
return meta.type.to_agent()
|
||||
if meta and (meta.begin_season is not None or meta.begin_episode is not None):
|
||||
return "tv"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_season(
|
||||
cls,
|
||||
media_type: Optional[str],
|
||||
meta: Optional[MetaBase] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
提取季信息,仅电视剧使用
|
||||
"""
|
||||
if media_type != "tv":
|
||||
return None
|
||||
season = getattr(meta, "begin_season", None)
|
||||
if season is None and mediainfo:
|
||||
season = mediainfo.season
|
||||
try:
|
||||
return int(season) if season is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_year(
|
||||
meta: Optional[MetaBase] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
提取年份
|
||||
"""
|
||||
year = getattr(meta, "year", None) or (mediainfo.year if mediainfo else None)
|
||||
if year is None:
|
||||
return None
|
||||
year_text = str(year).strip()
|
||||
return year_text or None
|
||||
|
||||
@classmethod
|
||||
def _build_api_url(cls) -> Optional[str]:
|
||||
"""
|
||||
获取共享识别API地址
|
||||
"""
|
||||
custom_api = (settings.MEDIA_RECOGNIZE_SHARE_API or "").strip()
|
||||
if custom_api:
|
||||
return custom_api.rstrip("/")
|
||||
server_host = (settings.MP_SERVER_HOST or "").strip().rstrip("/")
|
||||
if not server_host:
|
||||
return None
|
||||
return f"{server_host}{cls._default_path}"
|
||||
|
||||
@classmethod
|
||||
def _build_query_params(
|
||||
cls, meta: Optional[MetaBase], mtype: Optional[MediaType] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
组装共享识别查询参数
|
||||
"""
|
||||
keyword = cls._extract_keyword(meta)
|
||||
if not keyword:
|
||||
return None
|
||||
|
||||
media_type = cls._extract_media_type(meta=meta, mtype=mtype)
|
||||
params = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
if media_type:
|
||||
params["type"] = media_type
|
||||
if year := cls._extract_year(meta=meta):
|
||||
params["year"] = year
|
||||
if season := cls._extract_season(media_type=media_type, meta=meta):
|
||||
params["season"] = season
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def _build_report_payload(
|
||||
cls, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo]
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
组装共享识别上报载荷
|
||||
"""
|
||||
if not meta or not mediainfo:
|
||||
return None
|
||||
|
||||
keyword = cls._extract_keyword(meta)
|
||||
media_type = cls._extract_media_type(meta=meta, mediainfo=mediainfo)
|
||||
if not keyword or not media_type:
|
||||
return None
|
||||
if not any([mediainfo.tmdb_id, mediainfo.douban_id, mediainfo.bangumi_id]):
|
||||
return None
|
||||
|
||||
return {
|
||||
"keyword": keyword,
|
||||
"type": media_type,
|
||||
"title": mediainfo.title or keyword,
|
||||
"year": cls._extract_year(meta=meta, mediainfo=mediainfo),
|
||||
"season": cls._extract_season(
|
||||
media_type=media_type,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
),
|
||||
"tmdbid": mediainfo.tmdb_id,
|
||||
"doubanid": mediainfo.douban_id,
|
||||
"bangumiid": mediainfo.bangumi_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _parse_response_item(data: Optional[dict]) -> Optional[dict]:
|
||||
"""
|
||||
解析服务端返回的共享识别数据
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
item = (data.get("data") or {}).get("item")
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def _response_message(response) -> str:
|
||||
"""
|
||||
获取响应消息,兼容非JSON响应
|
||||
"""
|
||||
try:
|
||||
payload = response.json()
|
||||
return str(payload.get("message") or "")
|
||||
except (json.JSONDecodeError, ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _is_enabled() -> bool:
|
||||
"""
|
||||
是否启用共享识别
|
||||
"""
|
||||
return bool(settings.MEDIA_RECOGNIZE_SHARE)
|
||||
|
||||
def query(self, meta: Optional[MetaBase], mtype: Optional[MediaType] = None) -> Optional[dict]:
|
||||
"""
|
||||
查询共享识别结果
|
||||
"""
|
||||
if not self._is_enabled():
|
||||
return None
|
||||
|
||||
api_url = self._build_api_url()
|
||||
params = self._build_query_params(meta=meta, mtype=mtype)
|
||||
if not api_url or not params:
|
||||
return None
|
||||
|
||||
response = RequestUtils(proxies=settings.PROXY or {}, timeout=5).get_res(
|
||||
api_url,
|
||||
params=params,
|
||||
)
|
||||
if not response or response.status_code != 200:
|
||||
if response is not None:
|
||||
logger.warn(
|
||||
f"查询共享媒体识别失败:status={response.status_code} "
|
||||
f"message={self._response_message(response)}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except (json.JSONDecodeError, ValueError) as err:
|
||||
logger.warn(f"解析共享媒体识别响应失败:{err}")
|
||||
return None
|
||||
|
||||
if payload.get("code") != 0:
|
||||
return None
|
||||
|
||||
item = self._parse_response_item(payload)
|
||||
if item:
|
||||
logger.info(f"共享媒体识别命中:{params.get('keyword')}")
|
||||
return item
|
||||
|
||||
async def async_query(
|
||||
self, meta: Optional[MetaBase], mtype: Optional[MediaType] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
异步查询共享识别结果
|
||||
"""
|
||||
if not self._is_enabled():
|
||||
return None
|
||||
|
||||
api_url = self._build_api_url()
|
||||
params = self._build_query_params(meta=meta, mtype=mtype)
|
||||
if not api_url or not params:
|
||||
return None
|
||||
|
||||
response = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY or {},
|
||||
timeout=5,
|
||||
).get_res(api_url, params=params)
|
||||
if not response or response.status_code != 200:
|
||||
if response is not None:
|
||||
logger.warn(
|
||||
f"异步查询共享媒体识别失败:status={response.status_code} "
|
||||
f"message={self._response_message(response)}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except (json.JSONDecodeError, ValueError) as err:
|
||||
logger.warn(f"解析共享媒体识别响应失败:{err}")
|
||||
return None
|
||||
|
||||
if payload.get("code") != 0:
|
||||
return None
|
||||
|
||||
item = self._parse_response_item(payload)
|
||||
if item:
|
||||
logger.info(f"共享媒体识别命中:{params.get('keyword')}")
|
||||
return item
|
||||
|
||||
def report(self, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo]) -> bool:
|
||||
"""
|
||||
上报共享识别结果
|
||||
"""
|
||||
if not self._is_enabled():
|
||||
return False
|
||||
|
||||
api_url = self._build_api_url()
|
||||
payload = self._build_report_payload(meta=meta, mediainfo=mediainfo)
|
||||
if not api_url or not payload:
|
||||
return False
|
||||
|
||||
response = RequestUtils(
|
||||
proxies=settings.PROXY or {},
|
||||
timeout=5,
|
||||
content_type="application/json",
|
||||
).post_res(api_url, json=payload)
|
||||
if not response or response.status_code != 200:
|
||||
if response is not None:
|
||||
logger.warn(
|
||||
f"上报共享媒体识别失败:status={response.status_code} "
|
||||
f"message={self._response_message(response)}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
except (json.JSONDecodeError, ValueError) as err:
|
||||
logger.warn(f"解析共享媒体识别上报响应失败:{err}")
|
||||
return False
|
||||
|
||||
return result.get("code") == 0
|
||||
|
||||
async def async_report(
|
||||
self, meta: Optional[MetaBase], mediainfo: Optional[MediaInfo]
|
||||
) -> bool:
|
||||
"""
|
||||
异步上报共享识别结果
|
||||
"""
|
||||
if not self._is_enabled():
|
||||
return False
|
||||
|
||||
api_url = self._build_api_url()
|
||||
payload = self._build_report_payload(meta=meta, mediainfo=mediainfo)
|
||||
if not api_url or not payload:
|
||||
return False
|
||||
|
||||
response = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY or {},
|
||||
timeout=5,
|
||||
content_type="application/json",
|
||||
).post_res(api_url, json=payload)
|
||||
if not response or response.status_code != 200:
|
||||
if response is not None:
|
||||
logger.warn(
|
||||
f"异步上报共享媒体识别失败:status={response.status_code} "
|
||||
f"message={self._response_message(response)}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
except (json.JSONDecodeError, ValueError) as err:
|
||||
logger.warn(f"解析共享媒体识别上报响应失败:{err}")
|
||||
return False
|
||||
|
||||
return result.get("code") == 0
|
||||
|
||||
@classmethod
|
||||
def to_recognize_params(cls, item: Optional[dict]) -> Optional[dict]:
|
||||
"""
|
||||
将服务端返回的共享识别结果转成本地识别参数
|
||||
"""
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
|
||||
media_type = cls._normalize_media_type(item.get("type"))
|
||||
mtype = MediaType.from_agent(media_type) if media_type else None
|
||||
tmdbid = item.get("tmdbid")
|
||||
doubanid = item.get("doubanid")
|
||||
bangumiid = item.get("bangumiid")
|
||||
if not any([tmdbid, doubanid, bangumiid]):
|
||||
return None
|
||||
|
||||
return {
|
||||
"mtype": mtype,
|
||||
"tmdbid": tmdbid,
|
||||
"doubanid": doubanid,
|
||||
"bangumiid": bangumiid,
|
||||
"season": item.get("season"),
|
||||
}
|
||||
132
tests/test_media_recognize_share.py
Normal file
132
tests/test_media_recognize_share.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
|
||||
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
|
||||
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
|
||||
setattr(sys.modules["transmission_rpc"], "File", object)
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class TestMediaRecognizeShare(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.chain = ChainBase()
|
||||
|
||||
@staticmethod
|
||||
def _build_meta(name: str, media_type: MediaType = MediaType.UNKNOWN) -> MetaBase:
|
||||
"""
|
||||
构造测试用元数据
|
||||
"""
|
||||
meta = MetaBase(name)
|
||||
meta.name = name
|
||||
meta.type = media_type
|
||||
return meta
|
||||
|
||||
def test_report_shared_result_after_local_recognize_success(self):
|
||||
"""
|
||||
本地识别成功后应上报共享识别结果
|
||||
"""
|
||||
meta = self._build_meta("测试电影", MediaType.MOVIE)
|
||||
mediainfo = MediaInfo(title="测试电影", year="2024", tmdb_id=100, type=MediaType.MOVIE)
|
||||
|
||||
with patch.object(self.chain, "run_module", return_value=mediainfo) as run_module, patch(
|
||||
"app.chain.MediaRecognizeShareHelper.report",
|
||||
return_value=True,
|
||||
) as report_mock, patch(
|
||||
"app.chain.MediaRecognizeShareHelper.query"
|
||||
) as query_mock:
|
||||
result = self.chain.recognize_media(meta=meta, cache=False)
|
||||
|
||||
self.assertIs(result, mediainfo)
|
||||
run_module.assert_called_once()
|
||||
report_mock.assert_called_once_with(meta=meta, mediainfo=mediainfo)
|
||||
query_mock.assert_not_called()
|
||||
|
||||
def test_query_shared_result_when_local_recognize_failed(self):
|
||||
"""
|
||||
本地识别失败后应回查共享识别结果,并按共享ID再次识别
|
||||
"""
|
||||
meta = self._build_meta("测试剧集")
|
||||
shared_media = MediaInfo(title="测试剧集", year="2024", tmdb_id=200, type=MediaType.TV)
|
||||
|
||||
with patch.object(
|
||||
self.chain,
|
||||
"run_module",
|
||||
side_effect=[None, shared_media],
|
||||
) as run_module, patch(
|
||||
"app.chain.MediaRecognizeShareHelper.query",
|
||||
return_value={"type": "tv", "tmdbid": 200, "season": 1},
|
||||
) as query_mock, patch(
|
||||
"app.chain.MediaRecognizeShareHelper.to_recognize_params",
|
||||
return_value={
|
||||
"mtype": MediaType.TV,
|
||||
"tmdbid": 200,
|
||||
"doubanid": None,
|
||||
"bangumiid": None,
|
||||
"season": 1,
|
||||
},
|
||||
), patch(
|
||||
"app.chain.MediaRecognizeShareHelper.report",
|
||||
return_value=False,
|
||||
):
|
||||
result = self.chain.recognize_media(meta=meta, cache=False)
|
||||
|
||||
self.assertIs(result, shared_media)
|
||||
self.assertEqual(run_module.call_count, 2)
|
||||
query_mock.assert_called_once_with(meta=meta, mtype=None)
|
||||
second_call = run_module.call_args_list[1]
|
||||
self.assertEqual(second_call.kwargs["tmdbid"], 200)
|
||||
self.assertEqual(second_call.kwargs["mtype"], MediaType.TV)
|
||||
self.assertEqual(meta.begin_season, 1)
|
||||
|
||||
def test_async_query_shared_result_when_local_recognize_failed(self):
|
||||
"""
|
||||
异步识别失败后也应回查共享识别结果
|
||||
"""
|
||||
meta = self._build_meta("测试异步剧集")
|
||||
shared_media = MediaInfo(title="测试异步剧集", year="2025", tmdb_id=300, type=MediaType.TV)
|
||||
async_run_module = AsyncMock(side_effect=[None, shared_media])
|
||||
|
||||
async def runner():
|
||||
with patch.object(
|
||||
self.chain,
|
||||
"async_run_module",
|
||||
async_run_module,
|
||||
), patch(
|
||||
"app.chain.MediaRecognizeShareHelper.async_query",
|
||||
AsyncMock(return_value={"type": "tv", "tmdbid": 300, "season": 2}),
|
||||
) as query_mock, patch(
|
||||
"app.chain.MediaRecognizeShareHelper.to_recognize_params",
|
||||
return_value={
|
||||
"mtype": MediaType.TV,
|
||||
"tmdbid": 300,
|
||||
"doubanid": None,
|
||||
"bangumiid": None,
|
||||
"season": 2,
|
||||
},
|
||||
), patch(
|
||||
"app.chain.MediaRecognizeShareHelper.async_report",
|
||||
AsyncMock(return_value=False),
|
||||
):
|
||||
result = await self.chain.async_recognize_media(meta=meta, cache=False)
|
||||
return result, query_mock
|
||||
|
||||
result, query_mock = asyncio.run(runner())
|
||||
|
||||
self.assertIs(result, shared_media)
|
||||
self.assertEqual(async_run_module.await_count, 2)
|
||||
query_mock.assert_awaited_once_with(meta=meta, mtype=None)
|
||||
self.assertEqual(meta.begin_season, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user