mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-13 07:26:45 +00:00
feat(agent): expand LLM provider and wizard support
This commit is contained in:
@@ -117,6 +117,58 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"gpt-5.4-mini",
|
||||
"gpt-5.5",
|
||||
}
|
||||
_MODELS_DEV_DYNAMIC_SKIP_IDS = {
|
||||
"aihubmix",
|
||||
"amazon-bedrock",
|
||||
"azure",
|
||||
"azure-cognitive-services",
|
||||
"cloudflare-ai-gateway",
|
||||
"cohere",
|
||||
"gitlab",
|
||||
"google-vertex",
|
||||
"google-vertex-anthropic",
|
||||
"kiro",
|
||||
"sap-ai-core",
|
||||
"v0",
|
||||
"vercel",
|
||||
}
|
||||
_MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES = {
|
||||
"bailing": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.tbox.cn/api/llm/v1",
|
||||
"description": "Bailing OpenAI-compatible 端点。",
|
||||
},
|
||||
"cerebras": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.cerebras.ai/v1",
|
||||
"description": "Cerebras 官方兼容端点。",
|
||||
},
|
||||
"deepinfra": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.deepinfra.com/v1/openai",
|
||||
"description": "DeepInfra 官方兼容端点。",
|
||||
},
|
||||
"mistral": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.mistral.ai/v1",
|
||||
"description": "Mistral 官方兼容端点。",
|
||||
},
|
||||
"perplexity": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.perplexity.ai/v1",
|
||||
"description": "Perplexity 官方兼容端点。",
|
||||
},
|
||||
"togetherai": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.together.xyz/v1",
|
||||
"description": "Together AI 官方兼容端点。",
|
||||
},
|
||||
"venice": {
|
||||
"runtime": "openai_compatible",
|
||||
"default_base_url": "https://api.venice.ai/api/v1",
|
||||
"description": "Venice AI 官方兼容端点。",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
@@ -130,7 +182,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _provider_specs() -> tuple[ProviderSpec, ...]:
|
||||
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
|
||||
"""
|
||||
返回受支持的 provider 定义。
|
||||
|
||||
@@ -708,43 +760,253 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
return tuple(providers)
|
||||
|
||||
def _cached_models_dev_payload(self) -> dict[str, Any]:
|
||||
if isinstance(self._models_dev_data, dict):
|
||||
return self._models_dev_data
|
||||
|
||||
try:
|
||||
if not self._models_dev_cache_path.exists():
|
||||
return {}
|
||||
payload = json.loads(self._models_dev_cache_path.read_text(encoding="utf-8"))
|
||||
except Exception as err:
|
||||
logger.warning(f"读取 models.dev provider 缓存失败: {err}")
|
||||
return {}
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
|
||||
self._models_dev_data = payload
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]:
|
||||
raw_env_names = payload.get("env")
|
||||
if not isinstance(raw_env_names, list):
|
||||
return ()
|
||||
env_names = []
|
||||
for item in raw_env_names:
|
||||
value = str(item or "").strip()
|
||||
if value:
|
||||
env_names.append(value)
|
||||
return tuple(env_names)
|
||||
|
||||
@classmethod
|
||||
def _models_dev_reserved_provider_ids(
|
||||
cls, specs: tuple[ProviderSpec, ...]
|
||||
) -> set[str]:
|
||||
reserved_ids: set[str] = set()
|
||||
for spec in specs:
|
||||
if spec.models_dev_provider_id:
|
||||
reserved_ids.add(spec.models_dev_provider_id)
|
||||
for preset in spec.base_url_presets:
|
||||
if preset.models_dev_provider_id:
|
||||
reserved_ids.add(preset.models_dev_provider_id)
|
||||
return reserved_ids
|
||||
|
||||
@staticmethod
|
||||
def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str:
|
||||
first_env = env_names[0].upper() if env_names else ""
|
||||
if "TOKEN" in first_env and "KEY" not in first_env:
|
||||
return "API Token"
|
||||
return "API Key"
|
||||
|
||||
@classmethod
|
||||
def _normalize_models_dev_base_url(
|
||||
cls, runtime: str, base_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
normalized = cls._sanitize_base_url(base_url)
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
suffixes = {
|
||||
"openai_compatible": (
|
||||
"/chat/completions",
|
||||
"/completions",
|
||||
"/responses",
|
||||
"/embeddings",
|
||||
"/audio/speech",
|
||||
"/audio/transcriptions",
|
||||
),
|
||||
"anthropic_compatible": (
|
||||
"/messages",
|
||||
),
|
||||
}
|
||||
|
||||
for suffix in suffixes.get(runtime, ()):
|
||||
if normalized.endswith(suffix):
|
||||
normalized = normalized[: -len(suffix)]
|
||||
break
|
||||
return cls._sanitize_base_url(normalized)
|
||||
|
||||
@classmethod
|
||||
def _models_dev_dynamic_provider_spec(
|
||||
cls,
|
||||
provider_id: str,
|
||||
payload: dict[str, Any],
|
||||
sort_order: int,
|
||||
) -> ProviderSpec | None:
|
||||
normalized_id = str(provider_id or "").strip().lower()
|
||||
if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS:
|
||||
return None
|
||||
|
||||
override = cls._MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES.get(normalized_id, {})
|
||||
npm_package = str(payload.get("npm") or "").strip()
|
||||
runtime = override.get("runtime")
|
||||
if not runtime:
|
||||
if npm_package == "@ai-sdk/openai-compatible":
|
||||
runtime = "openai_compatible"
|
||||
elif npm_package == "@ai-sdk/anthropic":
|
||||
runtime = "anthropic_compatible"
|
||||
else:
|
||||
return None
|
||||
|
||||
model_list_strategy = override.get("model_list_strategy")
|
||||
if not model_list_strategy:
|
||||
model_list_strategy = (
|
||||
"anthropic_compatible"
|
||||
if runtime == "anthropic_compatible"
|
||||
else "models_dev_only"
|
||||
)
|
||||
|
||||
default_base_url = cls._normalize_models_dev_base_url(
|
||||
runtime=runtime,
|
||||
base_url=override.get("default_base_url") or payload.get("api"),
|
||||
)
|
||||
requires_base_url = not bool(default_base_url)
|
||||
env_names = cls._models_dev_env_names(payload)
|
||||
api_key_label = override.get("api_key_label") or cls._dynamic_api_key_label(
|
||||
env_names
|
||||
)
|
||||
name = str(payload.get("name") or override.get("name") or normalized_id).strip()
|
||||
description = override.get("description")
|
||||
if not description:
|
||||
transport_name = "Anthropic-compatible" if runtime == "anthropic_compatible" else "OpenAI-compatible"
|
||||
description = f"{name} {transport_name} 端点(来自 models.dev 目录)。"
|
||||
|
||||
api_key_hint = override.get("api_key_hint")
|
||||
if not api_key_hint:
|
||||
api_key_hint = f"填写 {name} {api_key_label}。"
|
||||
if requires_base_url:
|
||||
api_key_hint = f"填写 {name} {api_key_label},并手动填写 Base URL。"
|
||||
|
||||
return ProviderSpec(
|
||||
id=normalized_id,
|
||||
name=name,
|
||||
runtime=runtime,
|
||||
models_dev_provider_id=normalized_id,
|
||||
default_base_url=default_base_url,
|
||||
base_url_editable=True,
|
||||
requires_base_url=requires_base_url,
|
||||
api_key_label=api_key_label,
|
||||
api_key_hint=api_key_hint,
|
||||
model_list_strategy=model_list_strategy,
|
||||
description=description,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
def _dynamic_provider_specs(
|
||||
self, builtin_specs: tuple[ProviderSpec, ...]
|
||||
) -> tuple[ProviderSpec, ...]:
|
||||
payload = self._cached_models_dev_payload()
|
||||
if not payload:
|
||||
return ()
|
||||
|
||||
explicit_ids = {spec.id for spec in builtin_specs}
|
||||
reserved_ids = self._models_dev_reserved_provider_ids(builtin_specs)
|
||||
candidates: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
for provider_id, provider_payload in payload.items():
|
||||
normalized_id = str(provider_id or "").strip().lower()
|
||||
if not normalized_id or not isinstance(provider_payload, dict):
|
||||
continue
|
||||
if normalized_id in explicit_ids or normalized_id in reserved_ids:
|
||||
continue
|
||||
|
||||
spec = self._models_dev_dynamic_provider_spec(
|
||||
provider_id=normalized_id,
|
||||
payload=provider_payload,
|
||||
sort_order=0,
|
||||
)
|
||||
if not spec:
|
||||
continue
|
||||
candidates.append((spec.name.lower(), normalized_id, provider_payload))
|
||||
|
||||
dynamic_specs = []
|
||||
for sort_order, (_, provider_id, provider_payload) in enumerate(
|
||||
sorted(candidates),
|
||||
start=700,
|
||||
):
|
||||
spec = self._models_dev_dynamic_provider_spec(
|
||||
provider_id=provider_id,
|
||||
payload=provider_payload,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
if not spec:
|
||||
continue
|
||||
dynamic_specs.append(spec)
|
||||
return tuple(dynamic_specs)
|
||||
|
||||
def _provider_specs(self) -> tuple[ProviderSpec, ...]:
|
||||
builtin_specs = self._builtin_provider_specs()
|
||||
return builtin_specs + self._dynamic_provider_specs(builtin_specs)
|
||||
|
||||
async def _get_provider_async(
|
||||
self, provider_id: str, force_refresh: bool = False
|
||||
) -> ProviderSpec:
|
||||
try:
|
||||
return self.get_provider(provider_id)
|
||||
except LLMProviderError:
|
||||
await self.get_models_dev_data(force_refresh=force_refresh)
|
||||
return self.get_provider(provider_id)
|
||||
|
||||
def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]:
|
||||
return {
|
||||
"id": spec.id,
|
||||
"name": spec.name,
|
||||
"runtime": spec.runtime,
|
||||
"default_base_url": self._default_base_url_for_provider(spec) or "",
|
||||
"base_url_presets": [
|
||||
{
|
||||
"label": preset.label,
|
||||
"value": self._sanitize_base_url(preset.value) or "",
|
||||
}
|
||||
for preset in spec.base_url_presets
|
||||
],
|
||||
"base_url_editable": spec.base_url_editable,
|
||||
"requires_base_url": spec.requires_base_url,
|
||||
"supports_api_key": spec.supports_api_key,
|
||||
"api_key_label": spec.api_key_label,
|
||||
"api_key_hint": spec.api_key_hint,
|
||||
"supports_model_refresh": spec.supports_model_refresh,
|
||||
"oauth_methods": [
|
||||
{
|
||||
"id": method.id,
|
||||
"type": method.type,
|
||||
"label": method.label,
|
||||
"description": method.description,
|
||||
}
|
||||
for method in spec.oauth_methods
|
||||
],
|
||||
"description": spec.description,
|
||||
"auth_status": self.get_auth_status(spec.id),
|
||||
}
|
||||
|
||||
async def list_providers_async(
|
||||
self, force_refresh: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
"""返回前端可渲染的 provider 目录,并优先补齐 models.dev 动态平台。"""
|
||||
try:
|
||||
await self.get_models_dev_data(force_refresh=force_refresh)
|
||||
except Exception as err:
|
||||
logger.debug(f"加载 models.dev provider 目录失败,回退内置列表: {err}")
|
||||
return self.list_providers()
|
||||
|
||||
def list_providers(self) -> list[dict[str, Any]]:
|
||||
"""返回前端可渲染的 provider 目录。"""
|
||||
providers = []
|
||||
for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order):
|
||||
providers.append(
|
||||
{
|
||||
"id": spec.id,
|
||||
"name": spec.name,
|
||||
"runtime": spec.runtime,
|
||||
"default_base_url": self._default_base_url_for_provider(spec) or "",
|
||||
"base_url_presets": [
|
||||
{
|
||||
"label": preset.label,
|
||||
"value": self._sanitize_base_url(preset.value) or "",
|
||||
}
|
||||
for preset in spec.base_url_presets
|
||||
],
|
||||
"base_url_editable": spec.base_url_editable,
|
||||
"requires_base_url": spec.requires_base_url,
|
||||
"supports_api_key": spec.supports_api_key,
|
||||
"api_key_label": spec.api_key_label,
|
||||
"api_key_hint": spec.api_key_hint,
|
||||
"supports_model_refresh": spec.supports_model_refresh,
|
||||
"oauth_methods": [
|
||||
{
|
||||
"id": method.id,
|
||||
"type": method.type,
|
||||
"label": method.label,
|
||||
"description": method.description,
|
||||
}
|
||||
for method in spec.oauth_methods
|
||||
],
|
||||
"description": spec.description,
|
||||
"auth_status": self.get_auth_status(spec.id),
|
||||
}
|
||||
)
|
||||
return providers
|
||||
return [
|
||||
self._serialize_provider(spec)
|
||||
for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order)
|
||||
]
|
||||
|
||||
def get_provider(self, provider_id: str) -> ProviderSpec:
|
||||
"""按 provider id 获取定义。"""
|
||||
@@ -973,7 +1235,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
async def _models_dev_provider_payload(
|
||||
self, provider_id: str, base_url: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
spec = self.get_provider(provider_id)
|
||||
spec = await self._get_provider_async(provider_id)
|
||||
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
|
||||
spec,
|
||||
base_url,
|
||||
@@ -1313,7 +1575,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
force_refresh: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""返回标准化后的模型目录。"""
|
||||
spec = self.get_provider(provider_id)
|
||||
spec = await self._get_provider_async(provider_id, force_refresh=force_refresh)
|
||||
if self._resolve_provider_models_dev_provider_id(spec, base_url):
|
||||
# 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表”
|
||||
# 在使用目录型 provider 时也能拿到最新参数。
|
||||
@@ -1449,7 +1711,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
|
||||
API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。
|
||||
"""
|
||||
provider = self.get_provider(provider_id)
|
||||
provider = await self._get_provider_async(provider_id)
|
||||
method = next(
|
||||
(item for item in provider.oauth_methods if item.id == method_id),
|
||||
None,
|
||||
@@ -1844,7 +2106,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
|
||||
返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。
|
||||
"""
|
||||
spec = self.get_provider(provider_id)
|
||||
spec = await self._get_provider_async(provider_id)
|
||||
normalized_api_key = str(api_key or "").strip() or None
|
||||
normalized_base_url = self._sanitize_base_url(base_url)
|
||||
model_record = None
|
||||
|
||||
@@ -98,7 +98,7 @@ async def get_llm_providers(
|
||||
返回前端可直接渲染的 provider 目录。
|
||||
"""
|
||||
try:
|
||||
providers = LLMProviderManager().list_providers()
|
||||
providers = await LLMProviderManager().list_providers_async()
|
||||
return schemas.Response(success=True, data=providers)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
@@ -501,7 +501,7 @@ class ConfigModel(BaseModel):
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# 合局AI智能体
|
||||
AI_AGENT_GLOBAL: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
# LLM提供商(支持内置 provider,以及从 models.dev 动态补充的平台)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
import getpass
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@@ -66,6 +67,26 @@ LLM_PROVIDER_DEFAULTS = {
|
||||
"model": "gemini-2.5-flash",
|
||||
"base_url": "",
|
||||
},
|
||||
"anthropic": {
|
||||
"model": "claude-sonnet-4-0",
|
||||
"base_url": "https://api.anthropic.com/v1",
|
||||
},
|
||||
"openrouter": {
|
||||
"model": "openai/gpt-4.1-mini",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"groq": {
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
"base_url": "https://api.groq.com/openai/v1",
|
||||
},
|
||||
}
|
||||
LLM_PROVIDER_FALLBACK_CHOICES = {
|
||||
"deepseek": "DeepSeek",
|
||||
"openai": "OpenAI Compatible",
|
||||
"google": "Google",
|
||||
"anthropic": "Anthropic",
|
||||
"openrouter": "OpenRouter",
|
||||
"groq": "Groq",
|
||||
}
|
||||
RUNTIME_PACKAGE = {
|
||||
"name": "moviepilot-frontend-runtime",
|
||||
@@ -1063,6 +1084,273 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
|
||||
print("请输入列表中的可选值。")
|
||||
|
||||
|
||||
def _prompt_provider_choice(label: str, choices: dict[str, str], default: str) -> str:
|
||||
labels = []
|
||||
normalized_map: dict[str, str] = {}
|
||||
for key, desc in choices.items():
|
||||
labels.append(f"{key}({desc})")
|
||||
normalized_map[_normalize_choice(key)] = key
|
||||
|
||||
preview_limit = 12
|
||||
print("可用 LLM 提供商:")
|
||||
for item in labels[:preview_limit]:
|
||||
print(f" {item}")
|
||||
if len(labels) > preview_limit:
|
||||
print(f" ... 另有 {len(labels) - preview_limit} 个,可直接输入 provider id")
|
||||
|
||||
while True:
|
||||
raw = input(f"{label} (默认 {default},可直接输入 provider id): ").strip()
|
||||
if not raw:
|
||||
return default
|
||||
normalized = _normalize_choice(raw)
|
||||
if normalized in normalized_map:
|
||||
return normalized_map[normalized]
|
||||
|
||||
provider_id = raw.strip().lower()
|
||||
if re.fullmatch(r"[a-z0-9][a-z0-9._-]*", provider_id):
|
||||
return provider_id
|
||||
print("请输入列表中的可选值,或合法的 provider id(小写字母/数字/.-_)。")
|
||||
|
||||
|
||||
def _load_llm_provider_module():
|
||||
provider_path = ROOT / "app" / "agent" / "llm" / "provider.py"
|
||||
module_name = f"moviepilot_local_llm_provider_{uuid.uuid4().hex}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, provider_path)
|
||||
if not spec or not spec.loader:
|
||||
raise RuntimeError("无法加载 LLM provider 模块")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _load_llm_provider_definitions_inner() -> list[dict[str, Any]]:
|
||||
provider_module = _load_llm_provider_module()
|
||||
providers = asyncio.run(provider_module.LLMProviderManager().list_providers_async())
|
||||
return providers if isinstance(providers, list) else []
|
||||
|
||||
|
||||
def _load_llm_provider_definitions(
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
return _load_llm_provider_definitions_inner()
|
||||
except Exception as exc:
|
||||
if runtime_python and not _current_python_matches(runtime_python):
|
||||
try:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
output_path = Path(temp_dir) / "llm-providers.json"
|
||||
subprocess.run(
|
||||
[
|
||||
str(runtime_python),
|
||||
str(Path(__file__).resolve()),
|
||||
"query-llm-providers",
|
||||
"--output-json-file",
|
||||
str(output_path),
|
||||
],
|
||||
cwd=str(ROOT),
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
data = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except Exception as runtime_exc:
|
||||
print_step(
|
||||
f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{runtime_exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
print_step(f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{exc}")
|
||||
return []
|
||||
|
||||
|
||||
def _llm_provider_choice_map(
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
choices: dict[str, str] = {}
|
||||
for item in provider_definitions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("supports_api_key") is False:
|
||||
continue
|
||||
provider_id = str(item.get("id") or "").strip().lower()
|
||||
name = str(item.get("name") or provider_id).strip()
|
||||
if not provider_id or not name:
|
||||
continue
|
||||
choices[provider_id] = name
|
||||
if choices:
|
||||
return choices
|
||||
return dict(LLM_PROVIDER_FALLBACK_CHOICES)
|
||||
|
||||
|
||||
def _llm_provider_defaults(
|
||||
provider: str,
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
normalized_provider = str(provider or "").strip().lower()
|
||||
defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {})
|
||||
provider_meta = next(
|
||||
(
|
||||
item
|
||||
for item in provider_definitions
|
||||
if isinstance(item, dict)
|
||||
and str(item.get("id") or "").strip().lower() == normalized_provider
|
||||
),
|
||||
None,
|
||||
)
|
||||
if isinstance(provider_meta, dict):
|
||||
default_base_url = str(provider_meta.get("default_base_url") or "").strip()
|
||||
if default_base_url:
|
||||
defaults["base_url"] = default_base_url
|
||||
|
||||
defaults.setdefault("model", _env_default("LLM_MODEL", ""))
|
||||
defaults.setdefault("base_url", _env_default("LLM_BASE_URL", ""))
|
||||
return defaults
|
||||
|
||||
|
||||
def _llm_provider_meta(
|
||||
provider: str,
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
normalized_provider = str(provider or "").strip().lower()
|
||||
provider_meta = next(
|
||||
(
|
||||
item
|
||||
for item in provider_definitions
|
||||
if isinstance(item, dict)
|
||||
and str(item.get("id") or "").strip().lower() == normalized_provider
|
||||
),
|
||||
None,
|
||||
)
|
||||
return dict(provider_meta) if isinstance(provider_meta, dict) else {}
|
||||
|
||||
|
||||
def _load_llm_models_inner(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
provider = str(payload.get("provider") or "").strip().lower()
|
||||
if not provider:
|
||||
raise RuntimeError("缺少 LLM provider")
|
||||
|
||||
provider_module = _load_llm_provider_module()
|
||||
api_key = str(payload.get("api_key") or "").strip() or None
|
||||
base_url = str(payload.get("base_url") or "").strip() or None
|
||||
models = asyncio.run(
|
||||
provider_module.LLMProviderManager().list_models(
|
||||
provider_id=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
force_refresh=False,
|
||||
)
|
||||
)
|
||||
return models if isinstance(models, list) else []
|
||||
|
||||
|
||||
def _load_llm_models(
|
||||
*,
|
||||
provider: str,
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
payload = {
|
||||
"provider": str(provider or "").strip().lower(),
|
||||
"api_key": str(api_key or "").strip(),
|
||||
"base_url": str(base_url or "").strip(),
|
||||
}
|
||||
try:
|
||||
return _load_llm_models_inner(payload)
|
||||
except Exception as exc:
|
||||
if runtime_python and not _current_python_matches(runtime_python):
|
||||
try:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
request_path = Path(temp_dir) / "llm-models-request.json"
|
||||
output_path = Path(temp_dir) / "llm-models.json"
|
||||
request_path.write_text(
|
||||
json.dumps(payload, ensure_ascii=False), encoding="utf-8"
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
str(runtime_python),
|
||||
str(Path(__file__).resolve()),
|
||||
"query-llm-models",
|
||||
"--request-json-file",
|
||||
str(request_path),
|
||||
"--output-json-file",
|
||||
str(output_path),
|
||||
],
|
||||
cwd=str(ROOT),
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
data = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except Exception as runtime_exc:
|
||||
print_step(
|
||||
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{runtime_exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
print_step(
|
||||
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _print_llm_models(models: list[dict[str, Any]], limit: int = 20) -> None:
|
||||
print("可用模型:")
|
||||
for index, item in enumerate(models[:limit], start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
model_id = str(item.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
model_name = str(item.get("name") or model_id).strip()
|
||||
extras: list[str] = []
|
||||
if item.get("context_tokens_k"):
|
||||
extras.append(f"{item['context_tokens_k']}K")
|
||||
if item.get("supports_reasoning"):
|
||||
extras.append("reasoning")
|
||||
if item.get("supports_tools"):
|
||||
extras.append("tools")
|
||||
if item.get("supports_image_input"):
|
||||
extras.append("vision")
|
||||
extra_text = f" [{' / '.join(extras)}]" if extras else ""
|
||||
if model_name != model_id:
|
||||
print(f" {index}. {model_id} ({model_name}){extra_text}")
|
||||
else:
|
||||
print(f" {index}. {model_id}{extra_text}")
|
||||
if len(models) > limit:
|
||||
print(f" ... 共 {len(models)} 个模型,可输入编号或直接输入模型名称")
|
||||
|
||||
|
||||
def _prompt_model_choice(models: list[dict[str, Any]], default: Optional[str] = None) -> str:
|
||||
valid_models = [item for item in models if isinstance(item, dict) and item.get("id")]
|
||||
if not valid_models:
|
||||
return _prompt_text("LLM 模型名称", default=default)
|
||||
|
||||
indexed_models = {
|
||||
str(index): str(item.get("id")).strip()
|
||||
for index, item in enumerate(valid_models, start=1)
|
||||
}
|
||||
default_model = str(default or indexed_models.get("1") or "").strip()
|
||||
_print_llm_models(valid_models)
|
||||
|
||||
while True:
|
||||
raw = input(
|
||||
f"LLM 模型名称/编号{' [' + default_model + ']' if default_model else ''}: "
|
||||
).strip()
|
||||
if not raw and default_model:
|
||||
return default_model
|
||||
if raw in indexed_models:
|
||||
return indexed_models[raw]
|
||||
if raw:
|
||||
return raw
|
||||
print("请输入有效模型编号或模型名称。")
|
||||
|
||||
|
||||
def _env_llm_thinking_level_default() -> str:
|
||||
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
|
||||
alias_map = {
|
||||
@@ -1459,7 +1747,9 @@ def _collect_notification_config() -> Optional[dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _collect_agent_config() -> dict[str, Any]:
|
||||
def _collect_agent_config(
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> dict[str, Any]:
|
||||
print_step("AI Agent 配置")
|
||||
enabled = _prompt_yes_no(
|
||||
"是否启用 AI 智能体",
|
||||
@@ -1471,22 +1761,42 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
"AI_AGENT_GLOBAL": False,
|
||||
}
|
||||
|
||||
provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python)
|
||||
provider_choices = _llm_provider_choice_map(provider_definitions)
|
||||
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
|
||||
if current_provider not in LLM_PROVIDER_DEFAULTS:
|
||||
if current_provider not in provider_choices:
|
||||
current_provider = "deepseek"
|
||||
|
||||
provider = _prompt_choice(
|
||||
"选择 LLM 提供商",
|
||||
{
|
||||
"deepseek": "DeepSeek",
|
||||
"openai": "OpenAI",
|
||||
"google": "Google",
|
||||
},
|
||||
default=current_provider,
|
||||
)
|
||||
defaults = LLM_PROVIDER_DEFAULTS[provider]
|
||||
while True:
|
||||
provider = _prompt_provider_choice(
|
||||
"选择 LLM 提供商",
|
||||
provider_choices,
|
||||
default=current_provider,
|
||||
)
|
||||
provider_meta = _llm_provider_meta(provider, provider_definitions)
|
||||
if provider_meta.get("supports_api_key") is False:
|
||||
print_step(
|
||||
f"{provider_meta.get('name') or provider} 当前仅支持交互式授权,安装向导暂不支持,请改选可填写 API Key 的 provider。"
|
||||
)
|
||||
current_provider = "deepseek"
|
||||
continue
|
||||
break
|
||||
|
||||
defaults = _llm_provider_defaults(provider, provider_definitions)
|
||||
current_model = _env_default("LLM_MODEL", defaults["model"])
|
||||
current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"])
|
||||
api_key_label = str(provider_meta.get("api_key_label") or "API Key").strip() or "API Key"
|
||||
api_key_hint = str(provider_meta.get("api_key_hint") or "").strip()
|
||||
requires_base_url = bool(provider_meta.get("requires_base_url"))
|
||||
base_url_label = (
|
||||
"自定义 Google API Base URL(可选)"
|
||||
if provider == "google"
|
||||
else "LLM Base URL(必填)"
|
||||
if requires_base_url
|
||||
else "LLM Base URL"
|
||||
)
|
||||
if api_key_hint:
|
||||
print_step(api_key_hint)
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"AI_AGENT_ENABLE": True,
|
||||
@@ -1495,12 +1805,8 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
default=_env_bool("AI_AGENT_GLOBAL", False),
|
||||
),
|
||||
"LLM_PROVIDER": provider,
|
||||
"LLM_MODEL": _prompt_text(
|
||||
"LLM 模型名称",
|
||||
default=current_model,
|
||||
),
|
||||
"LLM_API_KEY": _prompt_secret_text(
|
||||
"LLM API Key",
|
||||
f"LLM {api_key_label}",
|
||||
current_value=read_env_value("LLM_API_KEY"),
|
||||
required=True,
|
||||
),
|
||||
@@ -1524,18 +1830,18 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
),
|
||||
}
|
||||
|
||||
if provider == "google":
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
"自定义 Google API Base URL(可选)",
|
||||
default=current_base_url,
|
||||
allow_empty=True,
|
||||
)
|
||||
else:
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
"LLM Base URL",
|
||||
default=current_base_url,
|
||||
allow_empty=True,
|
||||
)
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
base_url_label,
|
||||
default=current_base_url,
|
||||
allow_empty=not requires_base_url,
|
||||
)
|
||||
models = _load_llm_models(
|
||||
provider=provider,
|
||||
api_key=config["LLM_API_KEY"],
|
||||
base_url=config["LLM_BASE_URL"],
|
||||
runtime_python=runtime_python,
|
||||
)
|
||||
config["LLM_MODEL"] = _prompt_model_choice(models, default=current_model)
|
||||
|
||||
return config
|
||||
|
||||
@@ -1744,7 +2050,7 @@ def run_setup_wizard(
|
||||
preset_password=preset_superuser_password,
|
||||
),
|
||||
**_collect_database_config(),
|
||||
**_collect_agent_config(),
|
||||
**_collect_agent_config(runtime_python=runtime_python),
|
||||
},
|
||||
"directories": [_collect_directory_config()],
|
||||
"downloader": _collect_downloader_config(),
|
||||
@@ -3276,6 +3582,23 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
query_llm_providers_parser = subparsers.add_parser(
|
||||
"query-llm-providers", help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_providers_parser.add_argument(
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
query_llm_models_parser = subparsers.add_parser(
|
||||
"query-llm-models", help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_models_parser.add_argument(
|
||||
"--request-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_models_parser.add_argument(
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -3461,6 +3784,25 @@ def main() -> int:
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "query-llm-providers":
|
||||
payload = _load_llm_provider_definitions_inner()
|
||||
Path(args.output_json_file).write_text(
|
||||
json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "query-llm-models":
|
||||
payload = json.loads(Path(args.request_json_file).read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("模型查询负载格式错误")
|
||||
models = _load_llm_models_inner(payload)
|
||||
Path(args.output_json_file).write_text(
|
||||
json.dumps(models, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"命令执行失败,退出码:{exc.returncode}", file=sys.stderr)
|
||||
return exc.returncode
|
||||
|
||||
215
tests/test_llm_provider_registry.py
Normal file
215
tests/test_llm_provider_registry.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class _DummySystemConfigOper:
|
||||
def get(self, _key):
|
||||
return {}
|
||||
|
||||
async def async_set(self, _key, _value):
|
||||
return None
|
||||
|
||||
|
||||
for _module_name in ("aiofiles", "jwt"):
|
||||
_stub_module(_module_name)
|
||||
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
settings=SimpleNamespace(
|
||||
TEMP_PATH="/tmp",
|
||||
PROXY_HOST=None,
|
||||
LLM_MAX_CONTEXT_TOKENS=64,
|
||||
),
|
||||
)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper)
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent"))
|
||||
|
||||
provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path)
|
||||
provider_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
sys.modules[spec.name] = provider_module
|
||||
spec.loader.exec_module(provider_module)
|
||||
|
||||
LLMProviderError = provider_module.LLMProviderError
|
||||
LLMProviderManager = provider_module.LLMProviderManager
|
||||
|
||||
|
||||
class LlmProviderRegistryTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
LLMProviderManager._instances.clear()
|
||||
|
||||
def tearDown(self):
|
||||
LLMProviderManager._instances.clear()
|
||||
|
||||
def test_dynamic_provider_is_exposed_from_models_dev_cache(self):
|
||||
manager = LLMProviderManager()
|
||||
manager._models_dev_data = {
|
||||
"frogbot": {
|
||||
"id": "frogbot",
|
||||
"name": "FrogBot",
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"env": ["FROGBOT_API_KEY"],
|
||||
"api": "https://app.frogbot.ai/api/v1",
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
provider = manager.get_provider("frogbot")
|
||||
|
||||
self.assertEqual(provider.id, "frogbot")
|
||||
self.assertEqual(provider.runtime, "openai_compatible")
|
||||
self.assertEqual(provider.default_base_url, "https://app.frogbot.ai/api/v1")
|
||||
self.assertFalse(provider.requires_base_url)
|
||||
self.assertTrue(provider.base_url_editable)
|
||||
self.assertEqual(provider.model_list_strategy, "models_dev_only")
|
||||
|
||||
def test_dynamic_provider_override_normalizes_chat_endpoint_base_url(self):
|
||||
manager = LLMProviderManager()
|
||||
manager._models_dev_data = {
|
||||
"bailing": {
|
||||
"id": "bailing",
|
||||
"name": "Bailing",
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"env": ["BAILING_API_TOKEN"],
|
||||
"api": "https://api.tbox.cn/api/llm/v1/chat/completions",
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
provider = manager.get_provider("bailing")
|
||||
|
||||
self.assertEqual(provider.default_base_url, "https://api.tbox.cn/api/llm/v1")
|
||||
self.assertEqual(provider.api_key_label, "API Token")
|
||||
|
||||
def test_dynamic_provider_skips_alias_only_models_dev_ids(self):
|
||||
manager = LLMProviderManager()
|
||||
manager._models_dev_data = {
|
||||
"moonshotai": {
|
||||
"id": "moonshotai",
|
||||
"name": "Moonshot AI Intl",
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"env": ["MOONSHOT_API_KEY"],
|
||||
"api": "https://api.moonshot.ai/v1",
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
with self.assertRaises(LLMProviderError):
|
||||
manager.get_provider("moonshotai")
|
||||
|
||||
def test_dynamic_provider_skips_incompatible_models_dev_provider(self):
|
||||
manager = LLMProviderManager()
|
||||
manager._models_dev_data = {
|
||||
"azure": {
|
||||
"id": "azure",
|
||||
"name": "Azure",
|
||||
"npm": "@ai-sdk/azure",
|
||||
"env": ["AZURE_API_KEY"],
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
with self.assertRaises(LLMProviderError):
|
||||
manager.get_provider("azure")
|
||||
|
||||
def test_dynamic_provider_without_known_base_url_requires_manual_input(self):
|
||||
manager = LLMProviderManager()
|
||||
manager._models_dev_data = {
|
||||
"custom-anthropic": {
|
||||
"id": "custom-anthropic",
|
||||
"name": "Custom Anthropic",
|
||||
"npm": "@ai-sdk/anthropic",
|
||||
"env": ["CUSTOM_ANTHROPIC_KEY"],
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
provider = manager.get_provider("custom-anthropic")
|
||||
|
||||
self.assertEqual(provider.runtime, "anthropic_compatible")
|
||||
self.assertTrue(provider.requires_base_url)
|
||||
self.assertTrue(provider.base_url_editable)
|
||||
self.assertEqual(provider.model_list_strategy, "anthropic_compatible")
|
||||
|
||||
def test_list_providers_async_loads_models_dev_before_serializing(self):
|
||||
manager = LLMProviderManager()
|
||||
payload = {
|
||||
"frogbot": {
|
||||
"id": "frogbot",
|
||||
"name": "FrogBot",
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"env": ["FROGBOT_API_KEY"],
|
||||
"api": "https://app.frogbot.ai/api/v1",
|
||||
"models": {},
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"get_models_dev_data",
|
||||
AsyncMock(side_effect=lambda force_refresh=False: manager.__dict__.update({"_models_dev_data": payload}) or payload),
|
||||
) as fetch_mock:
|
||||
providers = asyncio.run(manager.list_providers_async())
|
||||
|
||||
fetch_mock.assert_awaited_once_with(force_refresh=False)
|
||||
self.assertIn("frogbot", {item["id"] for item in providers})
|
||||
|
||||
def test_list_models_uses_dynamic_provider_after_refresh(self):
|
||||
manager = LLMProviderManager()
|
||||
payload = {
|
||||
"frogbot": {
|
||||
"id": "frogbot",
|
||||
"name": "FrogBot",
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"env": ["FROGBOT_API_KEY"],
|
||||
"api": "https://app.frogbot.ai/api/v1",
|
||||
"models": {
|
||||
"frog-1": {
|
||||
"name": "Frog 1",
|
||||
"limit": {"context": 131072},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async def _load_models_dev(force_refresh: bool = False):
|
||||
manager._models_dev_data = payload
|
||||
return payload
|
||||
|
||||
with patch.object(manager, "get_models_dev_data", AsyncMock(side_effect=_load_models_dev)):
|
||||
models = asyncio.run(
|
||||
manager.list_models(
|
||||
provider_id="frogbot",
|
||||
api_key="sk-test",
|
||||
force_refresh=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual([item["id"] for item in models], ["frog-1"])
|
||||
self.assertEqual(models[0]["source"], "models.dev")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
189
tests/test_local_setup_llm_provider_prompt.py
Normal file
189
tests/test_local_setup_llm_provider_prompt.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "local_setup.py"
|
||||
|
||||
|
||||
def load_local_setup_module():
|
||||
module_name = f"moviepilot_local_setup_llm_{uuid.uuid4().hex}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class LocalSetupLlmProviderPromptTests(unittest.TestCase):
|
||||
def test_collect_agent_config_prefers_loaded_provider_directory(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
provider_definitions = [
|
||||
{
|
||||
"id": "frogbot",
|
||||
"name": "FrogBot",
|
||||
"default_base_url": "https://app.frogbot.ai/api/v1",
|
||||
"api_key_label": "API Key",
|
||||
}
|
||||
]
|
||||
models = [
|
||||
{"id": "frog-1", "name": "Frog 1", "context_tokens_k": 128},
|
||||
{"id": "frog-2", "name": "Frog 2"},
|
||||
]
|
||||
|
||||
with patch.object(module, "print_step"), patch.object(
|
||||
module, "_prompt_yes_no", side_effect=[True, False, True]
|
||||
), patch.object(
|
||||
module, "_load_llm_provider_definitions", return_value=provider_definitions
|
||||
), patch.object(
|
||||
module, "_prompt_provider_choice", return_value="frogbot"
|
||||
) as provider_prompt, patch.object(
|
||||
module, "_prompt_text", side_effect=["https://override.example.com/v1"]
|
||||
), patch.object(
|
||||
module, "_prompt_secret_text", return_value="sk-frog"
|
||||
), patch.object(
|
||||
module, "_load_llm_models", return_value=models
|
||||
) as load_models, patch.object(
|
||||
module, "_prompt_model_choice", return_value="frog-2"
|
||||
) as model_prompt, patch.object(
|
||||
module, "read_env_value", return_value=None
|
||||
), patch.object(
|
||||
module, "_env_default", side_effect=lambda key, default="": default
|
||||
), patch.object(
|
||||
module, "_env_bool", side_effect=lambda key, default: default
|
||||
), patch.object(
|
||||
module, "_env_llm_thinking_level_default", return_value="auto"
|
||||
), patch.object(
|
||||
module, "_prompt_choice", return_value="auto"
|
||||
):
|
||||
config = module._collect_agent_config(runtime_python=Path("/tmp/runtime-python"))
|
||||
|
||||
provider_prompt.assert_called_once()
|
||||
load_models.assert_called_once_with(
|
||||
provider="frogbot",
|
||||
api_key="sk-frog",
|
||||
base_url="https://override.example.com/v1",
|
||||
runtime_python=Path("/tmp/runtime-python"),
|
||||
)
|
||||
model_prompt.assert_called_once_with(models, default="")
|
||||
self.assertEqual(config["LLM_PROVIDER"], "frogbot")
|
||||
self.assertEqual(config["LLM_MODEL"], "frog-2")
|
||||
self.assertEqual(config["LLM_API_KEY"], "sk-frog")
|
||||
self.assertEqual(config["LLM_BASE_URL"], "https://override.example.com/v1")
|
||||
|
||||
def test_collect_agent_config_falls_back_to_common_provider_choices(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
with patch.object(module, "print_step"), patch.object(
|
||||
module, "_prompt_yes_no", side_effect=[True, False, True]
|
||||
), patch.object(
|
||||
module, "_load_llm_provider_definitions", return_value=[]
|
||||
), patch.object(
|
||||
module, "_prompt_provider_choice", return_value="anthropic"
|
||||
), patch.object(
|
||||
module, "_prompt_text", side_effect=["https://api.anthropic.com/v1"]
|
||||
), patch.object(
|
||||
module, "_prompt_secret_text", return_value="sk-anthropic"
|
||||
), patch.object(
|
||||
module, "_load_llm_models", return_value=[]
|
||||
), patch.object(
|
||||
module, "_prompt_model_choice", return_value="claude-sonnet-4-0"
|
||||
), patch.object(
|
||||
module, "read_env_value", return_value=None
|
||||
), patch.object(
|
||||
module, "_env_default", side_effect=lambda key, default="": default
|
||||
), patch.object(
|
||||
module, "_env_bool", side_effect=lambda key, default: default
|
||||
), patch.object(
|
||||
module, "_env_llm_thinking_level_default", return_value="off"
|
||||
), patch.object(
|
||||
module, "_prompt_choice", return_value="off"
|
||||
):
|
||||
config = module._collect_agent_config()
|
||||
|
||||
self.assertEqual(config["LLM_PROVIDER"], "anthropic")
|
||||
self.assertEqual(config["LLM_MODEL"], "claude-sonnet-4-0")
|
||||
self.assertEqual(config["LLM_BASE_URL"], "https://api.anthropic.com/v1")
|
||||
|
||||
def test_prompt_model_choice_accepts_index_selection(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
with patch.object(module, "_print_llm_models") as print_models, patch(
|
||||
"builtins.input", return_value="2"
|
||||
):
|
||||
model = module._prompt_model_choice(
|
||||
[
|
||||
{"id": "model-a", "name": "Model A"},
|
||||
{"id": "model-b", "name": "Model B"},
|
||||
],
|
||||
default="model-a",
|
||||
)
|
||||
|
||||
print_models.assert_called_once()
|
||||
self.assertEqual(model, "model-b")
|
||||
|
||||
def test_prompt_model_choice_falls_back_to_text_input_when_empty(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
with patch.object(module, "_prompt_text", return_value="custom-model") as prompt_text:
|
||||
model = module._prompt_model_choice([], default="")
|
||||
|
||||
prompt_text.assert_called_once_with("LLM 模型名称", default="")
|
||||
self.assertEqual(model, "custom-model")
|
||||
|
||||
def test_load_llm_provider_definitions_inner_uses_direct_provider_module_loader(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
class _FakeManager:
|
||||
async def list_providers_async(self, force_refresh: bool = False):
|
||||
return [{"id": "frogbot", "name": "FrogBot"}]
|
||||
|
||||
class _FakeProviderModule:
|
||||
@staticmethod
|
||||
def LLMProviderManager():
|
||||
return _FakeManager()
|
||||
|
||||
fake_provider_module = _FakeProviderModule()
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"_load_llm_provider_module",
|
||||
return_value=fake_provider_module,
|
||||
) as loader:
|
||||
providers = module._load_llm_provider_definitions_inner()
|
||||
|
||||
loader.assert_called_once_with()
|
||||
self.assertEqual(providers, [{"id": "frogbot", "name": "FrogBot"}])
|
||||
|
||||
def test_llm_provider_choice_map_skips_oauth_only_provider(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
choices = module._llm_provider_choice_map(
|
||||
[
|
||||
{"id": "chatgpt", "name": "ChatGPT", "supports_api_key": True},
|
||||
{"id": "github-copilot", "name": "GitHub Copilot", "supports_api_key": False},
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(choices, {"chatgpt": "ChatGPT"})
|
||||
|
||||
def test_prompt_provider_choice_accepts_custom_provider_id(self):
|
||||
module = load_local_setup_module()
|
||||
|
||||
with patch("builtins.input", return_value="my-provider_01"), patch("builtins.print"):
|
||||
provider = module._prompt_provider_choice(
|
||||
"选择 LLM 提供商",
|
||||
{"deepseek": "DeepSeek", "google": "Google"},
|
||||
default="deepseek",
|
||||
)
|
||||
|
||||
self.assertEqual(provider, "my-provider_01")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user