feat(agent): expand LLM provider and wizard support

This commit is contained in:
jxxghp
2026-05-08 08:09:50 +08:00
parent 2d2c2a01eb
commit 05d720d81f
6 changed files with 1080 additions and 72 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import asyncio
import getpass
import importlib.util
import json
import os
import platform
@@ -66,6 +67,26 @@ LLM_PROVIDER_DEFAULTS = {
"model": "gemini-2.5-flash",
"base_url": "",
},
"anthropic": {
"model": "claude-sonnet-4-0",
"base_url": "https://api.anthropic.com/v1",
},
"openrouter": {
"model": "openai/gpt-4.1-mini",
"base_url": "https://openrouter.ai/api/v1",
},
"groq": {
"model": "llama-3.3-70b-versatile",
"base_url": "https://api.groq.com/openai/v1",
},
}
LLM_PROVIDER_FALLBACK_CHOICES = {
"deepseek": "DeepSeek",
"openai": "OpenAI Compatible",
"google": "Google",
"anthropic": "Anthropic",
"openrouter": "OpenRouter",
"groq": "Groq",
}
RUNTIME_PACKAGE = {
"name": "moviepilot-frontend-runtime",
@@ -1063,6 +1084,273 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
print("请输入列表中的可选值。")
def _prompt_provider_choice(label: str, choices: dict[str, str], default: str) -> str:
labels = []
normalized_map: dict[str, str] = {}
for key, desc in choices.items():
labels.append(f"{key}({desc})")
normalized_map[_normalize_choice(key)] = key
preview_limit = 12
print("可用 LLM 提供商:")
for item in labels[:preview_limit]:
print(f" {item}")
if len(labels) > preview_limit:
print(f" ... 另有 {len(labels) - preview_limit} 个,可直接输入 provider id")
while True:
raw = input(f"{label} (默认 {default},可直接输入 provider id): ").strip()
if not raw:
return default
normalized = _normalize_choice(raw)
if normalized in normalized_map:
return normalized_map[normalized]
provider_id = raw.strip().lower()
if re.fullmatch(r"[a-z0-9][a-z0-9._-]*", provider_id):
return provider_id
print("请输入列表中的可选值,或合法的 provider id小写字母/数字/.-_")
def _load_llm_provider_module():
provider_path = ROOT / "app" / "agent" / "llm" / "provider.py"
module_name = f"moviepilot_local_llm_provider_{uuid.uuid4().hex}"
spec = importlib.util.spec_from_file_location(module_name, provider_path)
if not spec or not spec.loader:
raise RuntimeError("无法加载 LLM provider 模块")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _load_llm_provider_definitions_inner() -> list[dict[str, Any]]:
provider_module = _load_llm_provider_module()
providers = asyncio.run(provider_module.LLMProviderManager().list_providers_async())
return providers if isinstance(providers, list) else []
def _load_llm_provider_definitions(
runtime_python: Optional[Path] = None,
) -> list[dict[str, Any]]:
try:
return _load_llm_provider_definitions_inner()
except Exception as exc:
if runtime_python and not _current_python_matches(runtime_python):
try:
with TemporaryDirectory() as temp_dir:
output_path = Path(temp_dir) / "llm-providers.json"
subprocess.run(
[
str(runtime_python),
str(Path(__file__).resolve()),
"query-llm-providers",
"--output-json-file",
str(output_path),
],
cwd=str(ROOT),
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
data = json.loads(output_path.read_text(encoding="utf-8"))
if isinstance(data, list):
return data
except Exception as runtime_exc:
print_step(
f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{runtime_exc}"
)
return []
print_step(f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{exc}")
return []
def _llm_provider_choice_map(
provider_definitions: list[dict[str, Any]],
) -> dict[str, str]:
choices: dict[str, str] = {}
for item in provider_definitions:
if not isinstance(item, dict):
continue
if item.get("supports_api_key") is False:
continue
provider_id = str(item.get("id") or "").strip().lower()
name = str(item.get("name") or provider_id).strip()
if not provider_id or not name:
continue
choices[provider_id] = name
if choices:
return choices
return dict(LLM_PROVIDER_FALLBACK_CHOICES)
def _llm_provider_defaults(
provider: str,
provider_definitions: list[dict[str, Any]],
) -> dict[str, str]:
normalized_provider = str(provider or "").strip().lower()
defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {})
provider_meta = next(
(
item
for item in provider_definitions
if isinstance(item, dict)
and str(item.get("id") or "").strip().lower() == normalized_provider
),
None,
)
if isinstance(provider_meta, dict):
default_base_url = str(provider_meta.get("default_base_url") or "").strip()
if default_base_url:
defaults["base_url"] = default_base_url
defaults.setdefault("model", _env_default("LLM_MODEL", ""))
defaults.setdefault("base_url", _env_default("LLM_BASE_URL", ""))
return defaults
def _llm_provider_meta(
provider: str,
provider_definitions: list[dict[str, Any]],
) -> dict[str, Any]:
normalized_provider = str(provider or "").strip().lower()
provider_meta = next(
(
item
for item in provider_definitions
if isinstance(item, dict)
and str(item.get("id") or "").strip().lower() == normalized_provider
),
None,
)
return dict(provider_meta) if isinstance(provider_meta, dict) else {}
def _load_llm_models_inner(payload: dict[str, Any]) -> list[dict[str, Any]]:
provider = str(payload.get("provider") or "").strip().lower()
if not provider:
raise RuntimeError("缺少 LLM provider")
provider_module = _load_llm_provider_module()
api_key = str(payload.get("api_key") or "").strip() or None
base_url = str(payload.get("base_url") or "").strip() or None
models = asyncio.run(
provider_module.LLMProviderManager().list_models(
provider_id=provider,
api_key=api_key,
base_url=base_url,
force_refresh=False,
)
)
return models if isinstance(models, list) else []
def _load_llm_models(
*,
provider: str,
api_key: Optional[str],
base_url: Optional[str],
runtime_python: Optional[Path] = None,
) -> list[dict[str, Any]]:
payload = {
"provider": str(provider or "").strip().lower(),
"api_key": str(api_key or "").strip(),
"base_url": str(base_url or "").strip(),
}
try:
return _load_llm_models_inner(payload)
except Exception as exc:
if runtime_python and not _current_python_matches(runtime_python):
try:
with TemporaryDirectory() as temp_dir:
request_path = Path(temp_dir) / "llm-models-request.json"
output_path = Path(temp_dir) / "llm-models.json"
request_path.write_text(
json.dumps(payload, ensure_ascii=False), encoding="utf-8"
)
subprocess.run(
[
str(runtime_python),
str(Path(__file__).resolve()),
"query-llm-models",
"--request-json-file",
str(request_path),
"--output-json-file",
str(output_path),
],
cwd=str(ROOT),
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
data = json.loads(output_path.read_text(encoding="utf-8"))
if isinstance(data, list):
return data
except Exception as runtime_exc:
print_step(
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{runtime_exc}"
)
return []
print_step(
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{exc}"
)
return []
def _print_llm_models(models: list[dict[str, Any]], limit: int = 20) -> None:
print("可用模型:")
for index, item in enumerate(models[:limit], start=1):
if not isinstance(item, dict):
continue
model_id = str(item.get("id") or "").strip()
if not model_id:
continue
model_name = str(item.get("name") or model_id).strip()
extras: list[str] = []
if item.get("context_tokens_k"):
extras.append(f"{item['context_tokens_k']}K")
if item.get("supports_reasoning"):
extras.append("reasoning")
if item.get("supports_tools"):
extras.append("tools")
if item.get("supports_image_input"):
extras.append("vision")
extra_text = f" [{' / '.join(extras)}]" if extras else ""
if model_name != model_id:
print(f" {index}. {model_id} ({model_name}){extra_text}")
else:
print(f" {index}. {model_id}{extra_text}")
if len(models) > limit:
print(f" ... 共 {len(models)} 个模型,可输入编号或直接输入模型名称")
def _prompt_model_choice(models: list[dict[str, Any]], default: Optional[str] = None) -> str:
valid_models = [item for item in models if isinstance(item, dict) and item.get("id")]
if not valid_models:
return _prompt_text("LLM 模型名称", default=default)
indexed_models = {
str(index): str(item.get("id")).strip()
for index, item in enumerate(valid_models, start=1)
}
default_model = str(default or indexed_models.get("1") or "").strip()
_print_llm_models(valid_models)
while True:
raw = input(
f"LLM 模型名称/编号{' [' + default_model + ']' if default_model else ''}: "
).strip()
if not raw and default_model:
return default_model
if raw in indexed_models:
return indexed_models[raw]
if raw:
return raw
print("请输入有效模型编号或模型名称。")
def _env_llm_thinking_level_default() -> str:
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
alias_map = {
@@ -1459,7 +1747,9 @@ def _collect_notification_config() -> Optional[dict[str, Any]]:
}
def _collect_agent_config() -> dict[str, Any]:
def _collect_agent_config(
runtime_python: Optional[Path] = None,
) -> dict[str, Any]:
print_step("AI Agent 配置")
enabled = _prompt_yes_no(
"是否启用 AI 智能体",
@@ -1471,22 +1761,42 @@ def _collect_agent_config() -> dict[str, Any]:
"AI_AGENT_GLOBAL": False,
}
provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python)
provider_choices = _llm_provider_choice_map(provider_definitions)
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
if current_provider not in LLM_PROVIDER_DEFAULTS:
if current_provider not in provider_choices:
current_provider = "deepseek"
provider = _prompt_choice(
"选择 LLM 提供商",
{
"deepseek": "DeepSeek",
"openai": "OpenAI",
"google": "Google",
},
default=current_provider,
)
defaults = LLM_PROVIDER_DEFAULTS[provider]
while True:
provider = _prompt_provider_choice(
"选择 LLM 提供商",
provider_choices,
default=current_provider,
)
provider_meta = _llm_provider_meta(provider, provider_definitions)
if provider_meta.get("supports_api_key") is False:
print_step(
f"{provider_meta.get('name') or provider} 当前仅支持交互式授权,安装向导暂不支持,请改选可填写 API Key 的 provider"
)
current_provider = "deepseek"
continue
break
defaults = _llm_provider_defaults(provider, provider_definitions)
current_model = _env_default("LLM_MODEL", defaults["model"])
current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"])
api_key_label = str(provider_meta.get("api_key_label") or "API Key").strip() or "API Key"
api_key_hint = str(provider_meta.get("api_key_hint") or "").strip()
requires_base_url = bool(provider_meta.get("requires_base_url"))
base_url_label = (
"自定义 Google API Base URL可选"
if provider == "google"
else "LLM Base URL必填"
if requires_base_url
else "LLM Base URL"
)
if api_key_hint:
print_step(api_key_hint)
config: dict[str, Any] = {
"AI_AGENT_ENABLE": True,
@@ -1495,12 +1805,8 @@ def _collect_agent_config() -> dict[str, Any]:
default=_env_bool("AI_AGENT_GLOBAL", False),
),
"LLM_PROVIDER": provider,
"LLM_MODEL": _prompt_text(
"LLM 模型名称",
default=current_model,
),
"LLM_API_KEY": _prompt_secret_text(
"LLM API Key",
f"LLM {api_key_label}",
current_value=read_env_value("LLM_API_KEY"),
required=True,
),
@@ -1524,18 +1830,18 @@ def _collect_agent_config() -> dict[str, Any]:
),
}
if provider == "google":
config["LLM_BASE_URL"] = _prompt_text(
"自定义 Google API Base URL可选",
default=current_base_url,
allow_empty=True,
)
else:
config["LLM_BASE_URL"] = _prompt_text(
"LLM Base URL",
default=current_base_url,
allow_empty=True,
)
config["LLM_BASE_URL"] = _prompt_text(
base_url_label,
default=current_base_url,
allow_empty=not requires_base_url,
)
models = _load_llm_models(
provider=provider,
api_key=config["LLM_API_KEY"],
base_url=config["LLM_BASE_URL"],
runtime_python=runtime_python,
)
config["LLM_MODEL"] = _prompt_model_choice(models, default=current_model)
return config
@@ -1744,7 +2050,7 @@ def run_setup_wizard(
preset_password=preset_superuser_password,
),
**_collect_database_config(),
**_collect_agent_config(),
**_collect_agent_config(runtime_python=runtime_python),
},
"directories": [_collect_directory_config()],
"downloader": _collect_downloader_config(),
@@ -3276,6 +3582,23 @@ def build_parser() -> argparse.ArgumentParser:
"--output-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_providers_parser = subparsers.add_parser(
"query-llm-providers", help=argparse.SUPPRESS
)
query_llm_providers_parser.add_argument(
"--output-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_models_parser = subparsers.add_parser(
"query-llm-models", help=argparse.SUPPRESS
)
query_llm_models_parser.add_argument(
"--request-json-file", required=True, help=argparse.SUPPRESS
)
query_llm_models_parser.add_argument(
"--output-json-file", required=True, help=argparse.SUPPRESS
)
return parser
@@ -3461,6 +3784,25 @@ def main() -> int:
encoding="utf-8",
)
return 0
if args.command == "query-llm-providers":
payload = _load_llm_provider_definitions_inner()
Path(args.output_json_file).write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return 0
if args.command == "query-llm-models":
payload = json.loads(Path(args.request_json_file).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("模型查询负载格式错误")
models = _load_llm_models_inner(payload)
Path(args.output_json_file).write_text(
json.dumps(models, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return 0
except subprocess.CalledProcessError as exc:
print(f"命令执行失败,退出码:{exc.returncode}", file=sys.stderr)
return exc.returncode