diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py index 689e0fc8..3479790a 100644 --- a/app/agent/llm/provider.py +++ b/app/agent/llm/provider.py @@ -117,6 +117,58 @@ class LLMProviderManager(metaclass=Singleton): "gpt-5.4-mini", "gpt-5.5", } + _MODELS_DEV_DYNAMIC_SKIP_IDS = { + "aihubmix", + "amazon-bedrock", + "azure", + "azure-cognitive-services", + "cloudflare-ai-gateway", + "cohere", + "gitlab", + "google-vertex", + "google-vertex-anthropic", + "kiro", + "sap-ai-core", + "v0", + "vercel", + } + _MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES = { + "bailing": { + "runtime": "openai_compatible", + "default_base_url": "https://api.tbox.cn/api/llm/v1", + "description": "Bailing OpenAI-compatible 端点。", + }, + "cerebras": { + "runtime": "openai_compatible", + "default_base_url": "https://api.cerebras.ai/v1", + "description": "Cerebras 官方兼容端点。", + }, + "deepinfra": { + "runtime": "openai_compatible", + "default_base_url": "https://api.deepinfra.com/v1/openai", + "description": "DeepInfra 官方兼容端点。", + }, + "mistral": { + "runtime": "openai_compatible", + "default_base_url": "https://api.mistral.ai/v1", + "description": "Mistral 官方兼容端点。", + }, + "perplexity": { + "runtime": "openai_compatible", + "default_base_url": "https://api.perplexity.ai/v1", + "description": "Perplexity 官方兼容端点。", + }, + "togetherai": { + "runtime": "openai_compatible", + "default_base_url": "https://api.together.xyz/v1", + "description": "Together AI 官方兼容端点。", + }, + "venice": { + "runtime": "openai_compatible", + "default_base_url": "https://api.venice.ai/api/v1", + "description": "Venice AI 官方兼容端点。", + }, + } def __init__(self): self._lock = threading.RLock() @@ -130,7 +182,7 @@ class LLMProviderManager(metaclass=Singleton): ) @staticmethod - def _provider_specs() -> tuple[ProviderSpec, ...]: + def _builtin_provider_specs() -> tuple[ProviderSpec, ...]: """ 返回受支持的 provider 定义。 @@ -708,43 +760,253 @@ class LLMProviderManager(metaclass=Singleton): ) return tuple(providers) + def _cached_models_dev_payload(self) -> dict[str, Any]: + if isinstance(self._models_dev_data, dict): + return self._models_dev_data + + try: + if not self._models_dev_cache_path.exists(): + return {} + payload = json.loads(self._models_dev_cache_path.read_text(encoding="utf-8")) + except Exception as err: + logger.warning(f"读取 models.dev provider 缓存失败: {err}") + return {} + + if not isinstance(payload, dict): + return {} + + self._models_dev_data = payload + return payload + + @staticmethod + def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]: + raw_env_names = payload.get("env") + if not isinstance(raw_env_names, list): + return () + env_names = [] + for item in raw_env_names: + value = str(item or "").strip() + if value: + env_names.append(value) + return tuple(env_names) + + @classmethod + def _models_dev_reserved_provider_ids( + cls, specs: tuple[ProviderSpec, ...] + ) -> set[str]: + reserved_ids: set[str] = set() + for spec in specs: + if spec.models_dev_provider_id: + reserved_ids.add(spec.models_dev_provider_id) + for preset in spec.base_url_presets: + if preset.models_dev_provider_id: + reserved_ids.add(preset.models_dev_provider_id) + return reserved_ids + + @staticmethod + def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str: + first_env = env_names[0].upper() if env_names else "" + if "TOKEN" in first_env and "KEY" not in first_env: + return "API Token" + return "API Key" + + @classmethod + def _normalize_models_dev_base_url( + cls, runtime: str, base_url: Optional[str] + ) -> Optional[str]: + normalized = cls._sanitize_base_url(base_url) + if not normalized: + return None + + suffixes = { + "openai_compatible": ( + "/chat/completions", + "/completions", + "/responses", + "/embeddings", + "/audio/speech", + "/audio/transcriptions", + ), + "anthropic_compatible": ( + "/messages", + ), + } + + for suffix in suffixes.get(runtime, ()): + if normalized.endswith(suffix): + normalized = normalized[: -len(suffix)] + break + return cls._sanitize_base_url(normalized) + + @classmethod + def _models_dev_dynamic_provider_spec( + cls, + provider_id: str, + payload: dict[str, Any], + sort_order: int, + ) -> ProviderSpec | None: + normalized_id = str(provider_id or "").strip().lower() + if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS: + return None + + override = cls._MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES.get(normalized_id, {}) + npm_package = str(payload.get("npm") or "").strip() + runtime = override.get("runtime") + if not runtime: + if npm_package == "@ai-sdk/openai-compatible": + runtime = "openai_compatible" + elif npm_package == "@ai-sdk/anthropic": + runtime = "anthropic_compatible" + else: + return None + + model_list_strategy = override.get("model_list_strategy") + if not model_list_strategy: + model_list_strategy = ( + "anthropic_compatible" + if runtime == "anthropic_compatible" + else "models_dev_only" + ) + + default_base_url = cls._normalize_models_dev_base_url( + runtime=runtime, + base_url=override.get("default_base_url") or payload.get("api"), + ) + requires_base_url = not bool(default_base_url) + env_names = cls._models_dev_env_names(payload) + api_key_label = override.get("api_key_label") or cls._dynamic_api_key_label( + env_names + ) + name = str(payload.get("name") or override.get("name") or normalized_id).strip() + description = override.get("description") + if not description: + transport_name = "Anthropic-compatible" if runtime == "anthropic_compatible" else "OpenAI-compatible" + description = f"{name} {transport_name} 端点(来自 models.dev 目录)。" + + api_key_hint = override.get("api_key_hint") + if not api_key_hint: + api_key_hint = f"填写 {name} {api_key_label}。" + if requires_base_url: + api_key_hint = f"填写 {name} {api_key_label},并手动填写 Base URL。" + + return ProviderSpec( + id=normalized_id, + name=name, + runtime=runtime, + models_dev_provider_id=normalized_id, + default_base_url=default_base_url, + base_url_editable=True, + requires_base_url=requires_base_url, + api_key_label=api_key_label, + api_key_hint=api_key_hint, + model_list_strategy=model_list_strategy, + description=description, + sort_order=sort_order, + ) + + def _dynamic_provider_specs( + self, builtin_specs: tuple[ProviderSpec, ...] + ) -> tuple[ProviderSpec, ...]: + payload = self._cached_models_dev_payload() + if not payload: + return () + + explicit_ids = {spec.id for spec in builtin_specs} + reserved_ids = self._models_dev_reserved_provider_ids(builtin_specs) + candidates: list[tuple[str, str, dict[str, Any]]] = [] + + for provider_id, provider_payload in payload.items(): + normalized_id = str(provider_id or "").strip().lower() + if not normalized_id or not isinstance(provider_payload, dict): + continue + if normalized_id in explicit_ids or normalized_id in reserved_ids: + continue + + spec = self._models_dev_dynamic_provider_spec( + provider_id=normalized_id, + payload=provider_payload, + sort_order=0, + ) + if not spec: + continue + candidates.append((spec.name.lower(), normalized_id, provider_payload)) + + dynamic_specs = [] + for sort_order, (_, provider_id, provider_payload) in enumerate( + sorted(candidates), + start=700, + ): + spec = self._models_dev_dynamic_provider_spec( + provider_id=provider_id, + payload=provider_payload, + sort_order=sort_order, + ) + if not spec: + continue + dynamic_specs.append(spec) + return tuple(dynamic_specs) + + def _provider_specs(self) -> tuple[ProviderSpec, ...]: + builtin_specs = self._builtin_provider_specs() + return builtin_specs + self._dynamic_provider_specs(builtin_specs) + + async def _get_provider_async( + self, provider_id: str, force_refresh: bool = False + ) -> ProviderSpec: + try: + return self.get_provider(provider_id) + except LLMProviderError: + await self.get_models_dev_data(force_refresh=force_refresh) + return self.get_provider(provider_id) + + def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]: + return { + "id": spec.id, + "name": spec.name, + "runtime": spec.runtime, + "default_base_url": self._default_base_url_for_provider(spec) or "", + "base_url_presets": [ + { + "label": preset.label, + "value": self._sanitize_base_url(preset.value) or "", + } + for preset in spec.base_url_presets + ], + "base_url_editable": spec.base_url_editable, + "requires_base_url": spec.requires_base_url, + "supports_api_key": spec.supports_api_key, + "api_key_label": spec.api_key_label, + "api_key_hint": spec.api_key_hint, + "supports_model_refresh": spec.supports_model_refresh, + "oauth_methods": [ + { + "id": method.id, + "type": method.type, + "label": method.label, + "description": method.description, + } + for method in spec.oauth_methods + ], + "description": spec.description, + "auth_status": self.get_auth_status(spec.id), + } + + async def list_providers_async( + self, force_refresh: bool = False + ) -> list[dict[str, Any]]: + """返回前端可渲染的 provider 目录,并优先补齐 models.dev 动态平台。""" + try: + await self.get_models_dev_data(force_refresh=force_refresh) + except Exception as err: + logger.debug(f"加载 models.dev provider 目录失败,回退内置列表: {err}") + return self.list_providers() + def list_providers(self) -> list[dict[str, Any]]: """返回前端可渲染的 provider 目录。""" - providers = [] - for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order): - providers.append( - { - "id": spec.id, - "name": spec.name, - "runtime": spec.runtime, - "default_base_url": self._default_base_url_for_provider(spec) or "", - "base_url_presets": [ - { - "label": preset.label, - "value": self._sanitize_base_url(preset.value) or "", - } - for preset in spec.base_url_presets - ], - "base_url_editable": spec.base_url_editable, - "requires_base_url": spec.requires_base_url, - "supports_api_key": spec.supports_api_key, - "api_key_label": spec.api_key_label, - "api_key_hint": spec.api_key_hint, - "supports_model_refresh": spec.supports_model_refresh, - "oauth_methods": [ - { - "id": method.id, - "type": method.type, - "label": method.label, - "description": method.description, - } - for method in spec.oauth_methods - ], - "description": spec.description, - "auth_status": self.get_auth_status(spec.id), - } - ) - return providers + return [ + self._serialize_provider(spec) + for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order) + ] def get_provider(self, provider_id: str) -> ProviderSpec: """按 provider id 获取定义。""" @@ -973,7 +1235,7 @@ class LLMProviderManager(metaclass=Singleton): async def _models_dev_provider_payload( self, provider_id: str, base_url: Optional[str] = None ) -> dict[str, Any]: - spec = self.get_provider(provider_id) + spec = await self._get_provider_async(provider_id) models_dev_provider_id = self._resolve_provider_models_dev_provider_id( spec, base_url, @@ -1313,7 +1575,7 @@ class LLMProviderManager(metaclass=Singleton): force_refresh: bool = False, ) -> list[dict[str, Any]]: """返回标准化后的模型目录。""" - spec = self.get_provider(provider_id) + spec = await self._get_provider_async(provider_id, force_refresh=force_refresh) if self._resolve_provider_models_dev_provider_id(spec, base_url): # 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表” # 在使用目录型 provider 时也能拿到最新参数。 @@ -1449,7 +1711,7 @@ class LLMProviderManager(metaclass=Singleton): API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。 """ - provider = self.get_provider(provider_id) + provider = await self._get_provider_async(provider_id) method = next( (item for item in provider.oauth_methods if item.id == method_id), None, @@ -1844,7 +2106,7 @@ class LLMProviderManager(metaclass=Singleton): 返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。 """ - spec = self.get_provider(provider_id) + spec = await self._get_provider_async(provider_id) normalized_api_key = str(api_key or "").strip() or None normalized_base_url = self._sanitize_base_url(base_url) model_record = None diff --git a/app/api/endpoints/llm.py b/app/api/endpoints/llm.py index ca3e0fbd..cd0cc525 100644 --- a/app/api/endpoints/llm.py +++ b/app/api/endpoints/llm.py @@ -98,7 +98,7 @@ async def get_llm_providers( 返回前端可直接渲染的 provider 目录。 """ try: - providers = LLMProviderManager().list_providers() + providers = await LLMProviderManager().list_providers_async() return schemas.Response(success=True, data=providers) except Exception as err: return schemas.Response(success=False, message=str(err)) diff --git a/app/core/config.py b/app/core/config.py index a8a71454..2d2f22e6 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -501,7 +501,7 @@ class ConfigModel(BaseModel): AI_AGENT_ENABLE: bool = False # 合局AI智能体 AI_AGENT_GLOBAL: bool = False - # LLM提供商 (openai/google/deepseek) + # LLM提供商(支持内置 provider,以及从 models.dev 动态补充的平台) LLM_PROVIDER: str = "deepseek" # LLM模型名称 LLM_MODEL: str = "deepseek-chat" diff --git a/scripts/local_setup.py b/scripts/local_setup.py index c71306e2..1b5368c6 100644 --- a/scripts/local_setup.py +++ b/scripts/local_setup.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse import asyncio import getpass +import importlib.util import json import os import platform @@ -66,6 +67,26 @@ LLM_PROVIDER_DEFAULTS = { "model": "gemini-2.5-flash", "base_url": "", }, + "anthropic": { + "model": "claude-sonnet-4-0", + "base_url": "https://api.anthropic.com/v1", + }, + "openrouter": { + "model": "openai/gpt-4.1-mini", + "base_url": "https://openrouter.ai/api/v1", + }, + "groq": { + "model": "llama-3.3-70b-versatile", + "base_url": "https://api.groq.com/openai/v1", + }, +} +LLM_PROVIDER_FALLBACK_CHOICES = { + "deepseek": "DeepSeek", + "openai": "OpenAI Compatible", + "google": "Google", + "anthropic": "Anthropic", + "openrouter": "OpenRouter", + "groq": "Groq", } RUNTIME_PACKAGE = { "name": "moviepilot-frontend-runtime", @@ -1063,6 +1084,273 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str: print("请输入列表中的可选值。") +def _prompt_provider_choice(label: str, choices: dict[str, str], default: str) -> str: + labels = [] + normalized_map: dict[str, str] = {} + for key, desc in choices.items(): + labels.append(f"{key}({desc})") + normalized_map[_normalize_choice(key)] = key + + preview_limit = 12 + print("可用 LLM 提供商:") + for item in labels[:preview_limit]: + print(f" {item}") + if len(labels) > preview_limit: + print(f" ... 另有 {len(labels) - preview_limit} 个,可直接输入 provider id") + + while True: + raw = input(f"{label} (默认 {default},可直接输入 provider id): ").strip() + if not raw: + return default + normalized = _normalize_choice(raw) + if normalized in normalized_map: + return normalized_map[normalized] + + provider_id = raw.strip().lower() + if re.fullmatch(r"[a-z0-9][a-z0-9._-]*", provider_id): + return provider_id + print("请输入列表中的可选值,或合法的 provider id(小写字母/数字/.-_)。") + + +def _load_llm_provider_module(): + provider_path = ROOT / "app" / "agent" / "llm" / "provider.py" + module_name = f"moviepilot_local_llm_provider_{uuid.uuid4().hex}" + spec = importlib.util.spec_from_file_location(module_name, provider_path) + if not spec or not spec.loader: + raise RuntimeError("无法加载 LLM provider 模块") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _load_llm_provider_definitions_inner() -> list[dict[str, Any]]: + provider_module = _load_llm_provider_module() + providers = asyncio.run(provider_module.LLMProviderManager().list_providers_async()) + return providers if isinstance(providers, list) else [] + + +def _load_llm_provider_definitions( + runtime_python: Optional[Path] = None, +) -> list[dict[str, Any]]: + try: + return _load_llm_provider_definitions_inner() + except Exception as exc: + if runtime_python and not _current_python_matches(runtime_python): + try: + with TemporaryDirectory() as temp_dir: + output_path = Path(temp_dir) / "llm-providers.json" + subprocess.run( + [ + str(runtime_python), + str(Path(__file__).resolve()), + "query-llm-providers", + "--output-json-file", + str(output_path), + ], + cwd=str(ROOT), + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + except Exception as runtime_exc: + print_step( + f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{runtime_exc}" + ) + return [] + + print_step(f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{exc}") + return [] + + +def _llm_provider_choice_map( + provider_definitions: list[dict[str, Any]], +) -> dict[str, str]: + choices: dict[str, str] = {} + for item in provider_definitions: + if not isinstance(item, dict): + continue + if item.get("supports_api_key") is False: + continue + provider_id = str(item.get("id") or "").strip().lower() + name = str(item.get("name") or provider_id).strip() + if not provider_id or not name: + continue + choices[provider_id] = name + if choices: + return choices + return dict(LLM_PROVIDER_FALLBACK_CHOICES) + + +def _llm_provider_defaults( + provider: str, + provider_definitions: list[dict[str, Any]], +) -> dict[str, str]: + normalized_provider = str(provider or "").strip().lower() + defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {}) + provider_meta = next( + ( + item + for item in provider_definitions + if isinstance(item, dict) + and str(item.get("id") or "").strip().lower() == normalized_provider + ), + None, + ) + if isinstance(provider_meta, dict): + default_base_url = str(provider_meta.get("default_base_url") or "").strip() + if default_base_url: + defaults["base_url"] = default_base_url + + defaults.setdefault("model", _env_default("LLM_MODEL", "")) + defaults.setdefault("base_url", _env_default("LLM_BASE_URL", "")) + return defaults + + +def _llm_provider_meta( + provider: str, + provider_definitions: list[dict[str, Any]], +) -> dict[str, Any]: + normalized_provider = str(provider or "").strip().lower() + provider_meta = next( + ( + item + for item in provider_definitions + if isinstance(item, dict) + and str(item.get("id") or "").strip().lower() == normalized_provider + ), + None, + ) + return dict(provider_meta) if isinstance(provider_meta, dict) else {} + + +def _load_llm_models_inner(payload: dict[str, Any]) -> list[dict[str, Any]]: + provider = str(payload.get("provider") or "").strip().lower() + if not provider: + raise RuntimeError("缺少 LLM provider") + + provider_module = _load_llm_provider_module() + api_key = str(payload.get("api_key") or "").strip() or None + base_url = str(payload.get("base_url") or "").strip() or None + models = asyncio.run( + provider_module.LLMProviderManager().list_models( + provider_id=provider, + api_key=api_key, + base_url=base_url, + force_refresh=False, + ) + ) + return models if isinstance(models, list) else [] + + +def _load_llm_models( + *, + provider: str, + api_key: Optional[str], + base_url: Optional[str], + runtime_python: Optional[Path] = None, +) -> list[dict[str, Any]]: + payload = { + "provider": str(provider or "").strip().lower(), + "api_key": str(api_key or "").strip(), + "base_url": str(base_url or "").strip(), + } + try: + return _load_llm_models_inner(payload) + except Exception as exc: + if runtime_python and not _current_python_matches(runtime_python): + try: + with TemporaryDirectory() as temp_dir: + request_path = Path(temp_dir) / "llm-models-request.json" + output_path = Path(temp_dir) / "llm-models.json" + request_path.write_text( + json.dumps(payload, ensure_ascii=False), encoding="utf-8" + ) + subprocess.run( + [ + str(runtime_python), + str(Path(__file__).resolve()), + "query-llm-models", + "--request-json-file", + str(request_path), + "--output-json-file", + str(output_path), + ], + cwd=str(ROOT), + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + except Exception as runtime_exc: + print_step( + f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{runtime_exc}" + ) + return [] + + print_step( + f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{exc}" + ) + return [] + + +def _print_llm_models(models: list[dict[str, Any]], limit: int = 20) -> None: + print("可用模型:") + for index, item in enumerate(models[:limit], start=1): + if not isinstance(item, dict): + continue + model_id = str(item.get("id") or "").strip() + if not model_id: + continue + model_name = str(item.get("name") or model_id).strip() + extras: list[str] = [] + if item.get("context_tokens_k"): + extras.append(f"{item['context_tokens_k']}K") + if item.get("supports_reasoning"): + extras.append("reasoning") + if item.get("supports_tools"): + extras.append("tools") + if item.get("supports_image_input"): + extras.append("vision") + extra_text = f" [{' / '.join(extras)}]" if extras else "" + if model_name != model_id: + print(f" {index}. {model_id} ({model_name}){extra_text}") + else: + print(f" {index}. {model_id}{extra_text}") + if len(models) > limit: + print(f" ... 共 {len(models)} 个模型,可输入编号或直接输入模型名称") + + +def _prompt_model_choice(models: list[dict[str, Any]], default: Optional[str] = None) -> str: + valid_models = [item for item in models if isinstance(item, dict) and item.get("id")] + if not valid_models: + return _prompt_text("LLM 模型名称", default=default) + + indexed_models = { + str(index): str(item.get("id")).strip() + for index, item in enumerate(valid_models, start=1) + } + default_model = str(default or indexed_models.get("1") or "").strip() + _print_llm_models(valid_models) + + while True: + raw = input( + f"LLM 模型名称/编号{' [' + default_model + ']' if default_model else ''}: " + ).strip() + if not raw and default_model: + return default_model + if raw in indexed_models: + return indexed_models[raw] + if raw: + return raw + print("请输入有效模型编号或模型名称。") + + def _env_llm_thinking_level_default() -> str: value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", "")) alias_map = { @@ -1459,7 +1747,9 @@ def _collect_notification_config() -> Optional[dict[str, Any]]: } -def _collect_agent_config() -> dict[str, Any]: +def _collect_agent_config( + runtime_python: Optional[Path] = None, +) -> dict[str, Any]: print_step("AI Agent 配置") enabled = _prompt_yes_no( "是否启用 AI 智能体", @@ -1471,22 +1761,42 @@ def _collect_agent_config() -> dict[str, Any]: "AI_AGENT_GLOBAL": False, } + provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python) + provider_choices = _llm_provider_choice_map(provider_definitions) current_provider = _env_default("LLM_PROVIDER", "deepseek").lower() - if current_provider not in LLM_PROVIDER_DEFAULTS: + if current_provider not in provider_choices: current_provider = "deepseek" - provider = _prompt_choice( - "选择 LLM 提供商", - { - "deepseek": "DeepSeek", - "openai": "OpenAI", - "google": "Google", - }, - default=current_provider, - ) - defaults = LLM_PROVIDER_DEFAULTS[provider] + while True: + provider = _prompt_provider_choice( + "选择 LLM 提供商", + provider_choices, + default=current_provider, + ) + provider_meta = _llm_provider_meta(provider, provider_definitions) + if provider_meta.get("supports_api_key") is False: + print_step( + f"{provider_meta.get('name') or provider} 当前仅支持交互式授权,安装向导暂不支持,请改选可填写 API Key 的 provider。" + ) + current_provider = "deepseek" + continue + break + + defaults = _llm_provider_defaults(provider, provider_definitions) current_model = _env_default("LLM_MODEL", defaults["model"]) current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"]) + api_key_label = str(provider_meta.get("api_key_label") or "API Key").strip() or "API Key" + api_key_hint = str(provider_meta.get("api_key_hint") or "").strip() + requires_base_url = bool(provider_meta.get("requires_base_url")) + base_url_label = ( + "自定义 Google API Base URL(可选)" + if provider == "google" + else "LLM Base URL(必填)" + if requires_base_url + else "LLM Base URL" + ) + if api_key_hint: + print_step(api_key_hint) config: dict[str, Any] = { "AI_AGENT_ENABLE": True, @@ -1495,12 +1805,8 @@ def _collect_agent_config() -> dict[str, Any]: default=_env_bool("AI_AGENT_GLOBAL", False), ), "LLM_PROVIDER": provider, - "LLM_MODEL": _prompt_text( - "LLM 模型名称", - default=current_model, - ), "LLM_API_KEY": _prompt_secret_text( - "LLM API Key", + f"LLM {api_key_label}", current_value=read_env_value("LLM_API_KEY"), required=True, ), @@ -1524,18 +1830,18 @@ def _collect_agent_config() -> dict[str, Any]: ), } - if provider == "google": - config["LLM_BASE_URL"] = _prompt_text( - "自定义 Google API Base URL(可选)", - default=current_base_url, - allow_empty=True, - ) - else: - config["LLM_BASE_URL"] = _prompt_text( - "LLM Base URL", - default=current_base_url, - allow_empty=True, - ) + config["LLM_BASE_URL"] = _prompt_text( + base_url_label, + default=current_base_url, + allow_empty=not requires_base_url, + ) + models = _load_llm_models( + provider=provider, + api_key=config["LLM_API_KEY"], + base_url=config["LLM_BASE_URL"], + runtime_python=runtime_python, + ) + config["LLM_MODEL"] = _prompt_model_choice(models, default=current_model) return config @@ -1744,7 +2050,7 @@ def run_setup_wizard( preset_password=preset_superuser_password, ), **_collect_database_config(), - **_collect_agent_config(), + **_collect_agent_config(runtime_python=runtime_python), }, "directories": [_collect_directory_config()], "downloader": _collect_downloader_config(), @@ -3276,6 +3582,23 @@ def build_parser() -> argparse.ArgumentParser: "--output-json-file", required=True, help=argparse.SUPPRESS ) + query_llm_providers_parser = subparsers.add_parser( + "query-llm-providers", help=argparse.SUPPRESS + ) + query_llm_providers_parser.add_argument( + "--output-json-file", required=True, help=argparse.SUPPRESS + ) + + query_llm_models_parser = subparsers.add_parser( + "query-llm-models", help=argparse.SUPPRESS + ) + query_llm_models_parser.add_argument( + "--request-json-file", required=True, help=argparse.SUPPRESS + ) + query_llm_models_parser.add_argument( + "--output-json-file", required=True, help=argparse.SUPPRESS + ) + return parser @@ -3461,6 +3784,25 @@ def main() -> int: encoding="utf-8", ) return 0 + + if args.command == "query-llm-providers": + payload = _load_llm_provider_definitions_inner() + Path(args.output_json_file).write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return 0 + + if args.command == "query-llm-models": + payload = json.loads(Path(args.request_json_file).read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise RuntimeError("模型查询负载格式错误") + models = _load_llm_models_inner(payload) + Path(args.output_json_file).write_text( + json.dumps(models, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return 0 except subprocess.CalledProcessError as exc: print(f"命令执行失败,退出码:{exc.returncode}", file=sys.stderr) return exc.returncode diff --git a/tests/test_llm_provider_registry.py b/tests/test_llm_provider_registry.py new file mode 100644 index 00000000..76550e14 --- /dev/null +++ b/tests/test_llm_provider_registry.py @@ -0,0 +1,215 @@ +import asyncio +import importlib.util +import sys +import unittest +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import AsyncMock, patch + + +def _stub_module(name: str, **attrs): + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + sys.modules[name] = module + for key, value in attrs.items(): + setattr(module, key, value) + return module + + +class _DummyLogger: + def __getattr__(self, _name): + return lambda *args, **kwargs: None + + +class _DummySystemConfigOper: + def get(self, _key): + return {} + + async def async_set(self, _key, _value): + return None + + +for _module_name in ("aiofiles", "jwt"): + _stub_module(_module_name) + +_stub_module( + "app.core.config", + settings=SimpleNamespace( + TEMP_PATH="/tmp", + PROXY_HOST=None, + LLM_MAX_CONTEXT_TOKENS=64, + ), +) +_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper) +_stub_module("app.log", logger=_DummyLogger()) +_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent")) + +provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py" +spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path) +provider_module = importlib.util.module_from_spec(spec) +assert spec and spec.loader +sys.modules[spec.name] = provider_module +spec.loader.exec_module(provider_module) + +LLMProviderError = provider_module.LLMProviderError +LLMProviderManager = provider_module.LLMProviderManager + + +class LlmProviderRegistryTest(unittest.TestCase): + def setUp(self): + LLMProviderManager._instances.clear() + + def tearDown(self): + LLMProviderManager._instances.clear() + + def test_dynamic_provider_is_exposed_from_models_dev_cache(self): + manager = LLMProviderManager() + manager._models_dev_data = { + "frogbot": { + "id": "frogbot", + "name": "FrogBot", + "npm": "@ai-sdk/openai-compatible", + "env": ["FROGBOT_API_KEY"], + "api": "https://app.frogbot.ai/api/v1", + "models": {}, + } + } + + provider = manager.get_provider("frogbot") + + self.assertEqual(provider.id, "frogbot") + self.assertEqual(provider.runtime, "openai_compatible") + self.assertEqual(provider.default_base_url, "https://app.frogbot.ai/api/v1") + self.assertFalse(provider.requires_base_url) + self.assertTrue(provider.base_url_editable) + self.assertEqual(provider.model_list_strategy, "models_dev_only") + + def test_dynamic_provider_override_normalizes_chat_endpoint_base_url(self): + manager = LLMProviderManager() + manager._models_dev_data = { + "bailing": { + "id": "bailing", + "name": "Bailing", + "npm": "@ai-sdk/openai-compatible", + "env": ["BAILING_API_TOKEN"], + "api": "https://api.tbox.cn/api/llm/v1/chat/completions", + "models": {}, + } + } + + provider = manager.get_provider("bailing") + + self.assertEqual(provider.default_base_url, "https://api.tbox.cn/api/llm/v1") + self.assertEqual(provider.api_key_label, "API Token") + + def test_dynamic_provider_skips_alias_only_models_dev_ids(self): + manager = LLMProviderManager() + manager._models_dev_data = { + "moonshotai": { + "id": "moonshotai", + "name": "Moonshot AI Intl", + "npm": "@ai-sdk/openai-compatible", + "env": ["MOONSHOT_API_KEY"], + "api": "https://api.moonshot.ai/v1", + "models": {}, + } + } + + with self.assertRaises(LLMProviderError): + manager.get_provider("moonshotai") + + def test_dynamic_provider_skips_incompatible_models_dev_provider(self): + manager = LLMProviderManager() + manager._models_dev_data = { + "azure": { + "id": "azure", + "name": "Azure", + "npm": "@ai-sdk/azure", + "env": ["AZURE_API_KEY"], + "models": {}, + } + } + + with self.assertRaises(LLMProviderError): + manager.get_provider("azure") + + def test_dynamic_provider_without_known_base_url_requires_manual_input(self): + manager = LLMProviderManager() + manager._models_dev_data = { + "custom-anthropic": { + "id": "custom-anthropic", + "name": "Custom Anthropic", + "npm": "@ai-sdk/anthropic", + "env": ["CUSTOM_ANTHROPIC_KEY"], + "models": {}, + } + } + + provider = manager.get_provider("custom-anthropic") + + self.assertEqual(provider.runtime, "anthropic_compatible") + self.assertTrue(provider.requires_base_url) + self.assertTrue(provider.base_url_editable) + self.assertEqual(provider.model_list_strategy, "anthropic_compatible") + + def test_list_providers_async_loads_models_dev_before_serializing(self): + manager = LLMProviderManager() + payload = { + "frogbot": { + "id": "frogbot", + "name": "FrogBot", + "npm": "@ai-sdk/openai-compatible", + "env": ["FROGBOT_API_KEY"], + "api": "https://app.frogbot.ai/api/v1", + "models": {}, + } + } + + with patch.object( + manager, + "get_models_dev_data", + AsyncMock(side_effect=lambda force_refresh=False: manager.__dict__.update({"_models_dev_data": payload}) or payload), + ) as fetch_mock: + providers = asyncio.run(manager.list_providers_async()) + + fetch_mock.assert_awaited_once_with(force_refresh=False) + self.assertIn("frogbot", {item["id"] for item in providers}) + + def test_list_models_uses_dynamic_provider_after_refresh(self): + manager = LLMProviderManager() + payload = { + "frogbot": { + "id": "frogbot", + "name": "FrogBot", + "npm": "@ai-sdk/openai-compatible", + "env": ["FROGBOT_API_KEY"], + "api": "https://app.frogbot.ai/api/v1", + "models": { + "frog-1": { + "name": "Frog 1", + "limit": {"context": 131072}, + } + }, + } + } + + async def _load_models_dev(force_refresh: bool = False): + manager._models_dev_data = payload + return payload + + with patch.object(manager, "get_models_dev_data", AsyncMock(side_effect=_load_models_dev)): + models = asyncio.run( + manager.list_models( + provider_id="frogbot", + api_key="sk-test", + force_refresh=True, + ) + ) + + self.assertEqual([item["id"] for item in models], ["frog-1"]) + self.assertEqual(models[0]["source"], "models.dev") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_local_setup_llm_provider_prompt.py b/tests/test_local_setup_llm_provider_prompt.py new file mode 100644 index 00000000..d5805713 --- /dev/null +++ b/tests/test_local_setup_llm_provider_prompt.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import importlib.util +import unittest +import uuid +from pathlib import Path +from unittest.mock import patch + + +MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "local_setup.py" + + +def load_local_setup_module(): + module_name = f"moviepilot_local_setup_llm_{uuid.uuid4().hex}" + spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) + return module + + +class LocalSetupLlmProviderPromptTests(unittest.TestCase): + def test_collect_agent_config_prefers_loaded_provider_directory(self): + module = load_local_setup_module() + + provider_definitions = [ + { + "id": "frogbot", + "name": "FrogBot", + "default_base_url": "https://app.frogbot.ai/api/v1", + "api_key_label": "API Key", + } + ] + models = [ + {"id": "frog-1", "name": "Frog 1", "context_tokens_k": 128}, + {"id": "frog-2", "name": "Frog 2"}, + ] + + with patch.object(module, "print_step"), patch.object( + module, "_prompt_yes_no", side_effect=[True, False, True] + ), patch.object( + module, "_load_llm_provider_definitions", return_value=provider_definitions + ), patch.object( + module, "_prompt_provider_choice", return_value="frogbot" + ) as provider_prompt, patch.object( + module, "_prompt_text", side_effect=["https://override.example.com/v1"] + ), patch.object( + module, "_prompt_secret_text", return_value="sk-frog" + ), patch.object( + module, "_load_llm_models", return_value=models + ) as load_models, patch.object( + module, "_prompt_model_choice", return_value="frog-2" + ) as model_prompt, patch.object( + module, "read_env_value", return_value=None + ), patch.object( + module, "_env_default", side_effect=lambda key, default="": default + ), patch.object( + module, "_env_bool", side_effect=lambda key, default: default + ), patch.object( + module, "_env_llm_thinking_level_default", return_value="auto" + ), patch.object( + module, "_prompt_choice", return_value="auto" + ): + config = module._collect_agent_config(runtime_python=Path("/tmp/runtime-python")) + + provider_prompt.assert_called_once() + load_models.assert_called_once_with( + provider="frogbot", + api_key="sk-frog", + base_url="https://override.example.com/v1", + runtime_python=Path("/tmp/runtime-python"), + ) + model_prompt.assert_called_once_with(models, default="") + self.assertEqual(config["LLM_PROVIDER"], "frogbot") + self.assertEqual(config["LLM_MODEL"], "frog-2") + self.assertEqual(config["LLM_API_KEY"], "sk-frog") + self.assertEqual(config["LLM_BASE_URL"], "https://override.example.com/v1") + + def test_collect_agent_config_falls_back_to_common_provider_choices(self): + module = load_local_setup_module() + + with patch.object(module, "print_step"), patch.object( + module, "_prompt_yes_no", side_effect=[True, False, True] + ), patch.object( + module, "_load_llm_provider_definitions", return_value=[] + ), patch.object( + module, "_prompt_provider_choice", return_value="anthropic" + ), patch.object( + module, "_prompt_text", side_effect=["https://api.anthropic.com/v1"] + ), patch.object( + module, "_prompt_secret_text", return_value="sk-anthropic" + ), patch.object( + module, "_load_llm_models", return_value=[] + ), patch.object( + module, "_prompt_model_choice", return_value="claude-sonnet-4-0" + ), patch.object( + module, "read_env_value", return_value=None + ), patch.object( + module, "_env_default", side_effect=lambda key, default="": default + ), patch.object( + module, "_env_bool", side_effect=lambda key, default: default + ), patch.object( + module, "_env_llm_thinking_level_default", return_value="off" + ), patch.object( + module, "_prompt_choice", return_value="off" + ): + config = module._collect_agent_config() + + self.assertEqual(config["LLM_PROVIDER"], "anthropic") + self.assertEqual(config["LLM_MODEL"], "claude-sonnet-4-0") + self.assertEqual(config["LLM_BASE_URL"], "https://api.anthropic.com/v1") + + def test_prompt_model_choice_accepts_index_selection(self): + module = load_local_setup_module() + + with patch.object(module, "_print_llm_models") as print_models, patch( + "builtins.input", return_value="2" + ): + model = module._prompt_model_choice( + [ + {"id": "model-a", "name": "Model A"}, + {"id": "model-b", "name": "Model B"}, + ], + default="model-a", + ) + + print_models.assert_called_once() + self.assertEqual(model, "model-b") + + def test_prompt_model_choice_falls_back_to_text_input_when_empty(self): + module = load_local_setup_module() + + with patch.object(module, "_prompt_text", return_value="custom-model") as prompt_text: + model = module._prompt_model_choice([], default="") + + prompt_text.assert_called_once_with("LLM 模型名称", default="") + self.assertEqual(model, "custom-model") + + def test_load_llm_provider_definitions_inner_uses_direct_provider_module_loader(self): + module = load_local_setup_module() + + class _FakeManager: + async def list_providers_async(self, force_refresh: bool = False): + return [{"id": "frogbot", "name": "FrogBot"}] + + class _FakeProviderModule: + @staticmethod + def LLMProviderManager(): + return _FakeManager() + + fake_provider_module = _FakeProviderModule() + + with patch.object( + module, + "_load_llm_provider_module", + return_value=fake_provider_module, + ) as loader: + providers = module._load_llm_provider_definitions_inner() + + loader.assert_called_once_with() + self.assertEqual(providers, [{"id": "frogbot", "name": "FrogBot"}]) + + def test_llm_provider_choice_map_skips_oauth_only_provider(self): + module = load_local_setup_module() + + choices = module._llm_provider_choice_map( + [ + {"id": "chatgpt", "name": "ChatGPT", "supports_api_key": True}, + {"id": "github-copilot", "name": "GitHub Copilot", "supports_api_key": False}, + ] + ) + + self.assertEqual(choices, {"chatgpt": "ChatGPT"}) + + def test_prompt_provider_choice_accepts_custom_provider_id(self): + module = load_local_setup_module() + + with patch("builtins.input", return_value="my-provider_01"), patch("builtins.print"): + provider = module._prompt_provider_choice( + "选择 LLM 提供商", + {"deepseek": "DeepSeek", "google": "Google"}, + default="deepseek", + ) + + self.assertEqual(provider, "my-provider_01") + + +if __name__ == "__main__": + unittest.main()