fix: bound long-lived cache state

This commit is contained in:
jxxghp
2026-05-24 18:03:42 +08:00
parent dc73d61682
commit 79539760da
15 changed files with 380 additions and 24 deletions

View File

@@ -4,7 +4,7 @@ import re
import traceback
import uuid
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
@@ -966,6 +966,11 @@ class AgentManager:
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# 每个会话最后活动时间,用于回收空闲 Agent 实例
self._session_last_used: Dict[str, tuple[str, datetime]] = {}
self._idle_cleanup_task: Optional[asyncio.Task] = None
self._idle_session_ttl = timedelta(hours=24)
self._idle_cleanup_interval = 60 * 60
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
@@ -998,33 +1003,85 @@ class AgentManager:
)
return status
@staticmethod
async def initialize():
async def initialize(self):
"""
初始化管理器
"""
memory_manager.initialize()
if self._idle_cleanup_task and not self._idle_cleanup_task.done():
return
self._idle_cleanup_task = asyncio.create_task(self._cleanup_idle_sessions())
async def close(self):
"""
关闭管理器
"""
if self._idle_cleanup_task:
self._idle_cleanup_task.cancel()
try:
await self._idle_cleanup_task
except asyncio.CancelledError:
pass
self._idle_cleanup_task = None
await memory_manager.close()
# 取消所有会话worker
for task in self._session_workers.values():
for task in list(self._session_workers.values()):
task.cancel()
# 等待所有worker结束
for session_id, task in self._session_workers.items():
for session_id, task in list(self._session_workers.items()):
try:
await task
except asyncio.CancelledError:
pass
self._session_workers.clear()
self._session_queues.clear()
for agent in self.active_agents.values():
self._session_last_used.clear()
for agent in list(self.active_agents.values()):
await agent.cleanup()
self.active_agents.clear()
def _record_session_activity(self, session_id: str, user_id: str) -> None:
"""
记录会话最近活动时间,供空闲会话清理任务判断是否可释放资源。
"""
self._session_last_used[session_id] = (user_id, datetime.now())
def _is_session_busy(self, session_id: str) -> bool:
"""
判断会话是否仍有正在执行的 worker 或待处理消息,避免误清理活跃会话。
"""
worker = self._session_workers.get(session_id)
if worker and not worker.done():
return True
queue = self._session_queues.get(session_id)
return bool(queue and not queue.empty())
def _expired_idle_sessions(self) -> list[tuple[str, str]]:
"""
收集已经超过空闲时间且当前不忙的会话。
"""
expire_before = datetime.now() - self._idle_session_ttl
expired = []
for session_id, (user_id, last_used) in list(self._session_last_used.items()):
if last_used < expire_before and not self._is_session_busy(session_id):
expired.append((session_id, user_id))
return expired
async def _cleanup_idle_sessions(self) -> None:
"""
周期性清理长时间没有新消息的 Agent 会话,避免长期运行后实例持续累积。
"""
while True:
try:
await asyncio.sleep(self._idle_cleanup_interval)
for session_id, user_id in self._expired_idle_sessions():
await self.clear_session(session_id=session_id, user_id=user_id)
logger.info(f"已清理空闲Agent会话: session_id={session_id}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理空闲Agent会话失败: {e}")
async def process_message(
self,
session_id: str,
@@ -1056,6 +1113,7 @@ class AgentManager:
original_chat_id=original_chat_id,
reply_mode=reply_mode,
)
self._record_session_activity(session_id, user_id)
# 获取或创建会话队列
if session_id not in self._session_queues:
@@ -1221,6 +1279,7 @@ class AgentManager:
"""
清空会话
"""
self._session_last_used.pop(session_id, None)
# 取消该会话的worker
if session_id in self._session_workers:
self._session_workers[session_id].cancel()
@@ -1228,7 +1287,7 @@ class AgentManager:
await self._session_workers[session_id]
except asyncio.CancelledError:
pass
await self._session_workers.pop(session_id, None)
self._session_workers.pop(session_id, None)
# 清理队列
self._session_queues.pop(session_id, None)

