feat(agent): expand LLM provider and wizard support

This commit is contained in:
jxxghp
2026-05-08 08:09:50 +08:00
parent 2d2c2a01eb
commit 05d720d81f
6 changed files with 1080 additions and 72 deletions

View File

@@ -117,6 +117,58 @@ class LLMProviderManager(metaclass=Singleton):
"gpt-5.4-mini",
"gpt-5.5",
}
_MODELS_DEV_DYNAMIC_SKIP_IDS = {
"aihubmix",
"amazon-bedrock",
"azure",
"azure-cognitive-services",
"cloudflare-ai-gateway",
"cohere",
"gitlab",
"google-vertex",
"google-vertex-anthropic",
"kiro",
"sap-ai-core",
"v0",
"vercel",
}
_MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES = {
"bailing": {
"runtime": "openai_compatible",
"default_base_url": "https://api.tbox.cn/api/llm/v1",
"description": "Bailing OpenAI-compatible 端点。",
},
"cerebras": {
"runtime": "openai_compatible",
"default_base_url": "https://api.cerebras.ai/v1",
"description": "Cerebras 官方兼容端点。",
},
"deepinfra": {
"runtime": "openai_compatible",
"default_base_url": "https://api.deepinfra.com/v1/openai",
"description": "DeepInfra 官方兼容端点。",
},
"mistral": {
"runtime": "openai_compatible",
"default_base_url": "https://api.mistral.ai/v1",
"description": "Mistral 官方兼容端点。",
},
"perplexity": {
"runtime": "openai_compatible",
"default_base_url": "https://api.perplexity.ai/v1",
"description": "Perplexity 官方兼容端点。",
},
"togetherai": {
"runtime": "openai_compatible",
"default_base_url": "https://api.together.xyz/v1",
"description": "Together AI 官方兼容端点。",
},
"venice": {
"runtime": "openai_compatible",
"default_base_url": "https://api.venice.ai/api/v1",
"description": "Venice AI 官方兼容端点。",
},
}
def __init__(self):
self._lock = threading.RLock()
@@ -130,7 +182,7 @@ class LLMProviderManager(metaclass=Singleton):
)
@staticmethod
def _provider_specs() -> tuple[ProviderSpec, ...]:
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
"""
返回受支持的 provider 定义。
@@ -708,43 +760,253 @@ class LLMProviderManager(metaclass=Singleton):
)
return tuple(providers)
def _cached_models_dev_payload(self) -> dict[str, Any]:
if isinstance(self._models_dev_data, dict):
return self._models_dev_data
try:
if not self._models_dev_cache_path.exists():
return {}
payload = json.loads(self._models_dev_cache_path.read_text(encoding="utf-8"))
except Exception as err:
logger.warning(f"读取 models.dev provider 缓存失败: {err}")
return {}
if not isinstance(payload, dict):
return {}
self._models_dev_data = payload
return payload
@staticmethod
def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]:
raw_env_names = payload.get("env")
if not isinstance(raw_env_names, list):
return ()
env_names = []
for item in raw_env_names:
value = str(item or "").strip()
if value:
env_names.append(value)
return tuple(env_names)
@classmethod
def _models_dev_reserved_provider_ids(
cls, specs: tuple[ProviderSpec, ...]
) -> set[str]:
reserved_ids: set[str] = set()
for spec in specs:
if spec.models_dev_provider_id:
reserved_ids.add(spec.models_dev_provider_id)
for preset in spec.base_url_presets:
if preset.models_dev_provider_id:
reserved_ids.add(preset.models_dev_provider_id)
return reserved_ids
@staticmethod
def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str:
first_env = env_names[0].upper() if env_names else ""
if "TOKEN" in first_env and "KEY" not in first_env:
return "API Token"
return "API Key"
@classmethod
def _normalize_models_dev_base_url(
cls, runtime: str, base_url: Optional[str]
) -> Optional[str]:
normalized = cls._sanitize_base_url(base_url)
if not normalized:
return None
suffixes = {
"openai_compatible": (
"/chat/completions",
"/completions",
"/responses",
"/embeddings",
"/audio/speech",
"/audio/transcriptions",
),
"anthropic_compatible": (
"/messages",
),
}
for suffix in suffixes.get(runtime, ()):
if normalized.endswith(suffix):
normalized = normalized[: -len(suffix)]
break
return cls._sanitize_base_url(normalized)
@classmethod
def _models_dev_dynamic_provider_spec(
cls,
provider_id: str,
payload: dict[str, Any],
sort_order: int,
) -> ProviderSpec | None:
normalized_id = str(provider_id or "").strip().lower()
if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS:
return None
override = cls._MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES.get(normalized_id, {})
npm_package = str(payload.get("npm") or "").strip()
runtime = override.get("runtime")
if not runtime:
if npm_package == "@ai-sdk/openai-compatible":
runtime = "openai_compatible"
elif npm_package == "@ai-sdk/anthropic":
runtime = "anthropic_compatible"
else:
return None
model_list_strategy = override.get("model_list_strategy")
if not model_list_strategy:
model_list_strategy = (
"anthropic_compatible"
if runtime == "anthropic_compatible"
else "models_dev_only"
)
default_base_url = cls._normalize_models_dev_base_url(
runtime=runtime,
base_url=override.get("default_base_url") or payload.get("api"),
)
requires_base_url = not bool(default_base_url)
env_names = cls._models_dev_env_names(payload)
api_key_label = override.get("api_key_label") or cls._dynamic_api_key_label(
env_names
)
name = str(payload.get("name") or override.get("name") or normalized_id).strip()
description = override.get("description")
if not description:
transport_name = "Anthropic-compatible" if runtime == "anthropic_compatible" else "OpenAI-compatible"
description = f"{name} {transport_name} 端点(来自 models.dev 目录)。"
api_key_hint = override.get("api_key_hint")
if not api_key_hint:
api_key_hint = f"填写 {name} {api_key_label}"
if requires_base_url:
api_key_hint = f"填写 {name} {api_key_label},并手动填写 Base URL。"
return ProviderSpec(
id=normalized_id,
name=name,
runtime=runtime,
models_dev_provider_id=normalized_id,
default_base_url=default_base_url,
base_url_editable=True,
requires_base_url=requires_base_url,
api_key_label=api_key_label,
api_key_hint=api_key_hint,
model_list_strategy=model_list_strategy,
description=description,
sort_order=sort_order,
)
def _dynamic_provider_specs(
self, builtin_specs: tuple[ProviderSpec, ...]
) -> tuple[ProviderSpec, ...]:
payload = self._cached_models_dev_payload()
if not payload:
return ()
explicit_ids = {spec.id for spec in builtin_specs}
reserved_ids = self._models_dev_reserved_provider_ids(builtin_specs)
candidates: list[tuple[str, str, dict[str, Any]]] = []
for provider_id, provider_payload in payload.items():
normalized_id = str(provider_id or "").strip().lower()
if not normalized_id or not isinstance(provider_payload, dict):
continue
if normalized_id in explicit_ids or normalized_id in reserved_ids:
continue
spec = self._models_dev_dynamic_provider_spec(
provider_id=normalized_id,
payload=provider_payload,
sort_order=0,
)
if not spec:
continue
candidates.append((spec.name.lower(), normalized_id, provider_payload))
dynamic_specs = []
for sort_order, (_, provider_id, provider_payload) in enumerate(
sorted(candidates),
start=700,
):
spec = self._models_dev_dynamic_provider_spec(
provider_id=provider_id,
payload=provider_payload,
sort_order=sort_order,
)
if not spec:
continue
dynamic_specs.append(spec)
return tuple(dynamic_specs)
def _provider_specs(self) -> tuple[ProviderSpec, ...]:
builtin_specs = self._builtin_provider_specs()
return builtin_specs + self._dynamic_provider_specs(builtin_specs)
async def _get_provider_async(
self, provider_id: str, force_refresh: bool = False
) -> ProviderSpec:
try:
return self.get_provider(provider_id)
except LLMProviderError:
await self.get_models_dev_data(force_refresh=force_refresh)
return self.get_provider(provider_id)
def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]:
return {
"id": spec.id,
"name": spec.name,
"runtime": spec.runtime,
"default_base_url": self._default_base_url_for_provider(spec) or "",
"base_url_presets": [
{
"label": preset.label,
"value": self._sanitize_base_url(preset.value) or "",
}
for preset in spec.base_url_presets
],
"base_url_editable": spec.base_url_editable,
"requires_base_url": spec.requires_base_url,
"supports_api_key": spec.supports_api_key,
"api_key_label": spec.api_key_label,
"api_key_hint": spec.api_key_hint,
"supports_model_refresh": spec.supports_model_refresh,
"oauth_methods": [
{
"id": method.id,
"type": method.type,
"label": method.label,
"description": method.description,
}
for method in spec.oauth_methods
],
"description": spec.description,
"auth_status": self.get_auth_status(spec.id),
}
async def list_providers_async(
self, force_refresh: bool = False
) -> list[dict[str, Any]]:
"""返回前端可渲染的 provider 目录,并优先补齐 models.dev 动态平台。"""
try:
await self.get_models_dev_data(force_refresh=force_refresh)
except Exception as err:
logger.debug(f"加载 models.dev provider 目录失败,回退内置列表: {err}")
return self.list_providers()
def list_providers(self) -> list[dict[str, Any]]:
"""返回前端可渲染的 provider 目录。"""
providers = []
for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order):
providers.append(
{
"id": spec.id,
"name": spec.name,
"runtime": spec.runtime,
"default_base_url": self._default_base_url_for_provider(spec) or "",
"base_url_presets": [
{
"label": preset.label,
"value": self._sanitize_base_url(preset.value) or "",
}
for preset in spec.base_url_presets
],
"base_url_editable": spec.base_url_editable,
"requires_base_url": spec.requires_base_url,
"supports_api_key": spec.supports_api_key,
"api_key_label": spec.api_key_label,
"api_key_hint": spec.api_key_hint,
"supports_model_refresh": spec.supports_model_refresh,
"oauth_methods": [
{
"id": method.id,
"type": method.type,
"label": method.label,
"description": method.description,
}
for method in spec.oauth_methods
],
"description": spec.description,
"auth_status": self.get_auth_status(spec.id),
}
)
return providers
return [
self._serialize_provider(spec)
for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order)
]
def get_provider(self, provider_id: str) -> ProviderSpec:
"""按 provider id 获取定义。"""
@@ -973,7 +1235,7 @@ class LLMProviderManager(metaclass=Singleton):
async def _models_dev_provider_payload(
self, provider_id: str, base_url: Optional[str] = None
) -> dict[str, Any]:
spec = self.get_provider(provider_id)
spec = await self._get_provider_async(provider_id)
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
spec,
base_url,
@@ -1313,7 +1575,7 @@ class LLMProviderManager(metaclass=Singleton):
force_refresh: bool = False,
) -> list[dict[str, Any]]:
"""返回标准化后的模型目录。"""
spec = self.get_provider(provider_id)
spec = await self._get_provider_async(provider_id, force_refresh=force_refresh)
if self._resolve_provider_models_dev_provider_id(spec, base_url):
# 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表”
# 在使用目录型 provider 时也能拿到最新参数。
@@ -1449,7 +1711,7 @@ class LLMProviderManager(metaclass=Singleton):
API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。
"""
provider = self.get_provider(provider_id)
provider = await self._get_provider_async(provider_id)
method = next(
(item for item in provider.oauth_methods if item.id == method_id),
None,
@@ -1844,7 +2106,7 @@ class LLMProviderManager(metaclass=Singleton):
返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。
"""
spec = self.get_provider(provider_id)
spec = await self._get_provider_async(provider_id)
normalized_api_key = str(api_key or "").strip() or None
normalized_base_url = self._sanitize_base_url(base_url)
model_record = None

View File

@@ -98,7 +98,7 @@ async def get_llm_providers(
返回前端可直接渲染的 provider 目录。
"""
try:
providers = LLMProviderManager().list_providers()
providers = await LLMProviderManager().list_providers_async()
return schemas.Response(success=True, data=providers)
except Exception as err:
return schemas.Response(success=False, message=str(err))

View File

@@ -501,7 +501,7 @@ class ConfigModel(BaseModel):
AI_AGENT_ENABLE: bool = False
# 合局AI智能体
AI_AGENT_GLOBAL: bool = False
# LLM提供商 (openai/google/deepseek)
# LLM提供商(支持内置 provider以及从 models.dev 动态补充的平台)
LLM_PROVIDER: str = "deepseek"
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import asyncio
import getpass
import importlib.util
import json
import os
import platform
@@ -66,6 +67,26 @@ LLM_PROVIDER_DEFAULTS = {
"model": "gemini-2.5-flash",
"base_url": "",
},
"anthropic": {
"model": "claude-sonnet-4-0",
"base_url": "https://api.anthropic.com/v1",
},
"openrouter": {
"model": "openai/gpt-4.1-mini",
"base_url": "https://openrouter.ai/api/v1",
},
"groq": {
"model": "llama-3.3-70b-versatile",
"base_url": "https://api.groq.com/openai/v1",
},
}
LLM_PROVIDER_FALLBACK_CHOICES = {
"deepseek": "DeepSeek",
"openai": "OpenAI Compatible",
"google": "Google",
"anthropic": "Anthropic",
"openrouter": "OpenRouter",
"groq": "Groq",
}
RUNTIME_PACKAGE = {
"name": "moviepilot-frontend-runtime",
@@ -1063,6 +1084,273 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
print("请输入列表中的可选值。")
def _prompt_provider_choice(label: str, choices: dict[str, str], default: str) -> str:
labels = []
normalized_map: dict[str, str] = {}
for key, desc in choices.items():
labels.append(f"{key}({desc})")
normalized_map[_normalize_choice(key)] = key
preview_limit = 12
print("可用 LLM 提供商:")
for item in labels[:preview_limit]:
print(f" {item}")
if len(labels) > preview_limit:
print(f" ... 另有 {len(labels) - preview_limit} 个,可直接输入 provider id")
while True:
raw = input(f"{label} (默认 {default},可直接输入 provider id): ").strip()
if not raw:
return default
normalized = _normalize_choice(raw)
if normalized in normalized_map:
return normalized_map[normalized]
provider_id = raw.strip().lower()
if re.fullmatch(r"[a-z0-9][a-z0-9._-]*", provider_id):
return provider_id
print("请输入列表中的可选值,或合法的 provider id小写字母/数字/.-_")
def _load_llm_provider_module():
provider_path = ROOT / "app" / "agent" / "llm" / "provider.py"
module_name = f"moviepilot_local_llm_provider_{uuid.uuid4().hex}"
spec = importlib.util.spec_from_file_location(module_name, provider_path)
if not spec or not spec.loader:
raise RuntimeError("无法加载 LLM provider 模块")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _load_llm_provider_definitions_inner() -> list[dict[str, Any]]:
provider_module = _load_llm_provider_module()
providers = asyncio.run(provider_module.LLMProviderManager().list_providers_async())
return providers if isinstance(providers, list) else []
def _load_llm_provider_definitions(
runtime_python: Optional[Path] = None,
) -> list[dict[str, Any]]:
try:
return _load_llm_provider_definitions_inner()
except Exception as exc:
if runtime_python and not _current_python_matches(runtime_python):
try:
with TemporaryDirectory() as temp_dir:
output_path = Path(temp_dir) / "llm-providers.json"
subprocess.run(
[
str(runtime_python),
str(Path(__file__).resolve()),
"query-llm-providers",
"--output-json-file",
str(output_path),
],
cwd=str(ROOT),
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
data = json.loads(output_path.read_text(encoding="utf-8"))
if isinstance(data, list):
return data
except Exception as runtime_exc:
print_step(
f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{runtime_exc}"
)
return []
print_step(f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{exc}")
return []
def _llm_provider_choice_map(
provider_definitions: list[dict[str, Any]],
) -> dict[str, str]:
choices: dict[str, str] = {}
for item in provider_definitions:
if not isinstance(item, dict):
continue
if item.get("supports_api_key") is False:
continue
provider_id = str(item.get("id") or "").strip().lower()
name = str(item.get("name") or provider_id).strip()
if not provider_id or not name:
continue
choices[provider_id] = name
if choices:
return choices
return dict(LLM_PROVIDER_FALLBACK_CHOICES)
def _llm_provider_defaults(
provider: str,
provider_definitions: list[dict[str, Any]],
) -> dict[str, str]:
normalized_provider = str(provider or "").strip().lower()
defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {})
provider_meta = next(
(
item
for item in provider_definitions
if isinstance(item, dict)
and str(item.get("id") or "").strip().lower() == normalized_provider
),
None,
)
if isinstance(provider_meta, dict):
default_base_url = str(provider_meta.get("default_base_url") or "").strip()
if default_base_url:
defaults["base_url"] = default_base_url
defaults.setdefault("model", _env_default("LLM_MODEL", ""))
defaults.setdefault("base_url", _env_default("LLM_BASE_URL", ""))
return defaults
def _llm_provider_meta(
provider: str,
provider_definitions: list[dict[str, Any]],
) -> dict[str, Any]:
normalized_provider = str(provider or "").strip().lower()
provider_meta = next(
(
item
for item in provider_definitions
if isinstance(item, dict)
and str(item.get("id") or "").strip().lower() == normalized_provider
),
None,
)
return dict(provider_meta) if isinstance(provider_meta, dict) else {}
def _load_llm_models_inner(payload: dict[str, Any]) -> list[dict[str, Any]]:
provider = str(payload.get("provider") or "").strip().lower()
if not provider:
raise RuntimeError("缺少 LLM provider")
provider_module = _load_llm_provider_module()
api_key = str(payload.get("api_key") or "").strip() or None
base_url = str(payload.get("base_url") or "").strip() or None
models = asyncio.run(
provider_module.LLMProviderManager().list_models(
provider_id=provider,
api_key=api_key,
base_url=base_url,
force_refresh=False,
)
)
return models if isinstance(models, list) else []
def _load_llm_models(
*,
provider: str,
api_key: Optional[str],
base_url: Optional[str],
runtime_python: Optional[Path] = None,
) -> list[dict[str, Any]]:
payload = {
"provider": str(provider or "").strip().lower(),
"api_key": str(api_key or "").strip(),
"base_url": str(base_url or "").strip(),
}
try:
return _load_llm_models_inner(payload)
except Exception as exc:
if runtime_python and not _current_python_matches(runtime_python):
try:
with TemporaryDirectory() as temp_dir:
request_path = Path(temp_dir) / "llm-models-request.json"
output_path = Path(temp_dir) / "llm-models.json"
request_path.write_text(
json.dumps(payload, ensure_ascii=False), encoding="utf-8"
)
subprocess.run(
[
str(runtime_python),
str(Path(__file__).resolve()),
"query-llm-models",
"--request-json-file",
str(request_path),
"--output-json-file",
str(output_path),
],
cwd=str(ROOT),
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
data = json.loads(output_path.read_text(encoding="utf-8"))
if isinstance(data, list):
return data
except Exception as runtime_exc:
print_step(
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{runtime_exc}"
)
return []
print_step(
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{exc}"
)
return []
def _print_llm_models(models: list[dict[str, Any]], limit: int = 20) -> None:
print("可用模型:")
for index, item in enumerate(models[:limit], start=1):
if not isinstance(item, dict):
continue
model_id = str(item.get("id") or "").strip()
if not model_id:
continue
model_name = str(item.get("name") or model_id).strip()
extras: list[str] = []
if item.get("context_tokens_k"):
extras.append(f"{item['context_tokens_k']}K")
if item.get("supports_reasoning"):
extras.append("reasoning")
if item.get("supports_tools"):
extras.append("tools")
if item.get("supports_image_input"):
extras.append("vision")
extra_text = f" [{' / '.join(extras)}]" if extras else ""
if model_name != model_id:
print(f" {index}. {model_id} ({model_name}){extra_text}")
else:
print(f" {index}. {model_id}{extra_text}")
if len(models) > limit:
print(f" ... 共 {len(models)} 个模型,可输入编号或直接输入模型名称")
def _prompt_model_choice(models: list[dict[str, Any]], default: Optional[str] = None) -> str:
valid_models = [item for item in models if isinstance(item, dict) and item.get("id")]
if not valid_models:
return _prompt_text("LLM 模型名称", default=default)
indexed_models = {
str(index): str(item.get("id")).strip()
for index, item in enumerate(valid_models, start=1)
}
default_model = str(default or indexed_models.get("1") or "").strip()
_print_llm_models(valid_models)
while True:
raw = input(
f"LLM 模型名称/编号{' [' + default_model + ']' if default_model else ''}: "
).strip()
if not raw and default_model:
return default_model
if raw in indexed_models:
return indexed_models[raw]
if raw:
return raw
print("请输入有效模型编号或模型名称。")
def _env_llm_thinking_level_default() -> str:
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
alias_map = {
@@ -1459,7 +1747,9 @@ def _collect_notification_config() -> Optional[dict[str, Any]]:
}
def _collect_agent_config() -> dict[str, Any]:
def _collect_agent_config(
runtime_python: Optional[Path] = None,
) -> dict[str, Any]:
print_step("AI Agent 配置")
enabled = _prompt_yes_no(
"是否启用 AI 智能体",
@@ -1471,22 +1761,42 @@ def _collect_agent_config() -> dict[str, Any]:
"AI_AGENT_GLOBAL": False,
}
provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python)
provider_choices = _llm_provider_choice_map(provider_definitions)
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
if current_provider not in LLM_PROVIDER_DEFAULTS:
if current_provider not in provider_choices:
current_provider = "deepseek"
provider = _prompt_choice(
"选择 LLM 提供商",
{
"deepseek": "DeepSeek",
"openai": "OpenAI",
"google": "Google",
},
default=current_provider,
)
defaults = LLM_PROVIDER_DEFAULTS[provider]
while True:
provider = _prompt_provider_choice(
"选择 LLM 提供商",
provider_choices,
default=current_provider,
)
provider_meta = _llm_provider_meta(provider, provider_definitions)
if provider_meta.get("supports_api_key") is False:
print_step(
f"{provider_meta.get('name') or provider} 当前仅支持交互式授权,安装向导暂不支持,请改选可填写 API Key 的 provider"
)
current_provider = "deepseek"
continue
break
defaults = _llm_provider_defaults(provider, provider_definitions)
current_model = _env_default("LLM_MODEL", defaults["model"])
current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"])
api_key_label = str(provider_meta.get("api_key_label") or "API Key").strip() or "API Key"
api_key_hint = str(provider_meta.get("api_key_hint") or "").strip()
requires_base_url = bool(provider_meta.get("requires_base_url"))
base_url_label = (
"自定义 Google API Base URL可选"
if provider == "google"
else "LLM Base URL必填"
if requires_base_url
else "LLM Base URL"
)
if api_key_hint:
print_step(api_key_hint)
config: dict[str, Any] = {
"AI_AGENT_ENABLE": True,
@@ -1495,12 +1805,8 @@ def _collect_agent_config() -> dict[str, Any]:
default=_env_bool("AI_AGENT_GLOBAL", False),
),
"LLM_PROVIDER": provider,
"LLM_MODEL": _prompt_text(
"LLM 模型名称",
default=current_model,
),
"LLM_API_KEY": _prompt_secret_text(
"LLM API Key",
f"LLM {api_key_label}",
current_value=read_env_value("LLM_API_KEY"),
required=True,
),
@@ -1524,18 +1830,18 @@ def _collect_agent_config() -> dict[str, Any]:
),
}
if provider == "google":
config["LLM_BASE_URL"] = _prompt_text(
"自定义 Google API Base URL可选",
default=current_base_url,
allow_empty=True,
)
else:
config["LLM_BASE_URL"] = _prompt_text(
"LLM Base URL",
default=current_base_url,
allow_empty=True,
)
config["LLM_BASE_URL"] = _prompt_text(
base_url_label,
default=current_base_url,
allow_empty=not requires_base_url,
)
models = _load_llm_models(
provider=provider,
api_key=config["LLM_API_KEY"],
base_url=config["LLM_BASE_URL"],
runtime_python=runtime_python,
)
config["LLM_MODEL"] = _prompt_model_choice(models, default=current_model)
return config
@@ -1744,7 +2050,7 @@ def run_setup_wizard(
preset_password=preset_superuser_password,
),
**_collect_database_config(),
**_collect_agent_config(),
**_collect_agent_config(runtime_python=runtime_python),
},
"directories": [_collect_directory_config()],
"downloader": _collect_downloader_config(),
@@ -3276,6 +3582,23 @@ def build_parser() -> argparse.ArgumentParser:
"--output-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_providers_parser = subparsers.add_parser(
"query-llm-providers", help=argparse.SUPPRESS
)
query_llm_providers_parser.add_argument(
"--output-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_models_parser = subparsers.add_parser(
"query-llm-models", help=argparse.SUPPRESS
)
query_llm_models_parser.add_argument(
"--request-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_models_parser.add_argument(
"--output-json-file", required=True, help=argparse.SUPPRESS
)
return parser
@@ -3461,6 +3784,25 @@ def main() -> int:
encoding="utf-8",
)
return 0
if args.command == "query-llm-providers":
payload = _load_llm_provider_definitions_inner()
Path(args.output_json_file).write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return 0
if args.command == "query-llm-models":
payload = json.loads(Path(args.request_json_file).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("模型查询负载格式错误")
models = _load_llm_models_inner(payload)
Path(args.output_json_file).write_text(
json.dumps(models, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return 0
except subprocess.CalledProcessError as exc:
print(f"命令执行失败,退出码:{exc.returncode}", file=sys.stderr)
return exc.returncode

View File

@@ -0,0 +1,215 @@
import asyncio
import importlib.util
import sys
import unittest
from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import AsyncMock, patch
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
class _DummySystemConfigOper:
def get(self, _key):
return {}
async def async_set(self, _key, _value):
return None
for _module_name in ("aiofiles", "jwt"):
_stub_module(_module_name)
_stub_module(
"app.core.config",
settings=SimpleNamespace(
TEMP_PATH="/tmp",
PROXY_HOST=None,
LLM_MAX_CONTEXT_TOKENS=64,
),
)
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper)
_stub_module("app.log", logger=_DummyLogger())
_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent"))
provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py"
spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path)
provider_module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
sys.modules[spec.name] = provider_module
spec.loader.exec_module(provider_module)
LLMProviderError = provider_module.LLMProviderError
LLMProviderManager = provider_module.LLMProviderManager
class LlmProviderRegistryTest(unittest.TestCase):
def setUp(self):
LLMProviderManager._instances.clear()
def tearDown(self):
LLMProviderManager._instances.clear()
def test_dynamic_provider_is_exposed_from_models_dev_cache(self):
manager = LLMProviderManager()
manager._models_dev_data = {
"frogbot": {
"id": "frogbot",
"name": "FrogBot",
"npm": "@ai-sdk/openai-compatible",
"env": ["FROGBOT_API_KEY"],
"api": "https://app.frogbot.ai/api/v1",
"models": {},
}
}
provider = manager.get_provider("frogbot")
self.assertEqual(provider.id, "frogbot")
self.assertEqual(provider.runtime, "openai_compatible")
self.assertEqual(provider.default_base_url, "https://app.frogbot.ai/api/v1")
self.assertFalse(provider.requires_base_url)
self.assertTrue(provider.base_url_editable)
self.assertEqual(provider.model_list_strategy, "models_dev_only")
def test_dynamic_provider_override_normalizes_chat_endpoint_base_url(self):
manager = LLMProviderManager()
manager._models_dev_data = {
"bailing": {
"id": "bailing",
"name": "Bailing",
"npm": "@ai-sdk/openai-compatible",
"env": ["BAILING_API_TOKEN"],
"api": "https://api.tbox.cn/api/llm/v1/chat/completions",
"models": {},
}
}
provider = manager.get_provider("bailing")
self.assertEqual(provider.default_base_url, "https://api.tbox.cn/api/llm/v1")
self.assertEqual(provider.api_key_label, "API Token")
def test_dynamic_provider_skips_alias_only_models_dev_ids(self):
manager = LLMProviderManager()
manager._models_dev_data = {
"moonshotai": {
"id": "moonshotai",
"name": "Moonshot AI Intl",
"npm": "@ai-sdk/openai-compatible",
"env": ["MOONSHOT_API_KEY"],
"api": "https://api.moonshot.ai/v1",
"models": {},
}
}
with self.assertRaises(LLMProviderError):
manager.get_provider("moonshotai")
def test_dynamic_provider_skips_incompatible_models_dev_provider(self):
manager = LLMProviderManager()
manager._models_dev_data = {
"azure": {
"id": "azure",
"name": "Azure",
"npm": "@ai-sdk/azure",
"env": ["AZURE_API_KEY"],
"models": {},
}
}
with self.assertRaises(LLMProviderError):
manager.get_provider("azure")
def test_dynamic_provider_without_known_base_url_requires_manual_input(self):
manager = LLMProviderManager()
manager._models_dev_data = {
"custom-anthropic": {
"id": "custom-anthropic",
"name": "Custom Anthropic",
"npm": "@ai-sdk/anthropic",
"env": ["CUSTOM_ANTHROPIC_KEY"],
"models": {},
}
}
provider = manager.get_provider("custom-anthropic")
self.assertEqual(provider.runtime, "anthropic_compatible")
self.assertTrue(provider.requires_base_url)
self.assertTrue(provider.base_url_editable)
self.assertEqual(provider.model_list_strategy, "anthropic_compatible")
def test_list_providers_async_loads_models_dev_before_serializing(self):
manager = LLMProviderManager()
payload = {
"frogbot": {
"id": "frogbot",
"name": "FrogBot",
"npm": "@ai-sdk/openai-compatible",
"env": ["FROGBOT_API_KEY"],
"api": "https://app.frogbot.ai/api/v1",
"models": {},
}
}
with patch.object(
manager,
"get_models_dev_data",
AsyncMock(side_effect=lambda force_refresh=False: manager.__dict__.update({"_models_dev_data": payload}) or payload),
) as fetch_mock:
providers = asyncio.run(manager.list_providers_async())
fetch_mock.assert_awaited_once_with(force_refresh=False)
self.assertIn("frogbot", {item["id"] for item in providers})
def test_list_models_uses_dynamic_provider_after_refresh(self):
manager = LLMProviderManager()
payload = {
"frogbot": {
"id": "frogbot",
"name": "FrogBot",
"npm": "@ai-sdk/openai-compatible",
"env": ["FROGBOT_API_KEY"],
"api": "https://app.frogbot.ai/api/v1",
"models": {
"frog-1": {
"name": "Frog 1",
"limit": {"context": 131072},
}
},
}
}
async def _load_models_dev(force_refresh: bool = False):
manager._models_dev_data = payload
return payload
with patch.object(manager, "get_models_dev_data", AsyncMock(side_effect=_load_models_dev)):
models = asyncio.run(
manager.list_models(
provider_id="frogbot",
api_key="sk-test",
force_refresh=True,
)
)
self.assertEqual([item["id"] for item in models], ["frog-1"])
self.assertEqual(models[0]["source"], "models.dev")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import importlib.util
import unittest
import uuid
from pathlib import Path
from unittest.mock import patch
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "local_setup.py"
def load_local_setup_module():
module_name = f"moviepilot_local_setup_llm_{uuid.uuid4().hex}"
spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
class LocalSetupLlmProviderPromptTests(unittest.TestCase):
def test_collect_agent_config_prefers_loaded_provider_directory(self):
module = load_local_setup_module()
provider_definitions = [
{
"id": "frogbot",
"name": "FrogBot",
"default_base_url": "https://app.frogbot.ai/api/v1",
"api_key_label": "API Key",
}
]
models = [
{"id": "frog-1", "name": "Frog 1", "context_tokens_k": 128},
{"id": "frog-2", "name": "Frog 2"},
]
with patch.object(module, "print_step"), patch.object(
module, "_prompt_yes_no", side_effect=[True, False, True]
), patch.object(
module, "_load_llm_provider_definitions", return_value=provider_definitions
), patch.object(
module, "_prompt_provider_choice", return_value="frogbot"
) as provider_prompt, patch.object(
module, "_prompt_text", side_effect=["https://override.example.com/v1"]
), patch.object(
module, "_prompt_secret_text", return_value="sk-frog"
), patch.object(
module, "_load_llm_models", return_value=models
) as load_models, patch.object(
module, "_prompt_model_choice", return_value="frog-2"
) as model_prompt, patch.object(
module, "read_env_value", return_value=None
), patch.object(
module, "_env_default", side_effect=lambda key, default="": default
), patch.object(
module, "_env_bool", side_effect=lambda key, default: default
), patch.object(
module, "_env_llm_thinking_level_default", return_value="auto"
), patch.object(
module, "_prompt_choice", return_value="auto"
):
config = module._collect_agent_config(runtime_python=Path("/tmp/runtime-python"))
provider_prompt.assert_called_once()
load_models.assert_called_once_with(
provider="frogbot",
api_key="sk-frog",
base_url="https://override.example.com/v1",
runtime_python=Path("/tmp/runtime-python"),
)
model_prompt.assert_called_once_with(models, default="")
self.assertEqual(config["LLM_PROVIDER"], "frogbot")
self.assertEqual(config["LLM_MODEL"], "frog-2")
self.assertEqual(config["LLM_API_KEY"], "sk-frog")
self.assertEqual(config["LLM_BASE_URL"], "https://override.example.com/v1")
def test_collect_agent_config_falls_back_to_common_provider_choices(self):
module = load_local_setup_module()
with patch.object(module, "print_step"), patch.object(
module, "_prompt_yes_no", side_effect=[True, False, True]
), patch.object(
module, "_load_llm_provider_definitions", return_value=[]
), patch.object(
module, "_prompt_provider_choice", return_value="anthropic"
), patch.object(
module, "_prompt_text", side_effect=["https://api.anthropic.com/v1"]
), patch.object(
module, "_prompt_secret_text", return_value="sk-anthropic"
), patch.object(
module, "_load_llm_models", return_value=[]
), patch.object(
module, "_prompt_model_choice", return_value="claude-sonnet-4-0"
), patch.object(
module, "read_env_value", return_value=None
), patch.object(
module, "_env_default", side_effect=lambda key, default="": default
), patch.object(
module, "_env_bool", side_effect=lambda key, default: default
), patch.object(
module, "_env_llm_thinking_level_default", return_value="off"
), patch.object(
module, "_prompt_choice", return_value="off"
):
config = module._collect_agent_config()
self.assertEqual(config["LLM_PROVIDER"], "anthropic")
self.assertEqual(config["LLM_MODEL"], "claude-sonnet-4-0")
self.assertEqual(config["LLM_BASE_URL"], "https://api.anthropic.com/v1")
def test_prompt_model_choice_accepts_index_selection(self):
module = load_local_setup_module()
with patch.object(module, "_print_llm_models") as print_models, patch(
"builtins.input", return_value="2"
):
model = module._prompt_model_choice(
[
{"id": "model-a", "name": "Model A"},
{"id": "model-b", "name": "Model B"},
],
default="model-a",
)
print_models.assert_called_once()
self.assertEqual(model, "model-b")
def test_prompt_model_choice_falls_back_to_text_input_when_empty(self):
module = load_local_setup_module()
with patch.object(module, "_prompt_text", return_value="custom-model") as prompt_text:
model = module._prompt_model_choice([], default="")
prompt_text.assert_called_once_with("LLM 模型名称", default="")
self.assertEqual(model, "custom-model")
def test_load_llm_provider_definitions_inner_uses_direct_provider_module_loader(self):
module = load_local_setup_module()
class _FakeManager:
async def list_providers_async(self, force_refresh: bool = False):
return [{"id": "frogbot", "name": "FrogBot"}]
class _FakeProviderModule:
@staticmethod
def LLMProviderManager():
return _FakeManager()
fake_provider_module = _FakeProviderModule()
with patch.object(
module,
"_load_llm_provider_module",
return_value=fake_provider_module,
) as loader:
providers = module._load_llm_provider_definitions_inner()
loader.assert_called_once_with()
self.assertEqual(providers, [{"id": "frogbot", "name": "FrogBot"}])
def test_llm_provider_choice_map_skips_oauth_only_provider(self):
module = load_local_setup_module()
choices = module._llm_provider_choice_map(
[
{"id": "chatgpt", "name": "ChatGPT", "supports_api_key": True},
{"id": "github-copilot", "name": "GitHub Copilot", "supports_api_key": False},
]
)
self.assertEqual(choices, {"chatgpt": "ChatGPT"})
def test_prompt_provider_choice_accepts_custom_provider_id(self):
module = load_local_setup_module()
with patch("builtins.input", return_value="my-provider_01"), patch("builtins.print"):
provider = module._prompt_provider_choice(
"选择 LLM 提供商",
{"deepseek": "DeepSeek", "google": "Google"},
default="deepseek",
)
self.assertEqual(provider, "my-provider_01")
if __name__ == "__main__":
unittest.main()