mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-13 07:26:45 +00:00
feat(agent): expand LLM provider and wizard support
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
import getpass
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@@ -66,6 +67,26 @@ LLM_PROVIDER_DEFAULTS = {
|
||||
"model": "gemini-2.5-flash",
|
||||
"base_url": "",
|
||||
},
|
||||
"anthropic": {
|
||||
"model": "claude-sonnet-4-0",
|
||||
"base_url": "https://api.anthropic.com/v1",
|
||||
},
|
||||
"openrouter": {
|
||||
"model": "openai/gpt-4.1-mini",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"groq": {
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
"base_url": "https://api.groq.com/openai/v1",
|
||||
},
|
||||
}
|
||||
LLM_PROVIDER_FALLBACK_CHOICES = {
|
||||
"deepseek": "DeepSeek",
|
||||
"openai": "OpenAI Compatible",
|
||||
"google": "Google",
|
||||
"anthropic": "Anthropic",
|
||||
"openrouter": "OpenRouter",
|
||||
"groq": "Groq",
|
||||
}
|
||||
RUNTIME_PACKAGE = {
|
||||
"name": "moviepilot-frontend-runtime",
|
||||
@@ -1063,6 +1084,273 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
|
||||
print("请输入列表中的可选值。")
|
||||
|
||||
|
||||
def _prompt_provider_choice(label: str, choices: dict[str, str], default: str) -> str:
|
||||
labels = []
|
||||
normalized_map: dict[str, str] = {}
|
||||
for key, desc in choices.items():
|
||||
labels.append(f"{key}({desc})")
|
||||
normalized_map[_normalize_choice(key)] = key
|
||||
|
||||
preview_limit = 12
|
||||
print("可用 LLM 提供商:")
|
||||
for item in labels[:preview_limit]:
|
||||
print(f" {item}")
|
||||
if len(labels) > preview_limit:
|
||||
print(f" ... 另有 {len(labels) - preview_limit} 个,可直接输入 provider id")
|
||||
|
||||
while True:
|
||||
raw = input(f"{label} (默认 {default},可直接输入 provider id): ").strip()
|
||||
if not raw:
|
||||
return default
|
||||
normalized = _normalize_choice(raw)
|
||||
if normalized in normalized_map:
|
||||
return normalized_map[normalized]
|
||||
|
||||
provider_id = raw.strip().lower()
|
||||
if re.fullmatch(r"[a-z0-9][a-z0-9._-]*", provider_id):
|
||||
return provider_id
|
||||
print("请输入列表中的可选值,或合法的 provider id(小写字母/数字/.-_)。")
|
||||
|
||||
|
||||
def _load_llm_provider_module():
|
||||
provider_path = ROOT / "app" / "agent" / "llm" / "provider.py"
|
||||
module_name = f"moviepilot_local_llm_provider_{uuid.uuid4().hex}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, provider_path)
|
||||
if not spec or not spec.loader:
|
||||
raise RuntimeError("无法加载 LLM provider 模块")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _load_llm_provider_definitions_inner() -> list[dict[str, Any]]:
|
||||
provider_module = _load_llm_provider_module()
|
||||
providers = asyncio.run(provider_module.LLMProviderManager().list_providers_async())
|
||||
return providers if isinstance(providers, list) else []
|
||||
|
||||
|
||||
def _load_llm_provider_definitions(
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
return _load_llm_provider_definitions_inner()
|
||||
except Exception as exc:
|
||||
if runtime_python and not _current_python_matches(runtime_python):
|
||||
try:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
output_path = Path(temp_dir) / "llm-providers.json"
|
||||
subprocess.run(
|
||||
[
|
||||
str(runtime_python),
|
||||
str(Path(__file__).resolve()),
|
||||
"query-llm-providers",
|
||||
"--output-json-file",
|
||||
str(output_path),
|
||||
],
|
||||
cwd=str(ROOT),
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
data = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except Exception as runtime_exc:
|
||||
print_step(
|
||||
f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{runtime_exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
print_step(f"当前环境暂时无法读取 LLM 提供商目录,已回退到常用平台列表:{exc}")
|
||||
return []
|
||||
|
||||
|
||||
def _llm_provider_choice_map(
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
choices: dict[str, str] = {}
|
||||
for item in provider_definitions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("supports_api_key") is False:
|
||||
continue
|
||||
provider_id = str(item.get("id") or "").strip().lower()
|
||||
name = str(item.get("name") or provider_id).strip()
|
||||
if not provider_id or not name:
|
||||
continue
|
||||
choices[provider_id] = name
|
||||
if choices:
|
||||
return choices
|
||||
return dict(LLM_PROVIDER_FALLBACK_CHOICES)
|
||||
|
||||
|
||||
def _llm_provider_defaults(
|
||||
provider: str,
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
normalized_provider = str(provider or "").strip().lower()
|
||||
defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {})
|
||||
provider_meta = next(
|
||||
(
|
||||
item
|
||||
for item in provider_definitions
|
||||
if isinstance(item, dict)
|
||||
and str(item.get("id") or "").strip().lower() == normalized_provider
|
||||
),
|
||||
None,
|
||||
)
|
||||
if isinstance(provider_meta, dict):
|
||||
default_base_url = str(provider_meta.get("default_base_url") or "").strip()
|
||||
if default_base_url:
|
||||
defaults["base_url"] = default_base_url
|
||||
|
||||
defaults.setdefault("model", _env_default("LLM_MODEL", ""))
|
||||
defaults.setdefault("base_url", _env_default("LLM_BASE_URL", ""))
|
||||
return defaults
|
||||
|
||||
|
||||
def _llm_provider_meta(
|
||||
provider: str,
|
||||
provider_definitions: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
normalized_provider = str(provider or "").strip().lower()
|
||||
provider_meta = next(
|
||||
(
|
||||
item
|
||||
for item in provider_definitions
|
||||
if isinstance(item, dict)
|
||||
and str(item.get("id") or "").strip().lower() == normalized_provider
|
||||
),
|
||||
None,
|
||||
)
|
||||
return dict(provider_meta) if isinstance(provider_meta, dict) else {}
|
||||
|
||||
|
||||
def _load_llm_models_inner(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
provider = str(payload.get("provider") or "").strip().lower()
|
||||
if not provider:
|
||||
raise RuntimeError("缺少 LLM provider")
|
||||
|
||||
provider_module = _load_llm_provider_module()
|
||||
api_key = str(payload.get("api_key") or "").strip() or None
|
||||
base_url = str(payload.get("base_url") or "").strip() or None
|
||||
models = asyncio.run(
|
||||
provider_module.LLMProviderManager().list_models(
|
||||
provider_id=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
force_refresh=False,
|
||||
)
|
||||
)
|
||||
return models if isinstance(models, list) else []
|
||||
|
||||
|
||||
def _load_llm_models(
|
||||
*,
|
||||
provider: str,
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
payload = {
|
||||
"provider": str(provider or "").strip().lower(),
|
||||
"api_key": str(api_key or "").strip(),
|
||||
"base_url": str(base_url or "").strip(),
|
||||
}
|
||||
try:
|
||||
return _load_llm_models_inner(payload)
|
||||
except Exception as exc:
|
||||
if runtime_python and not _current_python_matches(runtime_python):
|
||||
try:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
request_path = Path(temp_dir) / "llm-models-request.json"
|
||||
output_path = Path(temp_dir) / "llm-models.json"
|
||||
request_path.write_text(
|
||||
json.dumps(payload, ensure_ascii=False), encoding="utf-8"
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
str(runtime_python),
|
||||
str(Path(__file__).resolve()),
|
||||
"query-llm-models",
|
||||
"--request-json-file",
|
||||
str(request_path),
|
||||
"--output-json-file",
|
||||
str(output_path),
|
||||
],
|
||||
cwd=str(ROOT),
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
data = json.loads(output_path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except Exception as runtime_exc:
|
||||
print_step(
|
||||
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{runtime_exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
print_step(
|
||||
f"当前环境暂时无法获取 {payload['provider']} 模型目录,已回退为手动输入模型名称:{exc}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _print_llm_models(models: list[dict[str, Any]], limit: int = 20) -> None:
|
||||
print("可用模型:")
|
||||
for index, item in enumerate(models[:limit], start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
model_id = str(item.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
model_name = str(item.get("name") or model_id).strip()
|
||||
extras: list[str] = []
|
||||
if item.get("context_tokens_k"):
|
||||
extras.append(f"{item['context_tokens_k']}K")
|
||||
if item.get("supports_reasoning"):
|
||||
extras.append("reasoning")
|
||||
if item.get("supports_tools"):
|
||||
extras.append("tools")
|
||||
if item.get("supports_image_input"):
|
||||
extras.append("vision")
|
||||
extra_text = f" [{' / '.join(extras)}]" if extras else ""
|
||||
if model_name != model_id:
|
||||
print(f" {index}. {model_id} ({model_name}){extra_text}")
|
||||
else:
|
||||
print(f" {index}. {model_id}{extra_text}")
|
||||
if len(models) > limit:
|
||||
print(f" ... 共 {len(models)} 个模型,可输入编号或直接输入模型名称")
|
||||
|
||||
|
||||
def _prompt_model_choice(models: list[dict[str, Any]], default: Optional[str] = None) -> str:
|
||||
valid_models = [item for item in models if isinstance(item, dict) and item.get("id")]
|
||||
if not valid_models:
|
||||
return _prompt_text("LLM 模型名称", default=default)
|
||||
|
||||
indexed_models = {
|
||||
str(index): str(item.get("id")).strip()
|
||||
for index, item in enumerate(valid_models, start=1)
|
||||
}
|
||||
default_model = str(default or indexed_models.get("1") or "").strip()
|
||||
_print_llm_models(valid_models)
|
||||
|
||||
while True:
|
||||
raw = input(
|
||||
f"LLM 模型名称/编号{' [' + default_model + ']' if default_model else ''}: "
|
||||
).strip()
|
||||
if not raw and default_model:
|
||||
return default_model
|
||||
if raw in indexed_models:
|
||||
return indexed_models[raw]
|
||||
if raw:
|
||||
return raw
|
||||
print("请输入有效模型编号或模型名称。")
|
||||
|
||||
|
||||
def _env_llm_thinking_level_default() -> str:
|
||||
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
|
||||
alias_map = {
|
||||
@@ -1459,7 +1747,9 @@ def _collect_notification_config() -> Optional[dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _collect_agent_config() -> dict[str, Any]:
|
||||
def _collect_agent_config(
|
||||
runtime_python: Optional[Path] = None,
|
||||
) -> dict[str, Any]:
|
||||
print_step("AI Agent 配置")
|
||||
enabled = _prompt_yes_no(
|
||||
"是否启用 AI 智能体",
|
||||
@@ -1471,22 +1761,42 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
"AI_AGENT_GLOBAL": False,
|
||||
}
|
||||
|
||||
provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python)
|
||||
provider_choices = _llm_provider_choice_map(provider_definitions)
|
||||
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
|
||||
if current_provider not in LLM_PROVIDER_DEFAULTS:
|
||||
if current_provider not in provider_choices:
|
||||
current_provider = "deepseek"
|
||||
|
||||
provider = _prompt_choice(
|
||||
"选择 LLM 提供商",
|
||||
{
|
||||
"deepseek": "DeepSeek",
|
||||
"openai": "OpenAI",
|
||||
"google": "Google",
|
||||
},
|
||||
default=current_provider,
|
||||
)
|
||||
defaults = LLM_PROVIDER_DEFAULTS[provider]
|
||||
while True:
|
||||
provider = _prompt_provider_choice(
|
||||
"选择 LLM 提供商",
|
||||
provider_choices,
|
||||
default=current_provider,
|
||||
)
|
||||
provider_meta = _llm_provider_meta(provider, provider_definitions)
|
||||
if provider_meta.get("supports_api_key") is False:
|
||||
print_step(
|
||||
f"{provider_meta.get('name') or provider} 当前仅支持交互式授权,安装向导暂不支持,请改选可填写 API Key 的 provider。"
|
||||
)
|
||||
current_provider = "deepseek"
|
||||
continue
|
||||
break
|
||||
|
||||
defaults = _llm_provider_defaults(provider, provider_definitions)
|
||||
current_model = _env_default("LLM_MODEL", defaults["model"])
|
||||
current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"])
|
||||
api_key_label = str(provider_meta.get("api_key_label") or "API Key").strip() or "API Key"
|
||||
api_key_hint = str(provider_meta.get("api_key_hint") or "").strip()
|
||||
requires_base_url = bool(provider_meta.get("requires_base_url"))
|
||||
base_url_label = (
|
||||
"自定义 Google API Base URL(可选)"
|
||||
if provider == "google"
|
||||
else "LLM Base URL(必填)"
|
||||
if requires_base_url
|
||||
else "LLM Base URL"
|
||||
)
|
||||
if api_key_hint:
|
||||
print_step(api_key_hint)
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"AI_AGENT_ENABLE": True,
|
||||
@@ -1495,12 +1805,8 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
default=_env_bool("AI_AGENT_GLOBAL", False),
|
||||
),
|
||||
"LLM_PROVIDER": provider,
|
||||
"LLM_MODEL": _prompt_text(
|
||||
"LLM 模型名称",
|
||||
default=current_model,
|
||||
),
|
||||
"LLM_API_KEY": _prompt_secret_text(
|
||||
"LLM API Key",
|
||||
f"LLM {api_key_label}",
|
||||
current_value=read_env_value("LLM_API_KEY"),
|
||||
required=True,
|
||||
),
|
||||
@@ -1524,18 +1830,18 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
),
|
||||
}
|
||||
|
||||
if provider == "google":
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
"自定义 Google API Base URL(可选)",
|
||||
default=current_base_url,
|
||||
allow_empty=True,
|
||||
)
|
||||
else:
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
"LLM Base URL",
|
||||
default=current_base_url,
|
||||
allow_empty=True,
|
||||
)
|
||||
config["LLM_BASE_URL"] = _prompt_text(
|
||||
base_url_label,
|
||||
default=current_base_url,
|
||||
allow_empty=not requires_base_url,
|
||||
)
|
||||
models = _load_llm_models(
|
||||
provider=provider,
|
||||
api_key=config["LLM_API_KEY"],
|
||||
base_url=config["LLM_BASE_URL"],
|
||||
runtime_python=runtime_python,
|
||||
)
|
||||
config["LLM_MODEL"] = _prompt_model_choice(models, default=current_model)
|
||||
|
||||
return config
|
||||
|
||||
@@ -1744,7 +2050,7 @@ def run_setup_wizard(
|
||||
preset_password=preset_superuser_password,
|
||||
),
|
||||
**_collect_database_config(),
|
||||
**_collect_agent_config(),
|
||||
**_collect_agent_config(runtime_python=runtime_python),
|
||||
},
|
||||
"directories": [_collect_directory_config()],
|
||||
"downloader": _collect_downloader_config(),
|
||||
@@ -3276,6 +3582,23 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
query_llm_providers_parser = subparsers.add_parser(
|
||||
"query-llm-providers", help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_providers_parser.add_argument(
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
query_llm_models_parser = subparsers.add_parser(
|
||||
"query-llm-models", help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_models_parser.add_argument(
|
||||
"--request-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
query_llm_models_parser.add_argument(
|
||||
"--output-json-file", required=True, help=argparse.SUPPRESS
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -3461,6 +3784,25 @@ def main() -> int:
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "query-llm-providers":
|
||||
payload = _load_llm_provider_definitions_inner()
|
||||
Path(args.output_json_file).write_text(
|
||||
json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "query-llm-models":
|
||||
payload = json.loads(Path(args.request_json_file).read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("模型查询负载格式错误")
|
||||
models = _load_llm_models_inner(payload)
|
||||
Path(args.output_json_file).write_text(
|
||||
json.dumps(models, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return 0
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"命令执行失败,退出码:{exc.returncode}", file=sys.stderr)
|
||||
return exc.returncode
|
||||
|
||||
Reference in New Issue
Block a user