View File

@@ -105,6 +105,7 @@ class LLMProviderManager(metaclass=Singleton):
_MODELS_DEV_URL = "https://models.dev/api.json"
_MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json")
_MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60
_AUTH_SESSION_DONE_RETENTION = 300
_CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
_CHATGPT_ISSUER = "https://auth.openai.com"
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
@@ -183,6 +184,33 @@ class LLMProviderManager(metaclass=Singleton):
Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json"
)
def _cleanup_auth_sessions_locked(self, now: Optional[float] = None) -> None:
"""
清理过期或已完成一段时间的临时授权会话。
调用方必须已经持有 `_lock`,这样 `_pending_sessions` 与
`_oauth_state_index` 能保持一致,避免 state 残留。
"""
now = time.time() if now is None else now
expired_session_ids = []
for session_id, session in self._pending_sessions.items():
expires_at = session.expires_at or session.created_at + 600
if session.status == "pending":
if expires_at <= now:
expired_session_ids.append(session_id)
elif expires_at + self._AUTH_SESSION_DONE_RETENTION <= now:
expired_session_ids.append(session_id)
if not expired_session_ids:
return
expired_session_ids_set = set(expired_session_ids)
for session_id in expired_session_ids:
self._pending_sessions.pop(session_id, None)
for state, session_id in list(self._oauth_state_index.items()):
if session_id in expired_session_ids_set:
self._oauth_state_index.pop(state, None)
@staticmethod
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
"""
@@ -2001,6 +2029,7 @@ class LLMProviderManager(metaclass=Singleton):
}
)
with self._lock:
self._cleanup_auth_sessions_locked()
self._pending_sessions[session.session_id] = session
self._oauth_state_index[state] = session.session_id
return {
@@ -2035,6 +2064,7 @@ class LLMProviderManager(metaclass=Singleton):
}
)
with self._lock:
self._cleanup_auth_sessions_locked()
self._pending_sessions[session.session_id] = session
return {
"session_id": session.session_id,
@@ -2073,6 +2103,7 @@ class LLMProviderManager(metaclass=Singleton):
}
)
with self._lock:
self._cleanup_auth_sessions_locked()
self._pending_sessions[session.session_id] = session
return {
"session_id": session.session_id,
@@ -2089,6 +2120,7 @@ class LLMProviderManager(metaclass=Singleton):
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""读取临时授权会话状态。"""
with self._lock:
self._cleanup_auth_sessions_locked()
session = self._pending_sessions.get(session_id)
if not session:
raise LLMProviderAuthError("授权会话不存在或已过期")
@@ -2135,6 +2167,7 @@ class LLMProviderManager(metaclass=Singleton):
if error:
message = error_description or error
with self._lock:
self._cleanup_auth_sessions_locked()
session_id = self._oauth_state_index.pop(state or "", None)
if session_id and session_id in self._pending_sessions:
self._mark_session_error(self._pending_sessions[session_id], message)
@@ -2144,6 +2177,7 @@ class LLMProviderManager(metaclass=Singleton):
return False, "缺少授权码或 state 参数"
with self._lock:
self._cleanup_auth_sessions_locked()
session_id = self._oauth_state_index.pop(state, None)
session = self._pending_sessions.get(session_id or "")
@@ -2186,6 +2220,7 @@ class LLMProviderManager(metaclass=Singleton):
前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。
"""
with self._lock:
self._cleanup_auth_sessions_locked()
session = self._pending_sessions.get(session_id)
if not session:
raise LLMProviderAuthError("授权会话不存在或已过期")

View File

@@ -27,6 +27,8 @@ class MemoryManager:
初始化记忆管理器
"""
try:
if self.cleanup_task and not self.cleanup_task.done():
return
# 启动内存缓存清理任务Redis通过TTL自动过期
self.cleanup_task = asyncio.create_task(
self._cleanup_expired_memories()
@@ -46,6 +48,7 @@ class MemoryManager:
await self.cleanup_task
except asyncio.CancelledError:
pass
self.cleanup_task = None
logger.info("对话记忆管理器已关闭")

View File

@@ -15,6 +15,7 @@ from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
from app.helper.service import ServiceConfigHelper
from app.helper.webpush import is_webpush_subscription_gone
from app.log import logger
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
from app.schemas.types import MessageChannel
@@ -218,8 +219,7 @@ async def subscribe(
客户端webpush通知订阅
"""
subinfo = subscription.model_dump()
if subinfo not in global_vars.get_subscriptions():
global_vars.push_subscription(subinfo)
global_vars.push_subscription(subinfo)
logger.debug(f"通知订阅成功: {subinfo}")
return schemas.Response(success=True)
@@ -244,5 +244,7 @@ def send_notification(
)
except WebPushException as err:
logger.error(f"WebPush发送失败: {str(err)}")
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
continue
return schemas.Response(success=True)

View File

@@ -84,13 +84,12 @@ class ScrapingOption:
class ScrapingConfig:
"""媒体刮削配置"""
_policies: dict[tuple[str], ScrapingOption] = {}
def __init__(self, config_dict: dict[str, str] = None):
"""
初始化配置对象
:param config_dict: 用户配置字典(扁平化格式),为 None 时使用默认配置
"""
self._policies: dict[tuple[str, str], ScrapingOption] = {}
# 合并用户配置和默认配置
if config_dict is None:
config_dict = {}

View File

@@ -47,6 +47,36 @@ class MessageChain(ChainBase):
# 会话超时时间(分钟)
_session_timeout_minutes: int = 24 * 60
@staticmethod
def _schedule_agent_session_clear(session_id: str, userid: Union[str, int]) -> None:
"""
异步调度 Agent 会话清理,避免同步消息链阻塞在模型资源释放上。
"""
if not session_id:
return
clear_task = None
try:
clear_task = agent_manager.clear_session(session_id=session_id, user_id=str(userid))
asyncio.run_coroutine_threadsafe(
clear_task,
global_vars.loop,
)
except Exception as e:
if clear_task:
clear_task.close()
logger.warning(f"调度清理智能体会话失败: {e}")
def _cleanup_expired_user_sessions(self, current_time: datetime) -> None:
"""
清理超过复用窗口的用户会话映射,并同步释放旧 Agent 实例。
"""
timeout = timedelta(minutes=self._session_timeout_minutes)
for userid, (session_id, last_time) in list(self._user_sessions.items()):
if current_time - last_time <= timeout:
continue
self._user_sessions.pop(userid, None)
self._schedule_agent_session_clear(session_id, userid)
@dataclass
class _ProcessingStatus:
channel: MessageChannel
@@ -919,6 +949,7 @@ class MessageChain(ChainBase):
如果用户上次会话在15分钟内则复用相同的会话ID否则创建新的会话ID
"""
current_time = datetime.now()
self._cleanup_expired_user_sessions(current_time)
# 检查用户是否有已存在的会话
if userid in self._user_sessions:
@@ -946,6 +977,9 @@ class MessageChain(ChainBase):
"""
将用户会话绑定到指定的 session_id并刷新最后活动时间。
"""
old_session = self._user_sessions.get(userid)
if old_session and old_session[0] != session_id:
self._schedule_agent_session_clear(old_session[0], userid)
self._user_sessions[userid] = (session_id, datetime.now())
def _record_user_message(
@@ -1005,14 +1039,18 @@ class MessageChain(ChainBase):
# 如果有会话ID同时清除智能体的会话记忆
if session_id:
clear_task = None
try:
clear_task = agent_manager.clear_session(
session_id=session_id, user_id=str(userid)
)
asyncio.run_coroutine_threadsafe(
agent_manager.clear_session(
session_id=session_id, user_id=str(userid)
),
clear_task,
global_vars.loop,
)
except Exception as e:
if clear_task:
clear_task.close()
logger.warning(f"清除智能体会话记忆失败: {e}")
self.post_message(

View File

@@ -1130,6 +1130,8 @@ class GlobalVar(object):
STOP_EVENT: threading.Event = threading.Event()
# webpush订阅
SUBSCRIPTIONS: List[dict] = []
# webpush订阅读写锁
SUBSCRIPTIONS_LOCK: threading.Lock = threading.Lock()
# 需应急停止的工作流
EMERGENCY_STOP_WORKFLOWS: List[int] = []
# 需应急停止文件整理
@@ -1169,13 +1171,37 @@ class GlobalVar(object):
"""
获取webpush订阅
"""
return self.SUBSCRIPTIONS
with self.SUBSCRIPTIONS_LOCK:
return list(self.SUBSCRIPTIONS)
def push_subscription(self, subscription: dict):
"""
添加webpush订阅
添加或更新webpush订阅
"""
self.SUBSCRIPTIONS.append(subscription)
endpoint = subscription.get("endpoint") if subscription else None
if not endpoint:
return
with self.SUBSCRIPTIONS_LOCK:
for index, current in enumerate(self.SUBSCRIPTIONS):
if current.get("endpoint") == endpoint:
self.SUBSCRIPTIONS[index] = subscription
return
self.SUBSCRIPTIONS.append(subscription)
def remove_subscription(self, subscription: dict) -> bool:
"""
根据 endpoint 移除webpush订阅返回是否实际删除。
"""
endpoint = subscription.get("endpoint") if subscription else None
if not endpoint:
return False
with self.SUBSCRIPTIONS_LOCK:
before_count = len(self.SUBSCRIPTIONS)
self.SUBSCRIPTIONS[:] = [
current for current in self.SUBSCRIPTIONS
if current.get("endpoint") != endpoint
]
return len(self.SUBSCRIPTIONS) != before_count
def stop_workflow(self, workflow_id: int):
"""

12
app/helper/webpush.py Normal file
View File

@@ -0,0 +1,12 @@
from typing import Any
from pywebpush import WebPushException
def is_webpush_subscription_gone(error: WebPushException) -> bool:
"""
判断 WebPush 订阅是否已经在浏览器或推送服务侧失效。
"""
response: Any = getattr(error, "response", None)
status_code = getattr(response, "status_code", None) or getattr(response, "status", None)
return status_code in {404, 410}

View File

@@ -2,6 +2,7 @@ import json
import subprocess
import threading
import time
from collections import OrderedDict
from pathlib import Path
from typing import Optional, List, Union
@@ -13,10 +14,24 @@ from app.schemas.types import StorageSchema
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
_folder_locks: dict[str, threading.Lock] = {}
_MAX_FOLDER_LOCKS = 4096
_folder_locks: OrderedDict[str, threading.Lock] = OrderedDict()
_folder_locks_guard = threading.Lock()
def _evict_unused_folder_locks_locked() -> None:
"""
在持有全局锁表互斥锁时淘汰旧路径锁,避免大量不同目录导致锁表无限增长。
"""
while len(_folder_locks) >= _MAX_FOLDER_LOCKS:
for key, lock in list(_folder_locks.items()):
if not lock.locked():
_folder_locks.pop(key, None)
break
else:
break
class Rclone(StorageBase):
"""
rclone相关操作
@@ -144,9 +159,14 @@ class Rclone(StorageBase):
"""
normalized = Rclone.__normalize_remote_path(path)
with _folder_locks_guard:
if normalized not in _folder_locks:
_folder_locks[normalized] = threading.Lock()
return _folder_locks[normalized]
lock = _folder_locks.get(normalized)
if lock:
_folder_locks.move_to_end(normalized)
return lock
_evict_unused_folder_locks_locked()
lock = threading.Lock()
_folder_locks[normalized] = lock
return lock
def __wait_for_item(
self, path: Path, retries: int = 3, delay: float = 0.2

View File

@@ -4,6 +4,7 @@ from typing import Union, Tuple
from pywebpush import webpush, WebPushException
from app.core.config import global_vars, settings
from app.helper.webpush import is_webpush_subscription_gone
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import Notification
@@ -97,6 +98,8 @@ class WebPushModule(_ModuleBase, _MessageBase):
)
except WebPushException as err:
logger.error(f"WebPush发送失败: {str(err)}")
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
except Exception as msg_e:
logger.error(f"发送消息失败:{msg_e}")

View File

@@ -1,6 +1,6 @@
import asyncio
import unittest
from datetime import datetime
from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import patch
@@ -8,11 +8,20 @@ from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
from app.agent.middleware.usage import UsageMiddleware
from app.agent import AgentManager
from app.chain.message import MessageChain
from app.schemas.types import MessageChannel
class TestAgentSessionStatus(unittest.TestCase):
def setUp(self):
"""清理跨用例共享的用户会话状态。"""
MessageChain._user_sessions.clear()
def tearDown(self):
"""清理测试产生的用户会话状态。"""
MessageChain._user_sessions.clear()
def test_usage_middleware_records_usage_metadata(self):
snapshots = []
middleware = UsageMiddleware(on_usage=snapshots.append)
@@ -104,3 +113,34 @@ class TestAgentSessionStatus(unittest.TestCase):
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")
def test_get_or_create_session_cleans_expired_session(self):
"""用户会话超过复用窗口时应调度清理旧 Agent 会话。"""
chain = MessageChain()
chain._user_sessions.clear()
chain._user_sessions["10001"] = (
"old-session",
datetime.now() - timedelta(minutes=chain._session_timeout_minutes + 1),
)
with patch.object(chain, "_schedule_agent_session_clear") as clear_session:
session_id = chain._get_or_create_session_id("10001")
self.assertNotEqual(session_id, "old-session")
self.assertEqual(chain._user_sessions["10001"][0], session_id)
clear_session.assert_called_once_with("old-session", "10001")
def test_agent_manager_collects_idle_sessions(self):
"""Agent 管理器应只回收超过空闲窗口且未忙碌的会话。"""
manager = AgentManager()
manager._idle_session_ttl = timedelta(seconds=1)
manager._session_last_used["idle-session"] = (
"10001",
datetime.now() - timedelta(seconds=2),
)
manager._session_last_used["fresh-session"] = ("10002", datetime.now())
self.assertEqual(
[("idle-session", "10001")],
manager._expired_idle_sessions(),
)

View File

@@ -39,11 +39,23 @@ _stub_module(
TEMP_PATH="/tmp",
PROXY_HOST=None,
LLM_MAX_CONTEXT_TOKENS=64,
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME=True,
RMT_MEDIAEXT=[".mkv", ".mp4"],
RMT_SUBEXT=[".srt"],
RMT_AUDIOEXT=[".flac"],
),
)
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper)
_stub_module("app.log", logger=_DummyLogger())
_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent"))
_stub_module(
"app.schemas.types",
SystemConfigKey=SimpleNamespace(
AIAgentConfig="agent",
CustomReleaseGroups="custom_release_groups",
Customization="customization",
CustomIdentifiers="custom_identifiers",
),
)
provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py"
spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path)
@@ -54,6 +66,7 @@ spec.loader.exec_module(provider_module)
LLMProviderError = provider_module.LLMProviderError
LLMProviderManager = provider_module.LLMProviderManager
PendingAuthSession = provider_module.PendingAuthSession
class LlmProviderRegistryTest(unittest.TestCase):
@@ -612,6 +625,24 @@ class LlmProviderRegistryTest(unittest.TestCase):
self.assertEqual(models, [])
def test_expired_auth_session_cleanup_removes_state_index(self):
"""过期授权会话应同时移除 session 与 OAuth state 索引。"""
manager = LLMProviderManager()
manager._pending_sessions["session-old"] = PendingAuthSession(
session_id="session-old",
provider_id="chatgpt",
method_id="browser_oauth",
flow_type="oauth",
expires_at=100,
)
manager._oauth_state_index["state-old"] = "session-old"
with manager._lock:
manager._cleanup_auth_sessions_locked(now=101)
self.assertNotIn("session-old", manager._pending_sessions)
self.assertNotIn("state-old", manager._oauth_state_index)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,7 +8,7 @@ sys.modules['app.db.systemconfig_oper'] = MagicMock()
sys.modules['app.db.systemconfig_oper'].SystemConfigOper.return_value.get.return_value = None
from app import schemas
from app.chain.media import MediaChain, ScrapingOption
from app.chain.media import MediaChain, ScrapingConfig, ScrapingOption
from app.core.context import MediaInfo
from app.core.event import Event
from app.core.metainfo import MetaInfo
@@ -42,6 +42,20 @@ class TestMediaScrapingPaths(unittest.TestCase):
self.assertEqual(target_item, parent_item)
self.assertEqual(target_path, Path("/movies/avatar.nfo"))
def test_scraping_config_does_not_share_policy_state_between_instances(self):
"""刮削配置实例之间不应共享已删除或覆盖过的策略。"""
first_config = ScrapingConfig({"movie_nfo": ScrapingPolicy.SKIP})
second_config = ScrapingConfig({})
self.assertEqual(
ScrapingPolicy.SKIP,
first_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
)
self.assertEqual(
ScrapingPolicy.MISSINGONLY,
second_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
)
def test_movie_dir_nfo_path(self):
fileitem = schemas.FileItem(path="/movies/Avatar (2009)", name="Avatar (2009)", type="dir", storage="local")

View File

@@ -200,6 +200,19 @@ class RcloneStorageTest(unittest.TestCase):
self.assertEqual("/Show/", folder.path)
run_mock.assert_called_once()
def test_folder_lock_table_evicts_old_unlocked_paths(self):
"""路径锁表超过上限时应优先淘汰未占用的旧锁。"""
with patch.object(rclone_module, "_MAX_FOLDER_LOCKS", 2):
first_lock = Rclone._Rclone__get_path_lock(Path("/A"))
second_lock = Rclone._Rclone__get_path_lock(Path("/B"))
third_lock = Rclone._Rclone__get_path_lock(Path("/C"))
self.assertNotIn("/A", rclone_module._folder_locks)
self.assertIn("/B", rclone_module._folder_locks)
self.assertIn("/C", rclone_module._folder_locks)
self.assertIsNot(first_lock, third_lock)
self.assertIs(second_lock, rclone_module._folder_locks["/B"])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,61 @@
import unittest
from types import SimpleNamespace
from app.core.config import global_vars
from app.helper.webpush import is_webpush_subscription_gone
class WebPushSubscriptionTest(unittest.TestCase):
def setUp(self):
"""清理跨用例共享的 WebPush 订阅。"""
with global_vars.SUBSCRIPTIONS_LOCK:
global_vars.SUBSCRIPTIONS.clear()
def tearDown(self):
"""清理测试产生的 WebPush 订阅。"""
with global_vars.SUBSCRIPTIONS_LOCK:
global_vars.SUBSCRIPTIONS.clear()
def test_push_subscription_upserts_by_endpoint(self):
"""相同 endpoint 的 WebPush 订阅应更新而不是重复追加。"""
global_vars.push_subscription(
{"endpoint": "https://push.example/a", "keys": {"p256dh": "old"}}
)
global_vars.push_subscription(
{"endpoint": "https://push.example/a", "keys": {"p256dh": "new"}}
)
subscriptions = global_vars.get_subscriptions()
self.assertEqual(1, len(subscriptions))
self.assertEqual("new", subscriptions[0]["keys"]["p256dh"])
def test_remove_subscription_deletes_by_endpoint(self):
"""失效订阅应能按 endpoint 从全局订阅表删除。"""
subscription = {"endpoint": "https://push.example/a", "keys": {}}
global_vars.push_subscription(subscription)
self.assertTrue(global_vars.remove_subscription(subscription))
self.assertEqual([], global_vars.get_subscriptions())
def test_is_webpush_subscription_gone_matches_404_and_410(self):
"""推送服务返回 404/410 时应识别为订阅已失效。"""
self.assertTrue(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status_code=410))
)
)
self.assertTrue(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status=404))
)
)
self.assertFalse(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status_code=500))
)
)
if __name__ == "__main__":
unittest.main()