feat(search): cache and expose last search parameters for replay and context retrieval

- Add methods to save and retrieve last search parameters in SearchChain
- Persist search params alongside results for replayable search context
- Add /last/context endpoint to fetch last search results and parameters
- Update tests to cover search param caching logic
- Allow images.tmdb.org in SECURITY_IMAGE_DOMAINS
This commit is contained in:
jxxghp
2026-05-15 22:43:40 +08:00
parent 1f49f9b454
commit 9b23265c3b
4 changed files with 200 additions and 1 deletions

View File

@@ -144,6 +144,23 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
return [torrent.to_dict() for torrent in torrents]
@router.get("/last/context", summary="查询上次搜索上下文", response_model=schemas.Response)
async def search_latest_context(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询上次搜索结果及其对应的搜索参数。
"""
search_chain = SearchChain()
torrents = await search_chain.async_last_search_results() or []
params = await search_chain.async_last_search_params() or {}
return schemas.Response(
success=True,
data={
"params": params,
"results": [torrent.to_dict() for torrent in torrents],
},
)
@router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源")
async def search_by_id_stream(
request: Request,

View File

@@ -33,6 +33,7 @@ class SearchChain(ChainBase):
"""
__result_temp_file = "__search_result__"
__search_params_temp_file = "__search_params__"
__ai_indices_cache_file = "__ai_recommend_indices__"
_ai_recommend_running = False
@@ -121,6 +122,115 @@ class SearchChain(ChainBase):
state._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
@staticmethod
def _build_search_keyword(
tmdbid: Optional[int] = None, doubanid: Optional[str] = None
) -> str:
"""
根据媒体ID生成可重放的搜索关键字。
"""
if tmdbid is not None:
return f"tmdb:{tmdbid}"
if doubanid:
return f"douban:{doubanid}"
return ""
@staticmethod
def _stringify_sites(sites: Optional[List[int]]) -> str:
"""
将站点ID列表转换为前端可直接复用的查询字符串。
"""
return ",".join(str(site) for site in sites) if sites else ""
@staticmethod
def _normalize_search_params(params: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]:
"""
规范化上次搜索参数,供前端结果页重新搜索使用。
"""
if not isinstance(params, dict):
return None
normalized = {
"keyword": str(params.get("keyword") or ""),
"type": str(params.get("type") or ""),
"area": str(params.get("area") or ""),
"title": str(params.get("title") or ""),
"year": str(params.get("year") or ""),
"season": str(params.get("season") or ""),
"sites": str(params.get("sites") or ""),
}
return normalized if normalized["keyword"] else None
def save_last_search_params(
self,
*,
keyword: Optional[str],
mtype: Optional[MediaType] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
sites: Optional[List[int]] = None,
) -> None:
"""
保存最后一次资源搜索参数。
"""
params = self._normalize_search_params(
{
"keyword": keyword,
"type": mtype.value if isinstance(mtype, MediaType) else mtype,
"area": area,
"title": title,
"year": year,
"season": season,
"sites": self._stringify_sites(sites),
}
)
if params:
self.save_cache(params, self.__search_params_temp_file)
async def async_save_last_search_params(
self,
*,
keyword: Optional[str],
mtype: Optional[MediaType] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
sites: Optional[List[int]] = None,
) -> None:
"""
异步保存最后一次资源搜索参数。
"""
params = self._normalize_search_params(
{
"keyword": keyword,
"type": mtype.value if isinstance(mtype, MediaType) else mtype,
"area": area,
"title": title,
"year": year,
"season": season,
"sites": self._stringify_sites(sites),
}
)
if params:
await self.async_save_cache(params, self.__search_params_temp_file)
def last_search_params(self) -> Optional[Dict[str, str]]:
"""
获取上次搜索使用的参数。
"""
return self._normalize_search_params(self.load_cache(self.__search_params_temp_file))
async def async_last_search_params(self) -> Optional[Dict[str, str]]:
"""
异步获取上次搜索使用的参数。
"""
return self._normalize_search_params(
await self.async_load_cache(self.__search_params_temp_file)
)
@staticmethod
def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]:
"""
@@ -337,6 +447,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
self.save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -365,6 +482,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
self.save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -414,6 +536,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -442,6 +571,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -472,6 +606,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
else:
@@ -518,6 +657,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')

View File

@@ -486,6 +486,7 @@ class ConfigModel(BaseModel):
SECURITY_IMAGE_DOMAINS: list = Field(
default=[
"image.tmdb.org",
"images.tmdb.org",
"static-mdb.v.geilijiasu.com",
"bing.com",
"doubanio.com",

View File

@@ -28,6 +28,7 @@ from app.agent.tools.factory import MoviePilotToolFactory
from app.agent import ReplyMode
from app.chain.search import SearchChain
from app.core.config import settings
from app.schemas.types import MediaType
def _make_result(title: str, size: int, seeders: int):
@@ -153,11 +154,45 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
self.assertEqual(1, len(results))
self.assertEqual(["__ai_recommend_indices__"], removed)
self.assertEqual("__search_result__", cached[0][0])
self.assertTrue(any(filename == "__search_result__" for filename, _ in cached))
self.assertTrue(any(filename == "__search_params__" for filename, _ in cached))
self.assertIsNone(SearchChain._current_recommend_request_hash)
self.assertIsNone(SearchChain._ai_recommend_result)
self.assertIsNone(SearchChain._ai_recommend_error)
def test_search_by_id_caches_replayable_search_params_when_caching(self):
chain = self._make_chain()
cached = []
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
chain.recognize_media = lambda **_kwargs: SimpleNamespace(title="Test")
chain.process = lambda **_kwargs: [SimpleNamespace(title="Result")]
chain.search_by_id(
tmdbid=123,
mtype=MediaType.MOVIE,
area="title",
season=2,
sites=[1, 3],
cache_local=True,
)
self.assertIn(
(
"__search_params__",
{
"keyword": "tmdb:123",
"type": "电影",
"area": "title",
"title": "",
"year": "",
"season": "2",
"sites": "1,3",
},
),
cached,
)
self.assertTrue(any(filename == "__search_result__" for filename, _ in cached))
def test_tool_factory_excludes_message_tools_when_disabled(self):
with patch(
"app.agent.tools.factory.PluginManager.get_plugin_agent_tools",