diff --git a/app/agent/__init__.py b/app/agent/__init__.py index c267a0b6..e3315a28 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -4,7 +4,7 @@ import re import traceback import uuid from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum from typing import Any, Callable, Dict, List, Optional @@ -966,6 +966,11 @@ class AgentManager: self._session_queues: Dict[str, asyncio.Queue] = {} # 每个会话的worker任务 self._session_workers: Dict[str, asyncio.Task] = {} + # 每个会话最后活动时间,用于回收空闲 Agent 实例 + self._session_last_used: Dict[str, tuple[str, datetime]] = {} + self._idle_cleanup_task: Optional[asyncio.Task] = None + self._idle_session_ttl = timedelta(hours=24) + self._idle_cleanup_interval = 60 * 60 def get_session_status(self, session_id: str) -> dict[str, Any]: """获取会话当前模型与 token 使用状态。""" @@ -998,33 +1003,85 @@ class AgentManager: ) return status - @staticmethod - async def initialize(): + async def initialize(self): """ 初始化管理器 """ memory_manager.initialize() + if self._idle_cleanup_task and not self._idle_cleanup_task.done(): + return + self._idle_cleanup_task = asyncio.create_task(self._cleanup_idle_sessions()) async def close(self): """ 关闭管理器 """ + if self._idle_cleanup_task: + self._idle_cleanup_task.cancel() + try: + await self._idle_cleanup_task + except asyncio.CancelledError: + pass + self._idle_cleanup_task = None await memory_manager.close() # 取消所有会话worker - for task in self._session_workers.values(): + for task in list(self._session_workers.values()): task.cancel() # 等待所有worker结束 - for session_id, task in self._session_workers.items(): + for session_id, task in list(self._session_workers.items()): try: await task except asyncio.CancelledError: pass self._session_workers.clear() self._session_queues.clear() - for agent in self.active_agents.values(): + self._session_last_used.clear() + for agent in list(self.active_agents.values()): await agent.cleanup() self.active_agents.clear() + def _record_session_activity(self, session_id: str, user_id: str) -> None: + """ + 记录会话最近活动时间,供空闲会话清理任务判断是否可释放资源。 + """ + self._session_last_used[session_id] = (user_id, datetime.now()) + + def _is_session_busy(self, session_id: str) -> bool: + """ + 判断会话是否仍有正在执行的 worker 或待处理消息,避免误清理活跃会话。 + """ + worker = self._session_workers.get(session_id) + if worker and not worker.done(): + return True + queue = self._session_queues.get(session_id) + return bool(queue and not queue.empty()) + + def _expired_idle_sessions(self) -> list[tuple[str, str]]: + """ + 收集已经超过空闲时间且当前不忙的会话。 + """ + expire_before = datetime.now() - self._idle_session_ttl + expired = [] + for session_id, (user_id, last_used) in list(self._session_last_used.items()): + if last_used < expire_before and not self._is_session_busy(session_id): + expired.append((session_id, user_id)) + return expired + + async def _cleanup_idle_sessions(self) -> None: + """ + 周期性清理长时间没有新消息的 Agent 会话,避免长期运行后实例持续累积。 + """ + while True: + try: + await asyncio.sleep(self._idle_cleanup_interval) + for session_id, user_id in self._expired_idle_sessions(): + await self.clear_session(session_id=session_id, user_id=user_id) + logger.info(f"已清理空闲Agent会话: session_id={session_id}") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理空闲Agent会话失败: {e}") + async def process_message( self, session_id: str, @@ -1056,6 +1113,7 @@ class AgentManager: original_chat_id=original_chat_id, reply_mode=reply_mode, ) + self._record_session_activity(session_id, user_id) # 获取或创建会话队列 if session_id not in self._session_queues: @@ -1221,6 +1279,7 @@ class AgentManager: """ 清空会话 """ + self._session_last_used.pop(session_id, None) # 取消该会话的worker if session_id in self._session_workers: self._session_workers[session_id].cancel() @@ -1228,7 +1287,7 @@ class AgentManager: await self._session_workers[session_id] except asyncio.CancelledError: pass - await self._session_workers.pop(session_id, None) + self._session_workers.pop(session_id, None) # 清理队列 self._session_queues.pop(session_id, None) diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py index e67a852b..963165b2 100644 --- a/app/agent/llm/provider.py +++ b/app/agent/llm/provider.py @@ -105,6 +105,7 @@ class LLMProviderManager(metaclass=Singleton): _MODELS_DEV_URL = "https://models.dev/api.json" _MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json") _MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60 + _AUTH_SESSION_DONE_RETENTION = 300 _CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" _CHATGPT_ISSUER = "https://auth.openai.com" _CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" @@ -183,6 +184,33 @@ class LLMProviderManager(metaclass=Singleton): Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json" ) + def _cleanup_auth_sessions_locked(self, now: Optional[float] = None) -> None: + """ + 清理过期或已完成一段时间的临时授权会话。 + + 调用方必须已经持有 `_lock`,这样 `_pending_sessions` 与 + `_oauth_state_index` 能保持一致,避免 state 残留。 + """ + now = time.time() if now is None else now + expired_session_ids = [] + for session_id, session in self._pending_sessions.items(): + expires_at = session.expires_at or session.created_at + 600 + if session.status == "pending": + if expires_at <= now: + expired_session_ids.append(session_id) + elif expires_at + self._AUTH_SESSION_DONE_RETENTION <= now: + expired_session_ids.append(session_id) + + if not expired_session_ids: + return + + expired_session_ids_set = set(expired_session_ids) + for session_id in expired_session_ids: + self._pending_sessions.pop(session_id, None) + for state, session_id in list(self._oauth_state_index.items()): + if session_id in expired_session_ids_set: + self._oauth_state_index.pop(state, None) + @staticmethod def _builtin_provider_specs() -> tuple[ProviderSpec, ...]: """ @@ -2001,6 +2029,7 @@ class LLMProviderManager(metaclass=Singleton): } ) with self._lock: + self._cleanup_auth_sessions_locked() self._pending_sessions[session.session_id] = session self._oauth_state_index[state] = session.session_id return { @@ -2035,6 +2064,7 @@ class LLMProviderManager(metaclass=Singleton): } ) with self._lock: + self._cleanup_auth_sessions_locked() self._pending_sessions[session.session_id] = session return { "session_id": session.session_id, @@ -2073,6 +2103,7 @@ class LLMProviderManager(metaclass=Singleton): } ) with self._lock: + self._cleanup_auth_sessions_locked() self._pending_sessions[session.session_id] = session return { "session_id": session.session_id, @@ -2089,6 +2120,7 @@ class LLMProviderManager(metaclass=Singleton): def get_session_status(self, session_id: str) -> dict[str, Any]: """读取临时授权会话状态。""" with self._lock: + self._cleanup_auth_sessions_locked() session = self._pending_sessions.get(session_id) if not session: raise LLMProviderAuthError("授权会话不存在或已过期") @@ -2135,6 +2167,7 @@ class LLMProviderManager(metaclass=Singleton): if error: message = error_description or error with self._lock: + self._cleanup_auth_sessions_locked() session_id = self._oauth_state_index.pop(state or "", None) if session_id and session_id in self._pending_sessions: self._mark_session_error(self._pending_sessions[session_id], message) @@ -2144,6 +2177,7 @@ class LLMProviderManager(metaclass=Singleton): return False, "缺少授权码或 state 参数" with self._lock: + self._cleanup_auth_sessions_locked() session_id = self._oauth_state_index.pop(state, None) session = self._pending_sessions.get(session_id or "") @@ -2186,6 +2220,7 @@ class LLMProviderManager(metaclass=Singleton): 前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。 """ with self._lock: + self._cleanup_auth_sessions_locked() session = self._pending_sessions.get(session_id) if not session: raise LLMProviderAuthError("授权会话不存在或已过期") diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index 3fba4f6e..ad9782b8 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -27,6 +27,8 @@ class MemoryManager: 初始化记忆管理器 """ try: + if self.cleanup_task and not self.cleanup_task.done(): + return # 启动内存缓存清理任务(Redis通过TTL自动过期) self.cleanup_task = asyncio.create_task( self._cleanup_expired_memories() @@ -46,6 +48,7 @@ class MemoryManager: await self.cleanup_task except asyncio.CancelledError: pass + self.cleanup_task = None logger.info("对话记忆管理器已关闭") diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index 47396f33..c0e29f93 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -15,6 +15,7 @@ from app.db.models import User from app.db.models.message import Message from app.db.user_oper import get_current_active_superuser from app.helper.service import ServiceConfigHelper +from app.helper.webpush import is_webpush_subscription_gone from app.log import logger from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.schemas.types import MessageChannel @@ -218,8 +219,7 @@ async def subscribe( 客户端webpush通知订阅 """ subinfo = subscription.model_dump() - if subinfo not in global_vars.get_subscriptions(): - global_vars.push_subscription(subinfo) + global_vars.push_subscription(subinfo) logger.debug(f"通知订阅成功: {subinfo}") return schemas.Response(success=True) @@ -244,5 +244,7 @@ def send_notification( ) except WebPushException as err: logger.error(f"WebPush发送失败: {str(err)}") + if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub): + logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}") continue return schemas.Response(success=True) diff --git a/app/chain/media.py b/app/chain/media.py index 5e17a716..956cf6cc 100644 --- a/app/chain/media.py +++ b/app/chain/media.py @@ -84,13 +84,12 @@ class ScrapingOption: class ScrapingConfig: """媒体刮削配置""" - _policies: dict[tuple[str], ScrapingOption] = {} - def __init__(self, config_dict: dict[str, str] = None): """ 初始化配置对象 :param config_dict: 用户配置字典(扁平化格式),为 None 时使用默认配置 """ + self._policies: dict[tuple[str, str], ScrapingOption] = {} # 合并用户配置和默认配置 if config_dict is None: config_dict = {} diff --git a/app/chain/message.py b/app/chain/message.py index 6b12d9ec..d3825823 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -47,6 +47,36 @@ class MessageChain(ChainBase): # 会话超时时间(分钟) _session_timeout_minutes: int = 24 * 60 + @staticmethod + def _schedule_agent_session_clear(session_id: str, userid: Union[str, int]) -> None: + """ + 异步调度 Agent 会话清理,避免同步消息链阻塞在模型资源释放上。 + """ + if not session_id: + return + clear_task = None + try: + clear_task = agent_manager.clear_session(session_id=session_id, user_id=str(userid)) + asyncio.run_coroutine_threadsafe( + clear_task, + global_vars.loop, + ) + except Exception as e: + if clear_task: + clear_task.close() + logger.warning(f"调度清理智能体会话失败: {e}") + + def _cleanup_expired_user_sessions(self, current_time: datetime) -> None: + """ + 清理超过复用窗口的用户会话映射,并同步释放旧 Agent 实例。 + """ + timeout = timedelta(minutes=self._session_timeout_minutes) + for userid, (session_id, last_time) in list(self._user_sessions.items()): + if current_time - last_time <= timeout: + continue + self._user_sessions.pop(userid, None) + self._schedule_agent_session_clear(session_id, userid) + @dataclass class _ProcessingStatus: channel: MessageChannel @@ -919,6 +949,7 @@ class MessageChain(ChainBase): 如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID """ current_time = datetime.now() + self._cleanup_expired_user_sessions(current_time) # 检查用户是否有已存在的会话 if userid in self._user_sessions: @@ -946,6 +977,9 @@ class MessageChain(ChainBase): """ 将用户会话绑定到指定的 session_id,并刷新最后活动时间。 """ + old_session = self._user_sessions.get(userid) + if old_session and old_session[0] != session_id: + self._schedule_agent_session_clear(old_session[0], userid) self._user_sessions[userid] = (session_id, datetime.now()) def _record_user_message( @@ -1005,14 +1039,18 @@ class MessageChain(ChainBase): # 如果有会话ID,同时清除智能体的会话记忆 if session_id: + clear_task = None try: + clear_task = agent_manager.clear_session( + session_id=session_id, user_id=str(userid) + ) asyncio.run_coroutine_threadsafe( - agent_manager.clear_session( - session_id=session_id, user_id=str(userid) - ), + clear_task, global_vars.loop, ) except Exception as e: + if clear_task: + clear_task.close() logger.warning(f"清除智能体会话记忆失败: {e}") self.post_message( diff --git a/app/core/config.py b/app/core/config.py index b31a0544..832486d4 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1130,6 +1130,8 @@ class GlobalVar(object): STOP_EVENT: threading.Event = threading.Event() # webpush订阅 SUBSCRIPTIONS: List[dict] = [] + # webpush订阅读写锁 + SUBSCRIPTIONS_LOCK: threading.Lock = threading.Lock() # 需应急停止的工作流 EMERGENCY_STOP_WORKFLOWS: List[int] = [] # 需应急停止文件整理 @@ -1169,13 +1171,37 @@ class GlobalVar(object): """ 获取webpush订阅 """ - return self.SUBSCRIPTIONS + with self.SUBSCRIPTIONS_LOCK: + return list(self.SUBSCRIPTIONS) def push_subscription(self, subscription: dict): """ - 添加webpush订阅 + 添加或更新webpush订阅。 """ - self.SUBSCRIPTIONS.append(subscription) + endpoint = subscription.get("endpoint") if subscription else None + if not endpoint: + return + with self.SUBSCRIPTIONS_LOCK: + for index, current in enumerate(self.SUBSCRIPTIONS): + if current.get("endpoint") == endpoint: + self.SUBSCRIPTIONS[index] = subscription + return + self.SUBSCRIPTIONS.append(subscription) + + def remove_subscription(self, subscription: dict) -> bool: + """ + 根据 endpoint 移除webpush订阅,返回是否实际删除。 + """ + endpoint = subscription.get("endpoint") if subscription else None + if not endpoint: + return False + with self.SUBSCRIPTIONS_LOCK: + before_count = len(self.SUBSCRIPTIONS) + self.SUBSCRIPTIONS[:] = [ + current for current in self.SUBSCRIPTIONS + if current.get("endpoint") != endpoint + ] + return len(self.SUBSCRIPTIONS) != before_count def stop_workflow(self, workflow_id: int): """ diff --git a/app/helper/webpush.py b/app/helper/webpush.py new file mode 100644 index 00000000..c45df3bb --- /dev/null +++ b/app/helper/webpush.py @@ -0,0 +1,12 @@ +from typing import Any + +from pywebpush import WebPushException + + +def is_webpush_subscription_gone(error: WebPushException) -> bool: + """ + 判断 WebPush 订阅是否已经在浏览器或推送服务侧失效。 + """ + response: Any = getattr(error, "response", None) + status_code = getattr(response, "status_code", None) or getattr(response, "status", None) + return status_code in {404, 410} diff --git a/app/modules/filemanager/storages/rclone.py b/app/modules/filemanager/storages/rclone.py index c4d34ade..5613ea68 100644 --- a/app/modules/filemanager/storages/rclone.py +++ b/app/modules/filemanager/storages/rclone.py @@ -2,6 +2,7 @@ import json import subprocess import threading import time +from collections import OrderedDict from pathlib import Path from typing import Optional, List, Union @@ -13,10 +14,24 @@ from app.schemas.types import StorageSchema from app.utils.string import StringUtils from app.utils.system import SystemUtils -_folder_locks: dict[str, threading.Lock] = {} +_MAX_FOLDER_LOCKS = 4096 +_folder_locks: OrderedDict[str, threading.Lock] = OrderedDict() _folder_locks_guard = threading.Lock() +def _evict_unused_folder_locks_locked() -> None: + """ + 在持有全局锁表互斥锁时淘汰旧路径锁,避免大量不同目录导致锁表无限增长。 + """ + while len(_folder_locks) >= _MAX_FOLDER_LOCKS: + for key, lock in list(_folder_locks.items()): + if not lock.locked(): + _folder_locks.pop(key, None) + break + else: + break + + class Rclone(StorageBase): """ rclone相关操作 @@ -144,9 +159,14 @@ class Rclone(StorageBase): """ normalized = Rclone.__normalize_remote_path(path) with _folder_locks_guard: - if normalized not in _folder_locks: - _folder_locks[normalized] = threading.Lock() - return _folder_locks[normalized] + lock = _folder_locks.get(normalized) + if lock: + _folder_locks.move_to_end(normalized) + return lock + _evict_unused_folder_locks_locked() + lock = threading.Lock() + _folder_locks[normalized] = lock + return lock def __wait_for_item( self, path: Path, retries: int = 3, delay: float = 0.2 diff --git a/app/modules/webpush/__init__.py b/app/modules/webpush/__init__.py index 3f06c2fa..3301d6e7 100644 --- a/app/modules/webpush/__init__.py +++ b/app/modules/webpush/__init__.py @@ -4,6 +4,7 @@ from typing import Union, Tuple from pywebpush import webpush, WebPushException from app.core.config import global_vars, settings +from app.helper.webpush import is_webpush_subscription_gone from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.schemas import Notification @@ -97,6 +98,8 @@ class WebPushModule(_ModuleBase, _MessageBase): ) except WebPushException as err: logger.error(f"WebPush发送失败: {str(err)}") + if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub): + logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}") except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") diff --git a/tests/test_agent_session_status.py b/tests/test_agent_session_status.py index 0fae74bb..338967ee 100644 --- a/tests/test_agent_session_status.py +++ b/tests/test_agent_session_status.py @@ -1,6 +1,6 @@ import asyncio import unittest -from datetime import datetime +from datetime import datetime, timedelta from types import SimpleNamespace from unittest.mock import patch @@ -8,11 +8,20 @@ from langchain.agents.middleware.types import ModelRequest, ModelResponse from langchain_core.messages import AIMessage from app.agent.middleware.usage import UsageMiddleware +from app.agent import AgentManager from app.chain.message import MessageChain from app.schemas.types import MessageChannel class TestAgentSessionStatus(unittest.TestCase): + def setUp(self): + """清理跨用例共享的用户会话状态。""" + MessageChain._user_sessions.clear() + + def tearDown(self): + """清理测试产生的用户会话状态。""" + MessageChain._user_sessions.clear() + def test_usage_middleware_records_usage_metadata(self): snapshots = [] middleware = UsageMiddleware(on_usage=snapshots.append) @@ -104,3 +113,34 @@ class TestAgentSessionStatus(unittest.TestCase): notification = post_message.call_args.args[0] self.assertEqual(notification.title, "您当前没有活跃的智能体会话") + + def test_get_or_create_session_cleans_expired_session(self): + """用户会话超过复用窗口时应调度清理旧 Agent 会话。""" + chain = MessageChain() + chain._user_sessions.clear() + chain._user_sessions["10001"] = ( + "old-session", + datetime.now() - timedelta(minutes=chain._session_timeout_minutes + 1), + ) + + with patch.object(chain, "_schedule_agent_session_clear") as clear_session: + session_id = chain._get_or_create_session_id("10001") + + self.assertNotEqual(session_id, "old-session") + self.assertEqual(chain._user_sessions["10001"][0], session_id) + clear_session.assert_called_once_with("old-session", "10001") + + def test_agent_manager_collects_idle_sessions(self): + """Agent 管理器应只回收超过空闲窗口且未忙碌的会话。""" + manager = AgentManager() + manager._idle_session_ttl = timedelta(seconds=1) + manager._session_last_used["idle-session"] = ( + "10001", + datetime.now() - timedelta(seconds=2), + ) + manager._session_last_used["fresh-session"] = ("10002", datetime.now()) + + self.assertEqual( + [("idle-session", "10001")], + manager._expired_idle_sessions(), + ) diff --git a/tests/test_llm_provider_registry.py b/tests/test_llm_provider_registry.py index 30b41830..340a3ca6 100644 --- a/tests/test_llm_provider_registry.py +++ b/tests/test_llm_provider_registry.py @@ -39,11 +39,23 @@ _stub_module( TEMP_PATH="/tmp", PROXY_HOST=None, LLM_MAX_CONTEXT_TOKENS=64, + RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME=True, + RMT_MEDIAEXT=[".mkv", ".mp4"], + RMT_SUBEXT=[".srt"], + RMT_AUDIOEXT=[".flac"], ), ) _stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper) _stub_module("app.log", logger=_DummyLogger()) -_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent")) +_stub_module( + "app.schemas.types", + SystemConfigKey=SimpleNamespace( + AIAgentConfig="agent", + CustomReleaseGroups="custom_release_groups", + Customization="customization", + CustomIdentifiers="custom_identifiers", + ), +) provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py" spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path) @@ -54,6 +66,7 @@ spec.loader.exec_module(provider_module) LLMProviderError = provider_module.LLMProviderError LLMProviderManager = provider_module.LLMProviderManager +PendingAuthSession = provider_module.PendingAuthSession class LlmProviderRegistryTest(unittest.TestCase): @@ -612,6 +625,24 @@ class LlmProviderRegistryTest(unittest.TestCase): self.assertEqual(models, []) + def test_expired_auth_session_cleanup_removes_state_index(self): + """过期授权会话应同时移除 session 与 OAuth state 索引。""" + manager = LLMProviderManager() + manager._pending_sessions["session-old"] = PendingAuthSession( + session_id="session-old", + provider_id="chatgpt", + method_id="browser_oauth", + flow_type="oauth", + expires_at=100, + ) + manager._oauth_state_index["state-old"] = "session-old" + + with manager._lock: + manager._cleanup_auth_sessions_locked(now=101) + + self.assertNotIn("session-old", manager._pending_sessions) + self.assertNotIn("state-old", manager._oauth_state_index) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mediascrape.py b/tests/test_mediascrape.py index a80f237d..70e0ed74 100644 --- a/tests/test_mediascrape.py +++ b/tests/test_mediascrape.py @@ -8,7 +8,7 @@ sys.modules['app.db.systemconfig_oper'] = MagicMock() sys.modules['app.db.systemconfig_oper'].SystemConfigOper.return_value.get.return_value = None from app import schemas -from app.chain.media import MediaChain, ScrapingOption +from app.chain.media import MediaChain, ScrapingConfig, ScrapingOption from app.core.context import MediaInfo from app.core.event import Event from app.core.metainfo import MetaInfo @@ -42,6 +42,20 @@ class TestMediaScrapingPaths(unittest.TestCase): self.assertEqual(target_item, parent_item) self.assertEqual(target_path, Path("/movies/avatar.nfo")) + def test_scraping_config_does_not_share_policy_state_between_instances(self): + """刮削配置实例之间不应共享已删除或覆盖过的策略。""" + first_config = ScrapingConfig({"movie_nfo": ScrapingPolicy.SKIP}) + second_config = ScrapingConfig({}) + + self.assertEqual( + ScrapingPolicy.SKIP, + first_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy, + ) + self.assertEqual( + ScrapingPolicy.MISSINGONLY, + second_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy, + ) + def test_movie_dir_nfo_path(self): fileitem = schemas.FileItem(path="/movies/Avatar (2009)", name="Avatar (2009)", type="dir", storage="local") diff --git a/tests/test_rclone_storage.py b/tests/test_rclone_storage.py index a7474e2d..dcb97032 100644 --- a/tests/test_rclone_storage.py +++ b/tests/test_rclone_storage.py @@ -200,6 +200,19 @@ class RcloneStorageTest(unittest.TestCase): self.assertEqual("/Show/", folder.path) run_mock.assert_called_once() + def test_folder_lock_table_evicts_old_unlocked_paths(self): + """路径锁表超过上限时应优先淘汰未占用的旧锁。""" + with patch.object(rclone_module, "_MAX_FOLDER_LOCKS", 2): + first_lock = Rclone._Rclone__get_path_lock(Path("/A")) + second_lock = Rclone._Rclone__get_path_lock(Path("/B")) + third_lock = Rclone._Rclone__get_path_lock(Path("/C")) + + self.assertNotIn("/A", rclone_module._folder_locks) + self.assertIn("/B", rclone_module._folder_locks) + self.assertIn("/C", rclone_module._folder_locks) + self.assertIsNot(first_lock, third_lock) + self.assertIs(second_lock, rclone_module._folder_locks["/B"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_webpush_subscription.py b/tests/test_webpush_subscription.py new file mode 100644 index 00000000..b0a7503d --- /dev/null +++ b/tests/test_webpush_subscription.py @@ -0,0 +1,61 @@ +import unittest +from types import SimpleNamespace + +from app.core.config import global_vars +from app.helper.webpush import is_webpush_subscription_gone + + +class WebPushSubscriptionTest(unittest.TestCase): + def setUp(self): + """清理跨用例共享的 WebPush 订阅。""" + with global_vars.SUBSCRIPTIONS_LOCK: + global_vars.SUBSCRIPTIONS.clear() + + def tearDown(self): + """清理测试产生的 WebPush 订阅。""" + with global_vars.SUBSCRIPTIONS_LOCK: + global_vars.SUBSCRIPTIONS.clear() + + def test_push_subscription_upserts_by_endpoint(self): + """相同 endpoint 的 WebPush 订阅应更新而不是重复追加。""" + global_vars.push_subscription( + {"endpoint": "https://push.example/a", "keys": {"p256dh": "old"}} + ) + global_vars.push_subscription( + {"endpoint": "https://push.example/a", "keys": {"p256dh": "new"}} + ) + + subscriptions = global_vars.get_subscriptions() + + self.assertEqual(1, len(subscriptions)) + self.assertEqual("new", subscriptions[0]["keys"]["p256dh"]) + + def test_remove_subscription_deletes_by_endpoint(self): + """失效订阅应能按 endpoint 从全局订阅表删除。""" + subscription = {"endpoint": "https://push.example/a", "keys": {}} + global_vars.push_subscription(subscription) + + self.assertTrue(global_vars.remove_subscription(subscription)) + self.assertEqual([], global_vars.get_subscriptions()) + + def test_is_webpush_subscription_gone_matches_404_and_410(self): + """推送服务返回 404/410 时应识别为订阅已失效。""" + self.assertTrue( + is_webpush_subscription_gone( + SimpleNamespace(response=SimpleNamespace(status_code=410)) + ) + ) + self.assertTrue( + is_webpush_subscription_gone( + SimpleNamespace(response=SimpleNamespace(status=404)) + ) + ) + self.assertFalse( + is_webpush_subscription_gone( + SimpleNamespace(response=SimpleNamespace(status_code=500)) + ) + ) + + +if __name__ == "__main__": + unittest.main()