From 9b23265c3b08f6ea218d6e816149dcad2efaaaee Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 15 May 2026 22:43:40 +0800 Subject: [PATCH] feat(search): cache and expose last search parameters for replay and context retrieval - Add methods to save and retrieve last search parameters in SearchChain - Persist search params alongside results for replayable search context - Add /last/context endpoint to fetch last search results and parameters - Update tests to cover search param caching logic - Allow images.tmdb.org in SECURITY_IMAGE_DOMAINS --- app/api/endpoints/search.py | 17 ++++ app/chain/search.py | 146 ++++++++++++++++++++++++++++++ app/core/config.py | 1 + tests/test_search_ai_recommend.py | 37 +++++++- 4 files changed, 200 insertions(+), 1 deletion(-) diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index d082edbe..d89329c1 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -144,6 +144,23 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: return [torrent.to_dict() for torrent in torrents] +@router.get("/last/context", summary="查询上次搜索上下文", response_model=schemas.Response) +async def search_latest_context(_: schemas.TokenPayload = Depends(verify_token)) -> Any: + """ + 查询上次搜索结果及其对应的搜索参数。 + """ + search_chain = SearchChain() + torrents = await search_chain.async_last_search_results() or [] + params = await search_chain.async_last_search_params() or {} + return schemas.Response( + success=True, + data={ + "params": params, + "results": [torrent.to_dict() for torrent in torrents], + }, + ) + + @router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源") async def search_by_id_stream( request: Request, diff --git a/app/chain/search.py b/app/chain/search.py index fcf01059..4298e628 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -33,6 +33,7 @@ class SearchChain(ChainBase): """ __result_temp_file = "__search_result__" + __search_params_temp_file = "__search_params__" __ai_indices_cache_file = "__ai_recommend_indices__" _ai_recommend_running = False @@ -121,6 +122,115 @@ class SearchChain(ChainBase): state._ai_recommend_error = None self.remove_cache(self.__ai_indices_cache_file) + @staticmethod + def _build_search_keyword( + tmdbid: Optional[int] = None, doubanid: Optional[str] = None + ) -> str: + """ + 根据媒体ID生成可重放的搜索关键字。 + """ + if tmdbid is not None: + return f"tmdb:{tmdbid}" + if doubanid: + return f"douban:{doubanid}" + return "" + + @staticmethod + def _stringify_sites(sites: Optional[List[int]]) -> str: + """ + 将站点ID列表转换为前端可直接复用的查询字符串。 + """ + return ",".join(str(site) for site in sites) if sites else "" + + @staticmethod + def _normalize_search_params(params: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]: + """ + 规范化上次搜索参数,供前端结果页重新搜索使用。 + """ + if not isinstance(params, dict): + return None + + normalized = { + "keyword": str(params.get("keyword") or ""), + "type": str(params.get("type") or ""), + "area": str(params.get("area") or ""), + "title": str(params.get("title") or ""), + "year": str(params.get("year") or ""), + "season": str(params.get("season") or ""), + "sites": str(params.get("sites") or ""), + } + return normalized if normalized["keyword"] else None + + def save_last_search_params( + self, + *, + keyword: Optional[str], + mtype: Optional[MediaType] = None, + area: Optional[str] = "title", + title: Optional[str] = None, + year: Optional[str] = None, + season: Optional[int] = None, + sites: Optional[List[int]] = None, + ) -> None: + """ + 保存最后一次资源搜索参数。 + """ + params = self._normalize_search_params( + { + "keyword": keyword, + "type": mtype.value if isinstance(mtype, MediaType) else mtype, + "area": area, + "title": title, + "year": year, + "season": season, + "sites": self._stringify_sites(sites), + } + ) + if params: + self.save_cache(params, self.__search_params_temp_file) + + async def async_save_last_search_params( + self, + *, + keyword: Optional[str], + mtype: Optional[MediaType] = None, + area: Optional[str] = "title", + title: Optional[str] = None, + year: Optional[str] = None, + season: Optional[int] = None, + sites: Optional[List[int]] = None, + ) -> None: + """ + 异步保存最后一次资源搜索参数。 + """ + params = self._normalize_search_params( + { + "keyword": keyword, + "type": mtype.value if isinstance(mtype, MediaType) else mtype, + "area": area, + "title": title, + "year": year, + "season": season, + "sites": self._stringify_sites(sites), + } + ) + if params: + await self.async_save_cache(params, self.__search_params_temp_file) + + def last_search_params(self) -> Optional[Dict[str, str]]: + """ + 获取上次搜索使用的参数。 + """ + return self._normalize_search_params(self.load_cache(self.__search_params_temp_file)) + + async def async_last_search_params(self) -> Optional[Dict[str, str]]: + """ + 异步获取上次搜索使用的参数。 + """ + return self._normalize_search_params( + await self.async_load_cache(self.__search_params_temp_file) + ) + @staticmethod def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]: """ @@ -337,6 +447,13 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + self.save_last_search_params( + keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid), + mtype=mtype, + area=area, + season=season, + sites=sites, + ) mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') @@ -365,6 +482,11 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + self.save_last_search_params( + keyword=title, + area="title", + sites=sites, + ) if title: logger.info(f'开始搜索资源,关键词:{title} ...') else: @@ -414,6 +536,13 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + await self.async_save_last_search_params( + keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid), + mtype=mtype, + area=area, + season=season, + sites=sites, + ) mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') @@ -442,6 +571,11 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + await self.async_save_last_search_params( + keyword=title, + area="title", + sites=sites, + ) if title: logger.info(f'开始搜索资源,关键词:{title} ...') else: @@ -472,6 +606,11 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + await self.async_save_last_search_params( + keyword=title, + area="title", + sites=sites, + ) if title: logger.info(f'开始渐进式搜索资源,关键词:{title} ...') else: @@ -518,6 +657,13 @@ class SearchChain(ChainBase): """ if cache_local: self.cancel_ai_recommend() + await self.async_save_last_search_params( + keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid), + mtype=mtype, + area=area, + season=season, + sites=sites, + ) mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') diff --git a/app/core/config.py b/app/core/config.py index 381bd043..22f0432a 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -486,6 +486,7 @@ class ConfigModel(BaseModel): SECURITY_IMAGE_DOMAINS: list = Field( default=[ "image.tmdb.org", + "images.tmdb.org", "static-mdb.v.geilijiasu.com", "bing.com", "doubanio.com", diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index 2faad8a8..4271413b 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -28,6 +28,7 @@ from app.agent.tools.factory import MoviePilotToolFactory from app.agent import ReplyMode from app.chain.search import SearchChain from app.core.config import settings +from app.schemas.types import MediaType def _make_result(title: str, size: int, seeders: int): @@ -153,11 +154,45 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(1, len(results)) self.assertEqual(["__ai_recommend_indices__"], removed) - self.assertEqual("__search_result__", cached[0][0]) + self.assertTrue(any(filename == "__search_result__" for filename, _ in cached)) + self.assertTrue(any(filename == "__search_params__" for filename, _ in cached)) self.assertIsNone(SearchChain._current_recommend_request_hash) self.assertIsNone(SearchChain._ai_recommend_result) self.assertIsNone(SearchChain._ai_recommend_error) + def test_search_by_id_caches_replayable_search_params_when_caching(self): + chain = self._make_chain() + cached = [] + chain.save_cache = lambda cache, filename: cached.append((filename, cache)) + chain.recognize_media = lambda **_kwargs: SimpleNamespace(title="Test") + chain.process = lambda **_kwargs: [SimpleNamespace(title="Result")] + + chain.search_by_id( + tmdbid=123, + mtype=MediaType.MOVIE, + area="title", + season=2, + sites=[1, 3], + cache_local=True, + ) + + self.assertIn( + ( + "__search_params__", + { + "keyword": "tmdb:123", + "type": "电影", + "area": "title", + "title": "", + "year": "", + "season": "2", + "sites": "1,3", + }, + ), + cached, + ) + self.assertTrue(any(filename == "__search_result__" for filename, _ in cached)) + def test_tool_factory_excludes_message_tools_when_disabled(self): with patch( "app.agent.tools.factory.PluginManager.get_plugin_agent_tools",