diff --git a/app/modules/themoviedb/tmdbv3api/tmdb.py b/app/modules/themoviedb/tmdbv3api/tmdb.py index e98a07d2..09836fdc 100644 --- a/app/modules/themoviedb/tmdbv3api/tmdb.py +++ b/app/modules/themoviedb/tmdbv3api/tmdb.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) class TMDb(object): + _RESPONSE_SNAPSHOT_MARKER = "__mp_tmdb_response_snapshot__" def __init__(self, session=None, language=None): self._api_key = settings.TMDB_API_KEY @@ -114,7 +115,7 @@ class TMDb(object): req = self._req.post_res(url, data=data, json=json) if req is None: raise TMDbException("无法连接TheMovieDb,请检查网络连接!") - return req + return self._snapshot_response(req) @cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True) async def async_request(self, method, url, data, json, **kwargs): @@ -124,7 +125,28 @@ class TMDb(object): req = await self._async_req.post_res(url, data=data, json=json) if req is None: raise TMDbException("无法连接TheMovieDb,请检查网络连接!") - return req + return self._snapshot_response(req) + + @classmethod + def _snapshot_response(cls, response): + # Redis 不能稳定序列化 requests/httpx 响应对象,缓存里只保留当前流程会用到的数据。 + return { + cls._RESPONSE_SNAPSHOT_MARKER: True, + "headers": dict(response.headers.items()), + "json": response.json(), + } + + @classmethod + def _get_response_headers(cls, response): + if isinstance(response, dict) and response.get(cls._RESPONSE_SNAPSHOT_MARKER): + return response.get("headers") or {} + return response.headers + + @classmethod + def _get_response_json(cls, response): + if isinstance(response, dict) and response.get(cls._RESPONSE_SNAPSHOT_MARKER): + return response.get("json") + return response.json() def cache_clear(self): return self.request.cache_clear() @@ -143,11 +165,15 @@ class TMDb(object): ) def _handle_headers(self, headers): - if "X-RateLimit-Remaining" in headers: - self._remaining = int(headers["X-RateLimit-Remaining"]) + normalized_headers = { + str(key).lower(): value for key, value in dict(headers or {}).items() + } - if "X-RateLimit-Reset" in headers: - self._reset = int(headers["X-RateLimit-Reset"]) + if "x-ratelimit-remaining" in normalized_headers: + self._remaining = int(normalized_headers["x-ratelimit-remaining"]) + + if "x-ratelimit-reset" in normalized_headers: + self._reset = int(normalized_headers["x-ratelimit-reset"]) def _handle_rate_limit(self): if self._remaining < 1: @@ -191,7 +217,7 @@ class TMDb(object): if req is None: return None - self._handle_headers(req.headers) + self._handle_headers(self._get_response_headers(req)) rate_limit_result = self._handle_rate_limit() if rate_limit_result: @@ -199,7 +225,7 @@ class TMDb(object): time.sleep(rate_limit_result) return self._request_obj(action, params, False, method, data, json, key) - json_data = req.json() + json_data = self._get_response_json(req) self._process_json_response(json_data, is_async=False) self._handle_errors(json_data) @@ -219,7 +245,7 @@ class TMDb(object): if req is None: return None - self._handle_headers(req.headers) + self._handle_headers(self._get_response_headers(req)) rate_limit_result = self._handle_rate_limit() if rate_limit_result: @@ -227,7 +253,7 @@ class TMDb(object): await asyncio.sleep(rate_limit_result) return await self._async_request_obj(action, params, False, method, data, json, key) - json_data = req.json() + json_data = self._get_response_json(req) self._process_json_response(json_data, is_async=True) self._handle_errors(json_data) diff --git a/tests/test_tmdb_response_cache.py b/tests/test_tmdb_response_cache.py new file mode 100644 index 00000000..6f9900ab --- /dev/null +++ b/tests/test_tmdb_response_cache.py @@ -0,0 +1,182 @@ +import asyncio +import importlib.util +import pickle +import sys +from contextlib import asynccontextmanager, contextmanager +from functools import wraps +from pathlib import Path +from threading import RLock +from types import ModuleType, SimpleNamespace +from unittest import TestCase + + +TMDB_MODULE_NAME = "app.modules.themoviedb.tmdbv3api.tmdb" +TMDB_FILE_PATH = Path(__file__).resolve().parents[1] / "app/modules/themoviedb/tmdbv3api/tmdb.py" + + +def _ensure_package(name: str) -> ModuleType: + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + return module + + +def _install_tmdb_test_stubs() -> None: + for package_name in [ + "app", + "app.core", + "app.utils", + "app.modules", + "app.modules.themoviedb", + "app.modules.themoviedb.tmdbv3api", + ]: + _ensure_package(package_name) + + cache_module = ModuleType("app.core.cache") + + def cached(*args, **kwargs): + def decorator(func): + if asyncio.iscoroutinefunction(func): + @wraps(func) + async def async_wrapper(*wrapper_args, **wrapper_kwargs): + return await func(*wrapper_args, **wrapper_kwargs) + + return async_wrapper + + @wraps(func) + def wrapper(*wrapper_args, **wrapper_kwargs): + return func(*wrapper_args, **wrapper_kwargs) + + return wrapper + + return decorator + + @contextmanager + def fresh(*args, **kwargs): + yield + + @asynccontextmanager + async def async_fresh(*args, **kwargs): + yield + + cache_module.cached = cached + cache_module.fresh = fresh + cache_module.async_fresh = async_fresh + sys.modules[cache_module.__name__] = cache_module + + config_module = ModuleType("app.core.config") + config_module.settings = SimpleNamespace( + TMDB_API_KEY="dummy-key", + TMDB_LOCALE="en-US", + PROXY=None, + TMDB_API_DOMAIN="example.com", + NORMAL_USER_AGENT="MoviePilot-Test-UA", + CONF=SimpleNamespace(tmdb=8, meta=60), + ) + sys.modules[config_module.__name__] = config_module + + http_module = ModuleType("app.utils.http") + + class RequestUtils: + def __init__(self, *args, **kwargs): + pass + + def get_res(self, *args, **kwargs): # pragma: no cover - 测试中会替换 + raise NotImplementedError + + def post_res(self, *args, **kwargs): # pragma: no cover - 测试中会替换 + raise NotImplementedError + + class AsyncRequestUtils: + def __init__(self, *args, **kwargs): + pass + + async def get_res(self, *args, **kwargs): # pragma: no cover - 测试中会替换 + raise NotImplementedError + + async def post_res(self, *args, **kwargs): # pragma: no cover - 测试中会替换 + raise NotImplementedError + + http_module.RequestUtils = RequestUtils + http_module.AsyncRequestUtils = AsyncRequestUtils + sys.modules[http_module.__name__] = http_module + + exceptions_module = ModuleType("app.modules.themoviedb.tmdbv3api.exceptions") + + class TMDbException(Exception): + pass + + exceptions_module.TMDbException = TMDbException + sys.modules[exceptions_module.__name__] = exceptions_module + + +def _load_tmdb_class(): + _install_tmdb_test_stubs() + sys.modules.pop(TMDB_MODULE_NAME, None) + spec = importlib.util.spec_from_file_location(TMDB_MODULE_NAME, TMDB_FILE_PATH) + module = importlib.util.module_from_spec(spec) + sys.modules[TMDB_MODULE_NAME] = module + assert spec and spec.loader + spec.loader.exec_module(module) + return module.TMDb + + +TMDb = _load_tmdb_class() + + +class _FakeResponse: + def __init__(self, payload: dict, headers: dict): + self._payload = payload + self.headers = headers + self._lock = RLock() + + def json(self): + return self._payload + + +class TmdbResponseCacheTest(TestCase): + def test_request_returns_pickleable_snapshot(self): + tmdb = TMDb() + response = _FakeResponse( + payload={"id": 1, "page": 2}, + headers={"X-RateLimit-Remaining": "39", "X-RateLimit-Reset": "1234567890"}, + ) + tmdb._req.get_res = lambda *args, **kwargs: response + + result = TMDb.request.__wrapped__(tmdb, "GET", "https://example.com", None, None) + + self.assertTrue(result[TMDb._RESPONSE_SNAPSHOT_MARKER]) + self.assertEqual(result["json"], {"id": 1, "page": 2}) + self.assertEqual(result["headers"]["X-RateLimit-Remaining"], "39") + pickle.dumps(result) + + def test_async_request_returns_pickleable_snapshot(self): + tmdb = TMDb() + response = _FakeResponse( + payload={"id": 2, "page": 3}, + headers={"x-ratelimit-remaining": "38", "x-ratelimit-reset": "1234567891"}, + ) + + async def _fake_get_res(*args, **kwargs): + return response + + tmdb._async_req.get_res = _fake_get_res + + result = asyncio.run( + TMDb.async_request.__wrapped__(tmdb, "GET", "https://example.com", None, None) + ) + + self.assertTrue(result[TMDb._RESPONSE_SNAPSHOT_MARKER]) + self.assertEqual(result["json"], {"id": 2, "page": 3}) + self.assertEqual(result["headers"]["x-ratelimit-remaining"], "38") + pickle.dumps(result) + + def test_handle_headers_accepts_snapshot_headers(self): + tmdb = TMDb() + + tmdb._handle_headers({"x-ratelimit-remaining": "7", "x-ratelimit-reset": "99"}) + + self.assertEqual(tmdb._remaining, 7) + self.assertEqual(tmdb._reset, 99)