mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-23 07:26:46 +00:00
feat: add extensible agent audio capabilities
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
"""Agent 内部使用的 LLM 适配层。"""
|
||||
|
||||
from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.agent.llm.capability import (
|
||||
AgentCapabilityManager,
|
||||
AgentCapabilityProvider,
|
||||
AudioCapabilityProvider,
|
||||
MiMoAudioProvider,
|
||||
OpenAIChatAudioProvider,
|
||||
OpenAIAudioProvider,
|
||||
)
|
||||
from app.agent.llm.provider import (
|
||||
LLMProviderAuthError,
|
||||
LLMProviderError,
|
||||
@@ -10,10 +18,16 @@ from app.agent.llm.provider import (
|
||||
|
||||
__all__ = [
|
||||
"LLMHelper",
|
||||
"AgentCapabilityManager",
|
||||
"AgentCapabilityProvider",
|
||||
"AudioCapabilityProvider",
|
||||
"LLMProviderAuthError",
|
||||
"LLMProviderError",
|
||||
"LLMProviderManager",
|
||||
"LLMTestError",
|
||||
"LLMTestTimeout",
|
||||
"MiMoAudioProvider",
|
||||
"OpenAIChatAudioProvider",
|
||||
"OpenAIAudioProvider",
|
||||
"render_auth_result_html",
|
||||
]
|
||||
|
||||
528
app/agent/llm/capability.py
Normal file
528
app/agent/llm/capability.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Agent 多模态能力 provider 与调度入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import shutil
|
||||
import subprocess
|
||||
from abc import ABC
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentCapabilityProvider(ABC):
|
||||
"""Agent 能力 provider 基类,后续图片等能力可继续扩展到这里。"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class AudioCapabilityProvider(AgentCapabilityProvider):
|
||||
"""音频输入/输出能力 provider。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
"""是否可用于音频输入转写。"""
|
||||
return False
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
"""是否可用于语音合成输出。"""
|
||||
return False
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
raise NotImplementedError
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字合成为可发送的音频文件。"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIAudioProvider(AudioCapabilityProvider):
|
||||
"""OpenAI / OpenAI-compatible 音频 provider。"""
|
||||
|
||||
name = "openai"
|
||||
|
||||
@staticmethod
|
||||
def _build_client(api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AUDIO_INPUT_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
voice=settings.AUDIO_OUTPUT_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIChatAudioProvider(AudioCapabilityProvider):
|
||||
"""通过 OpenAI Chat Completions 兼容接口传入/返回音频的 provider。"""
|
||||
|
||||
name = "openai_chat_audio"
|
||||
DISPLAY_NAME = "OpenAI Chat Audio"
|
||||
DEFAULT_BASE_URL: Optional[str] = None
|
||||
DEFAULT_STT_MODEL: Optional[str] = None
|
||||
DEFAULT_TTS_MODEL: Optional[str] = None
|
||||
DEFAULT_VOICE = "alloy"
|
||||
AUDIO_RESPONSE_FORMAT = "wav"
|
||||
AUDIO_INPUT_DATA_URL = False
|
||||
INCLUDE_AUDIO_MODALITIES = True
|
||||
TTS_MESSAGE_ROLE = "user"
|
||||
SUPPORTED_STT_MODELS: Optional[frozenset[str]] = None
|
||||
SUPPORTED_TTS_MODELS: Optional[frozenset[str]] = None
|
||||
UNSUPPORTED_TTS_MODELS = frozenset()
|
||||
SUPPORTED_AUDIO_MIME_TYPES = {
|
||||
".flac": "audio/flac",
|
||||
".m4a": "audio/mp4",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/ogg",
|
||||
".wav": "audio/wav",
|
||||
}
|
||||
|
||||
def _build_client(self, api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def _normalize_stt_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_STT_MODELS,
|
||||
default_model=self.DEFAULT_STT_MODEL,
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_TTS_MODELS,
|
||||
default_model=self.DEFAULT_TTS_MODEL,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model(
|
||||
model: Optional[str],
|
||||
supported_models: Optional[frozenset[str]],
|
||||
default_model: Optional[str],
|
||||
) -> str:
|
||||
model = (model or "").strip()
|
||||
if not model:
|
||||
return default_model or ""
|
||||
if supported_models is None:
|
||||
return model
|
||||
model_key = model.lower()
|
||||
if model_key in supported_models:
|
||||
return model_key
|
||||
return default_model or model
|
||||
|
||||
def _is_supported_tts_model(self) -> bool:
|
||||
model = self._normalize_tts_model()
|
||||
if not model:
|
||||
return False
|
||||
model_key = model.lower()
|
||||
if model_key in self.UNSUPPORTED_TTS_MODELS:
|
||||
return False
|
||||
return self.SUPPORTED_TTS_MODELS is None or model_key in self.SUPPORTED_TTS_MODELS
|
||||
|
||||
@classmethod
|
||||
def _guess_audio_mime_type(cls, filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower()
|
||||
if suffix in cls.SUPPORTED_AUDIO_MIME_TYPES:
|
||||
return cls.SUPPORTED_AUDIO_MIME_TYPES[suffix]
|
||||
mime_type, _ = mimetypes.guess_type(filename or "")
|
||||
return mime_type or "audio/ogg"
|
||||
|
||||
@staticmethod
|
||||
def _guess_audio_format(filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower().lstrip(".")
|
||||
if suffix == "opus":
|
||||
return "ogg"
|
||||
return suffix or "ogg"
|
||||
|
||||
def _build_audio_input_payload(self, content: bytes, filename: str) -> dict:
|
||||
"""按不同 Chat Audio 兼容形态构造 input_audio 内容。"""
|
||||
audio_data = base64.b64encode(content).decode("utf-8")
|
||||
if self.AUDIO_INPUT_DATA_URL:
|
||||
mime_type = self._guess_audio_mime_type(filename)
|
||||
return {"data": f"data:{mime_type};base64,{audio_data}"}
|
||||
return {
|
||||
"data": audio_data,
|
||||
"format": self._guess_audio_format(filename),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_text(message) -> Optional[str]:
|
||||
"""兼容音频理解响应可能放在 content 或 reasoning_content 的情况。"""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content.strip()
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
if isinstance(reasoning_content, str) and reasoning_content.strip():
|
||||
return reasoning_content.strip()
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict):
|
||||
for key in ("content", "reasoning_content"):
|
||||
value = extra.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_data(message) -> Optional[str]:
|
||||
audio = getattr(message, "audio", None)
|
||||
if isinstance(audio, dict):
|
||||
return audio.get("data")
|
||||
if audio is not None:
|
||||
return getattr(audio, "data", None)
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict) and isinstance(extra.get("audio"), dict):
|
||||
return extra["audio"].get("data")
|
||||
return None
|
||||
|
||||
def _convert_wav_to_opus(self, wav_path: Path) -> Optional[Path]:
|
||||
"""将 Chat Audio 返回的 WAV 转成 OGG/Opus,便于各通知渠道发送语音。"""
|
||||
if not shutil.which("ffmpeg"):
|
||||
return None
|
||||
|
||||
output_path = wav_path.with_suffix(".opus")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(wav_path),
|
||||
"-ar",
|
||||
"48000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-c:a",
|
||||
"libopus",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0 or not output_path.exists():
|
||||
logger.warning(
|
||||
"%s TTS 音频转 Opus 失败,将使用 WAV 原文件: returncode=%s, stderr=%s",
|
||||
self.DISPLAY_NAME,
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
return output_path
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key) and self._is_supported_tts_model()
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
language = (settings.AUDIO_INPUT_LANGUAGE or "").strip()
|
||||
prompt = "请将这段音频完整转写为文字,只输出转写结果,不要添加解释。"
|
||||
if language:
|
||||
prompt += f"音频主要语言是 {language}。"
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=self._normalize_stt_model(),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": self._build_audio_input_payload(
|
||||
content=content, filename=filename
|
||||
),
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_completion_tokens=2048,
|
||||
)
|
||||
return self._extract_message_text(completion.choices[0].message)
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
if not self._is_supported_tts_model():
|
||||
logger.error(
|
||||
"%s TTS 当前不支持该模型或模型未配置: %s",
|
||||
self.DISPLAY_NAME,
|
||||
settings.AUDIO_OUTPUT_MODEL,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = voice_dir / f"{uuid4().hex}.wav"
|
||||
request = {
|
||||
"model": self._normalize_tts_model(),
|
||||
"messages": [
|
||||
{
|
||||
"role": self.TTS_MESSAGE_ROLE,
|
||||
"content": text,
|
||||
}
|
||||
],
|
||||
"audio": {
|
||||
"format": self.AUDIO_RESPONSE_FORMAT,
|
||||
"voice": settings.AUDIO_OUTPUT_VOICE or self.DEFAULT_VOICE,
|
||||
},
|
||||
}
|
||||
if self.INCLUDE_AUDIO_MODALITIES:
|
||||
request["modalities"] = ["text", "audio"]
|
||||
completion = client.chat.completions.create(**request)
|
||||
audio_data = self._extract_audio_data(completion.choices[0].message)
|
||||
if not audio_data:
|
||||
raise ValueError(f"{self.DISPLAY_NAME} TTS 响应中没有音频数据")
|
||||
|
||||
wav_path.write_bytes(base64.b64decode(audio_data))
|
||||
return self._convert_wav_to_opus(wav_path) or wav_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class MiMoAudioProvider(OpenAIChatAudioProvider):
|
||||
"""Xiaomi MiMo Chat Audio 预设,仅接入普通 STT/TTS 能力。"""
|
||||
|
||||
name = "mimo"
|
||||
DISPLAY_NAME = "Xiaomi MiMo"
|
||||
DEFAULT_BASE_URL = "https://api.xiaomimimo.com/v1"
|
||||
DEFAULT_STT_MODEL = "mimo-v2.5"
|
||||
DEFAULT_TTS_MODEL = "mimo-v2.5-tts"
|
||||
DEFAULT_VOICE = "mimo_default"
|
||||
AUDIO_INPUT_DATA_URL = True
|
||||
INCLUDE_AUDIO_MODALITIES = False
|
||||
TTS_MESSAGE_ROLE = "assistant"
|
||||
SUPPORTED_STT_MODELS = frozenset({"mimo-v2.5", "mimo-v2-omni"})
|
||||
SUPPORTED_TTS_MODELS = frozenset({DEFAULT_TTS_MODEL})
|
||||
UNSUPPORTED_TTS_MODELS = frozenset(
|
||||
{
|
||||
"mimo-v2.5-tts-voiceclone",
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
}
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
model = (settings.AUDIO_OUTPUT_MODEL or "").strip().lower()
|
||||
if not model or not model.startswith("mimo-"):
|
||||
return self.DEFAULT_TTS_MODEL
|
||||
return model
|
||||
|
||||
|
||||
class AgentCapabilityManager:
|
||||
"""Agent 能力统一入口。"""
|
||||
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
_audio_providers: Dict[str, AudioCapabilityProvider] = {
|
||||
OpenAIAudioProvider.name: OpenAIAudioProvider(),
|
||||
OpenAIChatAudioProvider.name: OpenAIChatAudioProvider(),
|
||||
MiMoAudioProvider.name: MiMoAudioProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_audio_provider(cls, provider: AudioCapabilityProvider) -> None:
|
||||
"""注册新的音频 provider。"""
|
||||
cls._audio_providers[provider.name.lower()] = provider
|
||||
|
||||
@classmethod
|
||||
def get_registered_audio_providers(cls) -> list[str]:
|
||||
"""返回已注册的音频 provider 名称。"""
|
||||
return sorted(cls._audio_providers.keys())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider_name(provider: Optional[str]) -> str:
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_audio_provider(cls, mode: str) -> Optional[AudioCapabilityProvider]:
|
||||
provider_name = cls._normalize_provider_name(
|
||||
settings.AUDIO_INPUT_PROVIDER
|
||||
if (mode or "").lower() == "input"
|
||||
else settings.AUDIO_OUTPUT_PROVIDER
|
||||
)
|
||||
provider = cls._audio_providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning("未注册音频 provider: mode=%s, provider=%s", mode, provider_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""当前 Agent 是否启用图片输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_input() -> bool:
|
||||
"""当前 Agent 是否启用音频输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_output() -> bool:
|
||||
"""当前 Agent 是否启用音频输出能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_OUTPUT)
|
||||
|
||||
@classmethod
|
||||
def is_audio_input_available(cls) -> bool:
|
||||
if not cls.supports_audio_input():
|
||||
return False
|
||||
provider = cls.get_audio_provider("input")
|
||||
return bool(provider and provider.is_available_for_audio_input())
|
||||
|
||||
@classmethod
|
||||
def is_audio_output_available(cls) -> bool:
|
||||
if not cls.supports_audio_output():
|
||||
return False
|
||||
provider = cls.get_audio_provider("output")
|
||||
return bool(provider and provider.is_available_for_audio_output())
|
||||
|
||||
@classmethod
|
||||
def transcribe_audio(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
provider = cls.get_audio_provider("input")
|
||||
if not provider or not cls.is_audio_input_available():
|
||||
return None
|
||||
return provider.transcribe_audio(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
provider = cls.get_audio_provider("output")
|
||||
if not provider or not cls.is_audio_output_available():
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""仅在支持原生语音回复的渠道上发送音频,其余渠道回退文字。"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""判断当前渠道是否支持原生语音消息发送。"""
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
try:
|
||||
channel_enum = MessageChannel(channel)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Telegram:
|
||||
return True
|
||||
if channel_enum != MessageChannel.Wechat:
|
||||
return False
|
||||
|
||||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
@@ -886,7 +886,6 @@ class LLMHelper:
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_google_models(api_key or "")
|
||||
]
|
||||
model_list_base_url = base_url
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import (
|
||||
@@ -327,10 +328,12 @@ class PromptManager:
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions() -> str:
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
return "Audio output is disabled; do not call `send_voice_message`."
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when the user explicitly asks for a voice reply "
|
||||
"or spoken playback is clearly better than plain text."
|
||||
"Use normal text replies by default. Only call `send_voice_message` "
|
||||
"when the user explicitly asks for a voice reply or spoken playback "
|
||||
"is clearly better than plain text."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -77,6 +77,7 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.agent.tools.impl.query_system_settings import QuerySystemSettingsTool
|
||||
from app.agent.tools.impl.update_system_settings import UpdateSystemSettingsTool
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -225,12 +226,9 @@ class MoviePilotToolFactory:
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
tool_definitions.append(SendLocalFileTool)
|
||||
if AgentCapabilityManager.supports_audio_output():
|
||||
tool_definitions.append(SendVoiceMessageTool)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
|
||||
@@ -50,22 +50,24 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
reply_mode = VoiceHelper.resolve_reply_mode(
|
||||
reply_mode = AgentCapabilityManager.resolve_reply_mode(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
)
|
||||
fallback_reason = "当前渠道不支持语音回复"
|
||||
if not VoiceHelper.is_enabled():
|
||||
fallback_reason = "当前未启用音频输入输出"
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
fallback_reason = "当前未启用音频输出"
|
||||
if (
|
||||
reply_mode == VoiceHelper.REPLY_MODE_NATIVE
|
||||
and VoiceHelper.is_available("tts")
|
||||
reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE
|
||||
and AgentCapabilityManager.is_audio_output_available()
|
||||
):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
voice_file = await asyncio.to_thread(
|
||||
AgentCapabilityManager.synthesize_speech, message
|
||||
)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE:
|
||||
elif reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE:
|
||||
fallback_reason = "当前未配置可用的语音合成能力"
|
||||
|
||||
logger.info(
|
||||
@@ -87,7 +89,7 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path=voice_path,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
|
||||
if voice_path and settings.AUDIO_OUTPUT_INCLUDE_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -476,7 +476,8 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"AI_AGENT_ENABLE",
|
||||
"LLM_SUPPORT_AUDIO_INPUT_OUTPUT",
|
||||
"LLM_SUPPORT_AUDIO_INPUT",
|
||||
"LLM_SUPPORT_AUDIO_OUTPUT",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
@@ -486,7 +487,8 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False
|
||||
info["LLM_SUPPORT_AUDIO_INPUT"] = False
|
||||
info["LLM_SUPPORT_AUDIO_OUTPUT"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any, Optional, Dict, Union, List, Tuple
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from app.agent import ReplyMode, agent_manager, prompt_manager
|
||||
from app.agent.llm import LLMHelper
|
||||
from app.agent.llm import AgentCapabilityManager, LLMHelper
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -29,7 +29,6 @@ from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.interaction import agent_interaction_manager, media_interaction_manager, PendingMediaInteraction
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, CommingMessage, NotExistMediaInfo
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -1196,8 +1195,8 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
if not audio_refs:
|
||||
return None
|
||||
if not VoiceHelper.is_available("stt"):
|
||||
logger.warning("语音能力未配置,跳过语音识别")
|
||||
if not AgentCapabilityManager.is_audio_input_available():
|
||||
logger.warning("音频输入能力未配置或未启用,跳过语音识别")
|
||||
return None
|
||||
|
||||
transcripts = []
|
||||
@@ -1303,7 +1302,7 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
continue
|
||||
|
||||
transcript = VoiceHelper.transcribe_bytes(
|
||||
transcript = AgentCapabilityManager.transcribe_audio(
|
||||
content=content, filename=filename
|
||||
)
|
||||
if transcript:
|
||||
|
||||
@@ -541,8 +541,10 @@ class ConfigModel(BaseModel):
|
||||
LLM_THINKING_LEVEL: Optional[str] = "off"
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复
|
||||
LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False
|
||||
# 是否启用音频输入,开启后用户语音会先转写为文本再进入 Agent
|
||||
LLM_SUPPORT_AUDIO_INPUT: bool = False
|
||||
# 是否启用音频输出,开启后 Agent 可在支持渠道发送语音回复
|
||||
LLM_SUPPORT_AUDIO_OUTPUT: bool = False
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
@@ -589,22 +591,28 @@ class ConfigModel(BaseModel):
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai/openai-compatible)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音能力共享基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
|
||||
# TTS 发音人
|
||||
AI_VOICE_TTS_VOICE: str = "alloy"
|
||||
# 语音识别语言
|
||||
AI_VOICE_LANGUAGE: str = "zh"
|
||||
# 音频输入提供商:openai/openai_chat_audio/mimo
|
||||
AUDIO_INPUT_PROVIDER: str = "openai"
|
||||
# 音频输入 API 密钥
|
||||
AUDIO_INPUT_API_KEY: Optional[str] = None
|
||||
# 音频输入基础URL
|
||||
AUDIO_INPUT_BASE_URL: Optional[str] = None
|
||||
# 音频输入模型
|
||||
AUDIO_INPUT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 音频输入识别语言
|
||||
AUDIO_INPUT_LANGUAGE: str = "zh"
|
||||
# 音频输出提供商:openai/openai_chat_audio/mimo
|
||||
AUDIO_OUTPUT_PROVIDER: str = "openai"
|
||||
# 音频输出 API 密钥
|
||||
AUDIO_OUTPUT_API_KEY: Optional[str] = None
|
||||
# 音频输出基础URL
|
||||
AUDIO_OUTPUT_BASE_URL: Optional[str] = None
|
||||
# 音频输出模型
|
||||
AUDIO_OUTPUT_MODEL: str = "gpt-4o-mini-tts"
|
||||
# 音频输出音色/发音人
|
||||
AUDIO_OUTPUT_VOICE: str = "alloy"
|
||||
# 回复语音时是否同时附带文字说明
|
||||
AI_VOICE_REPLY_WITH_TEXT: bool = False
|
||||
AUDIO_OUTPUT_INCLUDE_TEXT: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -824,7 +832,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
|
||||
try:
|
||||
field = Settings.model_fields[key]
|
||||
field = Settings.model_fields.get(key)
|
||||
if not field:
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
original_value = getattr(self, key)
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
"""语音能力辅助功能。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class VoiceProvider(ABC):
|
||||
"""语音 provider 抽象层。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""provider 名称。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_stt(self) -> bool:
|
||||
"""是否可用于语音识别。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_tts(self) -> bool:
|
||||
"""是否可用于语音合成。"""
|
||||
|
||||
@abstractmethod
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字转成语音文件。"""
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProvider):
|
||||
"""OpenAI / OpenAI-compatible provider。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name() -> str:
|
||||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||||
return provider.strip().lower()
|
||||
|
||||
def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]:
|
||||
provider = self._resolve_provider_name()
|
||||
api_key = settings.AI_VOICE_API_KEY
|
||||
base_url = settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
and provider == "openai"
|
||||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||||
):
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = base_url or settings.LLM_BASE_URL
|
||||
|
||||
return api_key, base_url
|
||||
|
||||
def _get_client(self, mode: str):
|
||||
from openai import OpenAI
|
||||
|
||||
api_key, base_url = self._resolve_credentials()
|
||||
if not api_key:
|
||||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
def is_available_for_stt(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_tts(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
client = self._get_client("stt")
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AI_VOICE_STT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self._get_client("tts")
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AI_VOICE_TTS_MODEL,
|
||||
voice=settings.AI_VOICE_TTS_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def is_enabled() -> bool:
|
||||
"""音频输入输出总开关,以显式配置为准。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name() -> str:
|
||||
"""标准化当前配置的语音 provider 名称。"""
|
||||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||||
return provider.strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
provider_name = cls._resolve_provider_name()
|
||||
provider = cls._providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_registered_providers(cls) -> list[str]:
|
||||
return sorted(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if not cls.is_enabled():
|
||||
return False
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
return False
|
||||
return (
|
||||
provider.is_available_for_stt()
|
||||
if mode.lower() == "stt"
|
||||
else provider.is_available_for_tts()
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前渠道是否支持原生语音消息发送。
|
||||
"""
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
try:
|
||||
channel_enum = MessageChannel(channel)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Telegram:
|
||||
return True
|
||||
if channel_enum != MessageChannel.Wechat:
|
||||
return False
|
||||
|
||||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""
|
||||
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
|
||||
"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.transcribe_bytes(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
@@ -12,10 +12,10 @@ from telebot import apihelper
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.agent.llm import AgentCapabilityManager
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.agent.llm import LLMHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.modules.discord import DiscordModule
|
||||
from app.modules.qqbot import QQBotModule
|
||||
from app.modules.qqbot.qqbot import QQBot
|
||||
@@ -284,13 +284,15 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
"feishu://file/om_audio/file_audio/voice.opus",
|
||||
]
|
||||
|
||||
with patch.object(VoiceHelper, "is_available", return_value=True), patch.object(
|
||||
with patch.object(
|
||||
AgentCapabilityManager, "is_audio_input_available", return_value=True
|
||||
), patch.object(
|
||||
chain,
|
||||
"run_module",
|
||||
side_effect=[b"slack", b"discord", b"qq", b"vocechat", b"synology", b"feishu"],
|
||||
) as run_module, patch.object(
|
||||
VoiceHelper,
|
||||
"transcribe_bytes",
|
||||
AgentCapabilityManager,
|
||||
"transcribe_audio",
|
||||
side_effect=[
|
||||
"slack text",
|
||||
"discord text",
|
||||
|
||||
235
tests/test_agent_llm_capability.py
Normal file
235
tests/test_agent_llm_capability.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import sys
|
||||
import unittest
|
||||
import importlib.util
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
sys.modules.setdefault("pyquery", Mock())
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
module_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "capability.py"
|
||||
spec = importlib.util.spec_from_file_location("test_agent_llm_capability_module", module_path)
|
||||
capability_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
sys.modules[spec.name] = capability_module
|
||||
spec.loader.exec_module(capability_module)
|
||||
|
||||
AgentCapabilityManager = capability_module.AgentCapabilityManager
|
||||
MiMoAudioProvider = capability_module.MiMoAudioProvider
|
||||
OpenAIChatAudioProvider = capability_module.OpenAIChatAudioProvider
|
||||
OpenAIAudioProvider = capability_module.OpenAIAudioProvider
|
||||
|
||||
|
||||
class AgentCapabilityManagerTest(unittest.TestCase):
|
||||
def test_registered_audio_providers_contains_builtin_providers(self):
|
||||
self.assertIn("openai", AgentCapabilityManager.get_registered_audio_providers())
|
||||
self.assertIn(
|
||||
"openai_chat_audio", AgentCapabilityManager.get_registered_audio_providers()
|
||||
)
|
||||
self.assertIn("mimo", AgentCapabilityManager.get_registered_audio_providers())
|
||||
|
||||
def test_get_audio_provider_uses_separate_input_and_output_settings(self):
|
||||
with patch.object(settings, "AUDIO_INPUT_PROVIDER", "openai"), patch.object(
|
||||
settings, "AUDIO_OUTPUT_PROVIDER", "mimo"
|
||||
):
|
||||
self.assertIsInstance(
|
||||
AgentCapabilityManager.get_audio_provider("input"), OpenAIAudioProvider
|
||||
)
|
||||
self.assertIsInstance(
|
||||
AgentCapabilityManager.get_audio_provider("output"), MiMoAudioProvider
|
||||
)
|
||||
|
||||
def test_chat_audio_provider_keeps_arbitrary_compatible_models(self):
|
||||
provider = OpenAIChatAudioProvider()
|
||||
|
||||
with patch.object(
|
||||
settings, "AUDIO_INPUT_MODEL", "vendor-omni-audio"
|
||||
), patch.object(settings, "AUDIO_OUTPUT_MODEL", "vendor-tts-audio"):
|
||||
self.assertEqual(provider._normalize_stt_model(), "vendor-omni-audio")
|
||||
self.assertEqual(provider._normalize_tts_model(), "vendor-tts-audio")
|
||||
|
||||
def test_chat_audio_provider_uses_openai_audio_payload_shape(self):
|
||||
provider = OpenAIChatAudioProvider()
|
||||
fake_client = Mock()
|
||||
fake_client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="你好"))]
|
||||
)
|
||||
|
||||
with patch.object(provider, "_build_client", return_value=fake_client), patch.object(
|
||||
settings, "AUDIO_INPUT_MODEL", "gpt-4o-audio-preview"
|
||||
), patch.object(settings, "AUDIO_INPUT_LANGUAGE", "zh"), patch.object(
|
||||
settings, "AUDIO_INPUT_API_KEY", "sk-test"
|
||||
), patch.object(settings, "AUDIO_INPUT_BASE_URL", "https://example.com/v1"):
|
||||
result = provider.transcribe_audio(b"audio-bytes", filename="input.wav")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
request = fake_client.chat.completions.create.call_args.kwargs
|
||||
content = request["messages"][0]["content"]
|
||||
self.assertEqual(
|
||||
content[0]["input_audio"],
|
||||
{"data": b64encode(b"audio-bytes").decode("utf-8"), "format": "wav"},
|
||||
)
|
||||
|
||||
def test_chat_audio_provider_requests_audio_modality_for_tts(self):
|
||||
provider = OpenAIChatAudioProvider()
|
||||
fake_client = Mock()
|
||||
audio_data = b64encode(b"wav-bytes").decode("utf-8")
|
||||
fake_client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(audio={"data": audio_data}))]
|
||||
)
|
||||
|
||||
with TemporaryDirectory() as temp_dir, patch.object(
|
||||
provider, "_build_client", return_value=fake_client
|
||||
), patch.object(
|
||||
capability_module,
|
||||
"settings",
|
||||
SimpleNamespace(
|
||||
TEMP_PATH=Path(temp_dir),
|
||||
AUDIO_OUTPUT_MODEL="gpt-4o-audio-preview",
|
||||
AUDIO_OUTPUT_VOICE="alloy",
|
||||
AUDIO_OUTPUT_API_KEY="sk-test",
|
||||
AUDIO_OUTPUT_BASE_URL="https://example.com/v1",
|
||||
),
|
||||
), patch.object(provider, "_convert_wav_to_opus", return_value=None):
|
||||
output_path = provider.synthesize_speech("你好")
|
||||
|
||||
self.assertIsNotNone(output_path)
|
||||
request = fake_client.chat.completions.create.call_args.kwargs
|
||||
self.assertEqual(request["messages"][0]["role"], "user")
|
||||
self.assertEqual(request["modalities"], ["text", "audio"])
|
||||
self.assertEqual(request["audio"], {"format": "wav", "voice": "alloy"})
|
||||
|
||||
def test_audio_input_and_output_switches_are_independent(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_audio_input.return_value = True
|
||||
provider.is_available_for_audio_output.return_value = True
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT", True
|
||||
), patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_OUTPUT", False
|
||||
), patch.object(
|
||||
AgentCapabilityManager, "get_audio_provider", return_value=provider
|
||||
):
|
||||
self.assertTrue(AgentCapabilityManager.is_audio_input_available())
|
||||
self.assertFalse(AgentCapabilityManager.is_audio_output_available())
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT", False
|
||||
), patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_OUTPUT", True
|
||||
), patch.object(
|
||||
AgentCapabilityManager, "get_audio_provider", return_value=provider
|
||||
):
|
||||
self.assertFalse(AgentCapabilityManager.is_audio_input_available())
|
||||
self.assertTrue(AgentCapabilityManager.is_audio_output_available())
|
||||
|
||||
def test_transcribe_audio_routes_to_input_provider(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_audio_input.return_value = True
|
||||
provider.transcribe_audio.return_value = "你好"
|
||||
|
||||
with patch.object(settings, "LLM_SUPPORT_AUDIO_INPUT", True), patch.object(
|
||||
AgentCapabilityManager, "get_audio_provider", return_value=provider
|
||||
):
|
||||
result = AgentCapabilityManager.transcribe_audio(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
provider.transcribe_audio.assert_called_once()
|
||||
|
||||
def test_synthesize_speech_routes_to_output_provider(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_audio_output.return_value = True
|
||||
provider.synthesize_speech.return_value = Path("/tmp/reply.opus")
|
||||
|
||||
with patch.object(settings, "LLM_SUPPORT_AUDIO_OUTPUT", True), patch.object(
|
||||
AgentCapabilityManager, "get_audio_provider", return_value=provider
|
||||
):
|
||||
result = AgentCapabilityManager.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, Path("/tmp/reply.opus"))
|
||||
provider.synthesize_speech.assert_called_once_with(text="你好")
|
||||
|
||||
def test_mimo_tts_uses_chat_completions_audio_payload(self):
|
||||
provider = MiMoAudioProvider()
|
||||
fake_client = Mock()
|
||||
audio_data = b64encode(b"wav-bytes").decode("utf-8")
|
||||
fake_client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(audio={"data": audio_data}))]
|
||||
)
|
||||
|
||||
with TemporaryDirectory() as temp_dir, patch.object(
|
||||
provider, "_build_client", return_value=fake_client
|
||||
), patch.object(
|
||||
capability_module,
|
||||
"settings",
|
||||
SimpleNamespace(
|
||||
TEMP_PATH=Path(temp_dir),
|
||||
AUDIO_OUTPUT_MODEL="mimo-v2.5-tts",
|
||||
AUDIO_OUTPUT_VOICE="冰糖",
|
||||
AUDIO_OUTPUT_API_KEY="sk-test",
|
||||
AUDIO_OUTPUT_BASE_URL="https://api.xiaomimimo.com/v1",
|
||||
),
|
||||
), patch.object(provider, "_convert_wav_to_opus", return_value=None):
|
||||
output_path = provider.synthesize_speech("你好")
|
||||
output_bytes = output_path.read_bytes() if output_path else None
|
||||
|
||||
self.assertIsNotNone(output_path)
|
||||
self.assertEqual(output_bytes, b"wav-bytes")
|
||||
fake_client.chat.completions.create.assert_called_once()
|
||||
request = fake_client.chat.completions.create.call_args.kwargs
|
||||
self.assertEqual(request["model"], "mimo-v2.5-tts")
|
||||
self.assertEqual(request["messages"][0]["role"], "assistant")
|
||||
self.assertEqual(request["messages"][0]["content"], "你好")
|
||||
self.assertEqual(request["audio"], {"format": "wav", "voice": "冰糖"})
|
||||
|
||||
def test_mimo_tts_rejects_voice_design_and_clone_models(self):
|
||||
provider = MiMoAudioProvider()
|
||||
|
||||
with patch.object(
|
||||
settings, "AUDIO_OUTPUT_MODEL", "mimo-v2.5-tts-voiceclone"
|
||||
), patch.object(provider, "_build_client") as build_client:
|
||||
result = provider.synthesize_speech("你好")
|
||||
|
||||
self.assertIsNone(result)
|
||||
build_client.assert_not_called()
|
||||
|
||||
def test_mimo_stt_rejects_non_audio_mimo_models_by_falling_back(self):
|
||||
provider = MiMoAudioProvider()
|
||||
|
||||
with patch.object(settings, "AUDIO_INPUT_MODEL", "mimo-v2.5-pro"):
|
||||
self.assertEqual(provider._normalize_stt_model(), "mimo-v2.5")
|
||||
|
||||
def test_mimo_stt_uses_base64_audio_input(self):
|
||||
provider = MiMoAudioProvider()
|
||||
fake_client = Mock()
|
||||
fake_client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="你好"))]
|
||||
)
|
||||
|
||||
with patch.object(provider, "_build_client", return_value=fake_client), patch.object(
|
||||
settings, "AUDIO_INPUT_MODEL", "mimo-v2.5"
|
||||
), patch.object(settings, "AUDIO_INPUT_LANGUAGE", "zh"), patch.object(
|
||||
settings, "AUDIO_INPUT_API_KEY", "sk-test"
|
||||
), patch.object(
|
||||
settings, "AUDIO_INPUT_BASE_URL", "https://api.xiaomimimo.com/v1"
|
||||
):
|
||||
result = provider.transcribe_audio(b"audio-bytes", filename="input.wav")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
request = fake_client.chat.completions.create.call_args.kwargs
|
||||
content = request["messages"][0]["content"]
|
||||
self.assertEqual(request["model"], "mimo-v2.5")
|
||||
self.assertTrue(
|
||||
content[0]["input_audio"]["data"].startswith("data:audio/wav;base64,")
|
||||
)
|
||||
self.assertIn("只输出转写结果", content[1]["text"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
@@ -18,6 +19,10 @@ def _stub_module(name: str, **attrs):
|
||||
|
||||
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
_stub_module(
|
||||
"psutil",
|
||||
__spec__=importlib.machinery.ModuleSpec("psutil", loader=None),
|
||||
)
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent import ReplyMode
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
import unittest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
sys.modules.setdefault("pyquery", Mock())
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper, OpenAIVoiceProvider
|
||||
|
||||
|
||||
class VoiceHelperTest(unittest.TestCase):
|
||||
def test_registered_providers_contains_openai(self):
|
||||
self.assertIn("openai", VoiceHelper.get_registered_providers())
|
||||
|
||||
def test_get_provider_uses_single_audio_provider_setting(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"):
|
||||
provider = VoiceHelper.get_provider("stt")
|
||||
|
||||
self.assertIsInstance(provider, OpenAIVoiceProvider)
|
||||
|
||||
def test_is_available_checks_stt_and_tts_separately(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
provider.is_available_for_tts.return_value = False
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertTrue(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available("tts"))
|
||||
|
||||
def test_is_available_returns_false_when_audio_switch_is_disabled(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", False
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertFalse(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available())
|
||||
|
||||
def test_transcribe_bytes_routes_to_stt_provider(self):
|
||||
provider = Mock()
|
||||
provider.transcribe_bytes.return_value = "你好"
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.transcribe_bytes(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
provider.transcribe_bytes.assert_called_once()
|
||||
|
||||
def test_synthesize_speech_routes_to_tts_provider(self):
|
||||
provider = Mock()
|
||||
provider.synthesize_speech.return_value = "/tmp/reply.opus"
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, "/tmp/reply.opus")
|
||||
provider.synthesize_speech.assert_called_once_with(text="你好")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user