mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 07:26:48 +00:00
fix: bound long-lived cache state
This commit is contained in:
@@ -4,7 +4,7 @@ import re
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -966,6 +966,11 @@ class AgentManager:
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# 每个会话最后活动时间,用于回收空闲 Agent 实例
|
||||
self._session_last_used: Dict[str, tuple[str, datetime]] = {}
|
||||
self._idle_cleanup_task: Optional[asyncio.Task] = None
|
||||
self._idle_session_ttl = timedelta(hours=24)
|
||||
self._idle_cleanup_interval = 60 * 60
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
@@ -998,33 +1003,85 @@ class AgentManager:
|
||||
)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
memory_manager.initialize()
|
||||
if self._idle_cleanup_task and not self._idle_cleanup_task.done():
|
||||
return
|
||||
self._idle_cleanup_task = asyncio.create_task(self._cleanup_idle_sessions())
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
if self._idle_cleanup_task:
|
||||
self._idle_cleanup_task.cancel()
|
||||
try:
|
||||
await self._idle_cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._idle_cleanup_task = None
|
||||
await memory_manager.close()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
for task in list(self._session_workers.values()):
|
||||
task.cancel()
|
||||
# 等待所有worker结束
|
||||
for session_id, task in self._session_workers.items():
|
||||
for session_id, task in list(self._session_workers.items()):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.clear()
|
||||
self._session_queues.clear()
|
||||
for agent in self.active_agents.values():
|
||||
self._session_last_used.clear()
|
||||
for agent in list(self.active_agents.values()):
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
def _record_session_activity(self, session_id: str, user_id: str) -> None:
|
||||
"""
|
||||
记录会话最近活动时间,供空闲会话清理任务判断是否可释放资源。
|
||||
"""
|
||||
self._session_last_used[session_id] = (user_id, datetime.now())
|
||||
|
||||
def _is_session_busy(self, session_id: str) -> bool:
|
||||
"""
|
||||
判断会话是否仍有正在执行的 worker 或待处理消息,避免误清理活跃会话。
|
||||
"""
|
||||
worker = self._session_workers.get(session_id)
|
||||
if worker and not worker.done():
|
||||
return True
|
||||
queue = self._session_queues.get(session_id)
|
||||
return bool(queue and not queue.empty())
|
||||
|
||||
def _expired_idle_sessions(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
收集已经超过空闲时间且当前不忙的会话。
|
||||
"""
|
||||
expire_before = datetime.now() - self._idle_session_ttl
|
||||
expired = []
|
||||
for session_id, (user_id, last_used) in list(self._session_last_used.items()):
|
||||
if last_used < expire_before and not self._is_session_busy(session_id):
|
||||
expired.append((session_id, user_id))
|
||||
return expired
|
||||
|
||||
async def _cleanup_idle_sessions(self) -> None:
|
||||
"""
|
||||
周期性清理长时间没有新消息的 Agent 会话,避免长期运行后实例持续累积。
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._idle_cleanup_interval)
|
||||
for session_id, user_id in self._expired_idle_sessions():
|
||||
await self.clear_session(session_id=session_id, user_id=user_id)
|
||||
logger.info(f"已清理空闲Agent会话: session_id={session_id}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理空闲Agent会话失败: {e}")
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -1056,6 +1113,7 @@ class AgentManager:
|
||||
original_chat_id=original_chat_id,
|
||||
reply_mode=reply_mode,
|
||||
)
|
||||
self._record_session_activity(session_id, user_id)
|
||||
|
||||
# 获取或创建会话队列
|
||||
if session_id not in self._session_queues:
|
||||
@@ -1221,6 +1279,7 @@ class AgentManager:
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
self._session_last_used.pop(session_id, None)
|
||||
# 取消该会话的worker
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
@@ -1228,7 +1287,7 @@ class AgentManager:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self._session_workers.pop(session_id, None)
|
||||
self._session_workers.pop(session_id, None)
|
||||
|
||||
# 清理队列
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
@@ -105,6 +105,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
_MODELS_DEV_URL = "https://models.dev/api.json"
|
||||
_MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json")
|
||||
_MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60
|
||||
_AUTH_SESSION_DONE_RETENTION = 300
|
||||
_CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
_CHATGPT_ISSUER = "https://auth.openai.com"
|
||||
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
@@ -183,6 +184,33 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json"
|
||||
)
|
||||
|
||||
def _cleanup_auth_sessions_locked(self, now: Optional[float] = None) -> None:
|
||||
"""
|
||||
清理过期或已完成一段时间的临时授权会话。
|
||||
|
||||
调用方必须已经持有 `_lock`,这样 `_pending_sessions` 与
|
||||
`_oauth_state_index` 能保持一致,避免 state 残留。
|
||||
"""
|
||||
now = time.time() if now is None else now
|
||||
expired_session_ids = []
|
||||
for session_id, session in self._pending_sessions.items():
|
||||
expires_at = session.expires_at or session.created_at + 600
|
||||
if session.status == "pending":
|
||||
if expires_at <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
elif expires_at + self._AUTH_SESSION_DONE_RETENTION <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
|
||||
if not expired_session_ids:
|
||||
return
|
||||
|
||||
expired_session_ids_set = set(expired_session_ids)
|
||||
for session_id in expired_session_ids:
|
||||
self._pending_sessions.pop(session_id, None)
|
||||
for state, session_id in list(self._oauth_state_index.items()):
|
||||
if session_id in expired_session_ids_set:
|
||||
self._oauth_state_index.pop(state, None)
|
||||
|
||||
@staticmethod
|
||||
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
|
||||
"""
|
||||
@@ -2001,6 +2029,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
self._oauth_state_index[state] = session.session_id
|
||||
return {
|
||||
@@ -2035,6 +2064,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -2073,6 +2103,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -2089,6 +2120,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""读取临时授权会话状态。"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
@@ -2135,6 +2167,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
if error:
|
||||
message = error_description or error
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state or "", None)
|
||||
if session_id and session_id in self._pending_sessions:
|
||||
self._mark_session_error(self._pending_sessions[session_id], message)
|
||||
@@ -2144,6 +2177,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return False, "缺少授权码或 state 参数"
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state, None)
|
||||
session = self._pending_sessions.get(session_id or "")
|
||||
|
||||
@@ -2186,6 +2220,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
|
||||
@@ -27,6 +27,8 @@ class MemoryManager:
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
return
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
@@ -46,6 +48,7 @@ class MemoryManager:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.cleanup_task = None
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.db.models import User
|
||||
from app.db.models.message import Message
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.helper.webpush import is_webpush_subscription_gone
|
||||
from app.log import logger
|
||||
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from app.schemas.types import MessageChannel
|
||||
@@ -218,8 +219,7 @@ async def subscribe(
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -244,5 +244,7 @@ def send_notification(
|
||||
)
|
||||
except WebPushException as err:
|
||||
logger.error(f"WebPush发送失败: {str(err)}")
|
||||
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
|
||||
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
|
||||
continue
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -84,13 +84,12 @@ class ScrapingOption:
|
||||
class ScrapingConfig:
|
||||
"""媒体刮削配置"""
|
||||
|
||||
_policies: dict[tuple[str], ScrapingOption] = {}
|
||||
|
||||
def __init__(self, config_dict: dict[str, str] = None):
|
||||
"""
|
||||
初始化配置对象
|
||||
:param config_dict: 用户配置字典(扁平化格式),为 None 时使用默认配置
|
||||
"""
|
||||
self._policies: dict[tuple[str, str], ScrapingOption] = {}
|
||||
# 合并用户配置和默认配置
|
||||
if config_dict is None:
|
||||
config_dict = {}
|
||||
|
||||
@@ -47,6 +47,36 @@ class MessageChain(ChainBase):
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 24 * 60
|
||||
|
||||
@staticmethod
|
||||
def _schedule_agent_session_clear(session_id: str, userid: Union[str, int]) -> None:
|
||||
"""
|
||||
异步调度 Agent 会话清理,避免同步消息链阻塞在模型资源释放上。
|
||||
"""
|
||||
if not session_id:
|
||||
return
|
||||
clear_task = None
|
||||
try:
|
||||
clear_task = agent_manager.clear_session(session_id=session_id, user_id=str(userid))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
clear_task,
|
||||
global_vars.loop,
|
||||
)
|
||||
except Exception as e:
|
||||
if clear_task:
|
||||
clear_task.close()
|
||||
logger.warning(f"调度清理智能体会话失败: {e}")
|
||||
|
||||
def _cleanup_expired_user_sessions(self, current_time: datetime) -> None:
|
||||
"""
|
||||
清理超过复用窗口的用户会话映射,并同步释放旧 Agent 实例。
|
||||
"""
|
||||
timeout = timedelta(minutes=self._session_timeout_minutes)
|
||||
for userid, (session_id, last_time) in list(self._user_sessions.items()):
|
||||
if current_time - last_time <= timeout:
|
||||
continue
|
||||
self._user_sessions.pop(userid, None)
|
||||
self._schedule_agent_session_clear(session_id, userid)
|
||||
|
||||
@dataclass
|
||||
class _ProcessingStatus:
|
||||
channel: MessageChannel
|
||||
@@ -919,6 +949,7 @@ class MessageChain(ChainBase):
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
self._cleanup_expired_user_sessions(current_time)
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in self._user_sessions:
|
||||
@@ -946,6 +977,9 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
将用户会话绑定到指定的 session_id,并刷新最后活动时间。
|
||||
"""
|
||||
old_session = self._user_sessions.get(userid)
|
||||
if old_session and old_session[0] != session_id:
|
||||
self._schedule_agent_session_clear(old_session[0], userid)
|
||||
self._user_sessions[userid] = (session_id, datetime.now())
|
||||
|
||||
def _record_user_message(
|
||||
@@ -1005,14 +1039,18 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
clear_task = None
|
||||
try:
|
||||
clear_task = agent_manager.clear_session(
|
||||
session_id=session_id, user_id=str(userid)
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id, user_id=str(userid)
|
||||
),
|
||||
clear_task,
|
||||
global_vars.loop,
|
||||
)
|
||||
except Exception as e:
|
||||
if clear_task:
|
||||
clear_task.close()
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
self.post_message(
|
||||
|
||||
@@ -1130,6 +1130,8 @@ class GlobalVar(object):
|
||||
STOP_EVENT: threading.Event = threading.Event()
|
||||
# webpush订阅
|
||||
SUBSCRIPTIONS: List[dict] = []
|
||||
# webpush订阅读写锁
|
||||
SUBSCRIPTIONS_LOCK: threading.Lock = threading.Lock()
|
||||
# 需应急停止的工作流
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
@@ -1169,13 +1171,37 @@ class GlobalVar(object):
|
||||
"""
|
||||
获取webpush订阅
|
||||
"""
|
||||
return self.SUBSCRIPTIONS
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
return list(self.SUBSCRIPTIONS)
|
||||
|
||||
def push_subscription(self, subscription: dict):
|
||||
"""
|
||||
添加webpush订阅
|
||||
添加或更新webpush订阅。
|
||||
"""
|
||||
self.SUBSCRIPTIONS.append(subscription)
|
||||
endpoint = subscription.get("endpoint") if subscription else None
|
||||
if not endpoint:
|
||||
return
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
for index, current in enumerate(self.SUBSCRIPTIONS):
|
||||
if current.get("endpoint") == endpoint:
|
||||
self.SUBSCRIPTIONS[index] = subscription
|
||||
return
|
||||
self.SUBSCRIPTIONS.append(subscription)
|
||||
|
||||
def remove_subscription(self, subscription: dict) -> bool:
|
||||
"""
|
||||
根据 endpoint 移除webpush订阅,返回是否实际删除。
|
||||
"""
|
||||
endpoint = subscription.get("endpoint") if subscription else None
|
||||
if not endpoint:
|
||||
return False
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
before_count = len(self.SUBSCRIPTIONS)
|
||||
self.SUBSCRIPTIONS[:] = [
|
||||
current for current in self.SUBSCRIPTIONS
|
||||
if current.get("endpoint") != endpoint
|
||||
]
|
||||
return len(self.SUBSCRIPTIONS) != before_count
|
||||
|
||||
def stop_workflow(self, workflow_id: int):
|
||||
"""
|
||||
|
||||
12
app/helper/webpush.py
Normal file
12
app/helper/webpush.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from pywebpush import WebPushException
|
||||
|
||||
|
||||
def is_webpush_subscription_gone(error: WebPushException) -> bool:
|
||||
"""
|
||||
判断 WebPush 订阅是否已经在浏览器或推送服务侧失效。
|
||||
"""
|
||||
response: Any = getattr(error, "response", None)
|
||||
status_code = getattr(response, "status_code", None) or getattr(response, "status", None)
|
||||
return status_code in {404, 410}
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
@@ -13,10 +14,24 @@ from app.schemas.types import StorageSchema
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
_folder_locks: dict[str, threading.Lock] = {}
|
||||
_MAX_FOLDER_LOCKS = 4096
|
||||
_folder_locks: OrderedDict[str, threading.Lock] = OrderedDict()
|
||||
_folder_locks_guard = threading.Lock()
|
||||
|
||||
|
||||
def _evict_unused_folder_locks_locked() -> None:
|
||||
"""
|
||||
在持有全局锁表互斥锁时淘汰旧路径锁,避免大量不同目录导致锁表无限增长。
|
||||
"""
|
||||
while len(_folder_locks) >= _MAX_FOLDER_LOCKS:
|
||||
for key, lock in list(_folder_locks.items()):
|
||||
if not lock.locked():
|
||||
_folder_locks.pop(key, None)
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class Rclone(StorageBase):
|
||||
"""
|
||||
rclone相关操作
|
||||
@@ -144,9 +159,14 @@ class Rclone(StorageBase):
|
||||
"""
|
||||
normalized = Rclone.__normalize_remote_path(path)
|
||||
with _folder_locks_guard:
|
||||
if normalized not in _folder_locks:
|
||||
_folder_locks[normalized] = threading.Lock()
|
||||
return _folder_locks[normalized]
|
||||
lock = _folder_locks.get(normalized)
|
||||
if lock:
|
||||
_folder_locks.move_to_end(normalized)
|
||||
return lock
|
||||
_evict_unused_folder_locks_locked()
|
||||
lock = threading.Lock()
|
||||
_folder_locks[normalized] = lock
|
||||
return lock
|
||||
|
||||
def __wait_for_item(
|
||||
self, path: Path, retries: int = 3, delay: float = 0.2
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Union, Tuple
|
||||
from pywebpush import webpush, WebPushException
|
||||
|
||||
from app.core.config import global_vars, settings
|
||||
from app.helper.webpush import is_webpush_subscription_gone
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import Notification
|
||||
@@ -97,6 +98,8 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
)
|
||||
except WebPushException as err:
|
||||
logger.error(f"WebPush发送失败: {str(err)}")
|
||||
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
|
||||
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -8,11 +8,20 @@ from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent import AgentManager
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentSessionStatus(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""清理跨用例共享的用户会话状态。"""
|
||||
MessageChain._user_sessions.clear()
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试产生的用户会话状态。"""
|
||||
MessageChain._user_sessions.clear()
|
||||
|
||||
def test_usage_middleware_records_usage_metadata(self):
|
||||
snapshots = []
|
||||
middleware = UsageMiddleware(on_usage=snapshots.append)
|
||||
@@ -104,3 +113,34 @@ class TestAgentSessionStatus(unittest.TestCase):
|
||||
|
||||
notification = post_message.call_args.args[0]
|
||||
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")
|
||||
|
||||
def test_get_or_create_session_cleans_expired_session(self):
|
||||
"""用户会话超过复用窗口时应调度清理旧 Agent 会话。"""
|
||||
chain = MessageChain()
|
||||
chain._user_sessions.clear()
|
||||
chain._user_sessions["10001"] = (
|
||||
"old-session",
|
||||
datetime.now() - timedelta(minutes=chain._session_timeout_minutes + 1),
|
||||
)
|
||||
|
||||
with patch.object(chain, "_schedule_agent_session_clear") as clear_session:
|
||||
session_id = chain._get_or_create_session_id("10001")
|
||||
|
||||
self.assertNotEqual(session_id, "old-session")
|
||||
self.assertEqual(chain._user_sessions["10001"][0], session_id)
|
||||
clear_session.assert_called_once_with("old-session", "10001")
|
||||
|
||||
def test_agent_manager_collects_idle_sessions(self):
|
||||
"""Agent 管理器应只回收超过空闲窗口且未忙碌的会话。"""
|
||||
manager = AgentManager()
|
||||
manager._idle_session_ttl = timedelta(seconds=1)
|
||||
manager._session_last_used["idle-session"] = (
|
||||
"10001",
|
||||
datetime.now() - timedelta(seconds=2),
|
||||
)
|
||||
manager._session_last_used["fresh-session"] = ("10002", datetime.now())
|
||||
|
||||
self.assertEqual(
|
||||
[("idle-session", "10001")],
|
||||
manager._expired_idle_sessions(),
|
||||
)
|
||||
|
||||
@@ -39,11 +39,23 @@ _stub_module(
|
||||
TEMP_PATH="/tmp",
|
||||
PROXY_HOST=None,
|
||||
LLM_MAX_CONTEXT_TOKENS=64,
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME=True,
|
||||
RMT_MEDIAEXT=[".mkv", ".mp4"],
|
||||
RMT_SUBEXT=[".srt"],
|
||||
RMT_AUDIOEXT=[".flac"],
|
||||
),
|
||||
)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper)
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent"))
|
||||
_stub_module(
|
||||
"app.schemas.types",
|
||||
SystemConfigKey=SimpleNamespace(
|
||||
AIAgentConfig="agent",
|
||||
CustomReleaseGroups="custom_release_groups",
|
||||
Customization="customization",
|
||||
CustomIdentifiers="custom_identifiers",
|
||||
),
|
||||
)
|
||||
|
||||
provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path)
|
||||
@@ -54,6 +66,7 @@ spec.loader.exec_module(provider_module)
|
||||
|
||||
LLMProviderError = provider_module.LLMProviderError
|
||||
LLMProviderManager = provider_module.LLMProviderManager
|
||||
PendingAuthSession = provider_module.PendingAuthSession
|
||||
|
||||
|
||||
class LlmProviderRegistryTest(unittest.TestCase):
|
||||
@@ -612,6 +625,24 @@ class LlmProviderRegistryTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(models, [])
|
||||
|
||||
def test_expired_auth_session_cleanup_removes_state_index(self):
|
||||
"""过期授权会话应同时移除 session 与 OAuth state 索引。"""
|
||||
manager = LLMProviderManager()
|
||||
manager._pending_sessions["session-old"] = PendingAuthSession(
|
||||
session_id="session-old",
|
||||
provider_id="chatgpt",
|
||||
method_id="browser_oauth",
|
||||
flow_type="oauth",
|
||||
expires_at=100,
|
||||
)
|
||||
manager._oauth_state_index["state-old"] = "session-old"
|
||||
|
||||
with manager._lock:
|
||||
manager._cleanup_auth_sessions_locked(now=101)
|
||||
|
||||
self.assertNotIn("session-old", manager._pending_sessions)
|
||||
self.assertNotIn("state-old", manager._oauth_state_index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,7 @@ sys.modules['app.db.systemconfig_oper'] = MagicMock()
|
||||
sys.modules['app.db.systemconfig_oper'].SystemConfigOper.return_value.get.return_value = None
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain, ScrapingOption
|
||||
from app.chain.media import MediaChain, ScrapingConfig, ScrapingOption
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -42,6 +42,20 @@ class TestMediaScrapingPaths(unittest.TestCase):
|
||||
self.assertEqual(target_item, parent_item)
|
||||
self.assertEqual(target_path, Path("/movies/avatar.nfo"))
|
||||
|
||||
def test_scraping_config_does_not_share_policy_state_between_instances(self):
|
||||
"""刮削配置实例之间不应共享已删除或覆盖过的策略。"""
|
||||
first_config = ScrapingConfig({"movie_nfo": ScrapingPolicy.SKIP})
|
||||
second_config = ScrapingConfig({})
|
||||
|
||||
self.assertEqual(
|
||||
ScrapingPolicy.SKIP,
|
||||
first_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
|
||||
)
|
||||
self.assertEqual(
|
||||
ScrapingPolicy.MISSINGONLY,
|
||||
second_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
|
||||
)
|
||||
|
||||
def test_movie_dir_nfo_path(self):
|
||||
fileitem = schemas.FileItem(path="/movies/Avatar (2009)", name="Avatar (2009)", type="dir", storage="local")
|
||||
|
||||
|
||||
@@ -200,6 +200,19 @@ class RcloneStorageTest(unittest.TestCase):
|
||||
self.assertEqual("/Show/", folder.path)
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_folder_lock_table_evicts_old_unlocked_paths(self):
|
||||
"""路径锁表超过上限时应优先淘汰未占用的旧锁。"""
|
||||
with patch.object(rclone_module, "_MAX_FOLDER_LOCKS", 2):
|
||||
first_lock = Rclone._Rclone__get_path_lock(Path("/A"))
|
||||
second_lock = Rclone._Rclone__get_path_lock(Path("/B"))
|
||||
third_lock = Rclone._Rclone__get_path_lock(Path("/C"))
|
||||
|
||||
self.assertNotIn("/A", rclone_module._folder_locks)
|
||||
self.assertIn("/B", rclone_module._folder_locks)
|
||||
self.assertIn("/C", rclone_module._folder_locks)
|
||||
self.assertIsNot(first_lock, third_lock)
|
||||
self.assertIs(second_lock, rclone_module._folder_locks["/B"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
61
tests/test_webpush_subscription.py
Normal file
61
tests/test_webpush_subscription.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.helper.webpush import is_webpush_subscription_gone
|
||||
|
||||
|
||||
class WebPushSubscriptionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""清理跨用例共享的 WebPush 订阅。"""
|
||||
with global_vars.SUBSCRIPTIONS_LOCK:
|
||||
global_vars.SUBSCRIPTIONS.clear()
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试产生的 WebPush 订阅。"""
|
||||
with global_vars.SUBSCRIPTIONS_LOCK:
|
||||
global_vars.SUBSCRIPTIONS.clear()
|
||||
|
||||
def test_push_subscription_upserts_by_endpoint(self):
|
||||
"""相同 endpoint 的 WebPush 订阅应更新而不是重复追加。"""
|
||||
global_vars.push_subscription(
|
||||
{"endpoint": "https://push.example/a", "keys": {"p256dh": "old"}}
|
||||
)
|
||||
global_vars.push_subscription(
|
||||
{"endpoint": "https://push.example/a", "keys": {"p256dh": "new"}}
|
||||
)
|
||||
|
||||
subscriptions = global_vars.get_subscriptions()
|
||||
|
||||
self.assertEqual(1, len(subscriptions))
|
||||
self.assertEqual("new", subscriptions[0]["keys"]["p256dh"])
|
||||
|
||||
def test_remove_subscription_deletes_by_endpoint(self):
|
||||
"""失效订阅应能按 endpoint 从全局订阅表删除。"""
|
||||
subscription = {"endpoint": "https://push.example/a", "keys": {}}
|
||||
global_vars.push_subscription(subscription)
|
||||
|
||||
self.assertTrue(global_vars.remove_subscription(subscription))
|
||||
self.assertEqual([], global_vars.get_subscriptions())
|
||||
|
||||
def test_is_webpush_subscription_gone_matches_404_and_410(self):
|
||||
"""推送服务返回 404/410 时应识别为订阅已失效。"""
|
||||
self.assertTrue(
|
||||
is_webpush_subscription_gone(
|
||||
SimpleNamespace(response=SimpleNamespace(status_code=410))
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
is_webpush_subscription_gone(
|
||||
SimpleNamespace(response=SimpleNamespace(status=404))
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
is_webpush_subscription_gone(
|
||||
SimpleNamespace(response=SimpleNamespace(status_code=500))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user