diff --git a/app/agent/llm/__init__.py b/app/agent/llm/__init__.py index bad76183..488aba3e 100644 --- a/app/agent/llm/__init__.py +++ b/app/agent/llm/__init__.py @@ -1,6 +1,14 @@ """Agent 内部使用的 LLM 适配层。""" from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout +from app.agent.llm.capability import ( + AgentCapabilityManager, + AgentCapabilityProvider, + AudioCapabilityProvider, + MiMoAudioProvider, + OpenAIChatAudioProvider, + OpenAIAudioProvider, +) from app.agent.llm.provider import ( LLMProviderAuthError, LLMProviderError, @@ -10,10 +18,16 @@ from app.agent.llm.provider import ( __all__ = [ "LLMHelper", + "AgentCapabilityManager", + "AgentCapabilityProvider", + "AudioCapabilityProvider", "LLMProviderAuthError", "LLMProviderError", "LLMProviderManager", "LLMTestError", "LLMTestTimeout", + "MiMoAudioProvider", + "OpenAIChatAudioProvider", + "OpenAIAudioProvider", "render_auth_result_html", ] diff --git a/app/agent/llm/capability.py b/app/agent/llm/capability.py new file mode 100644 index 00000000..9374e873 --- /dev/null +++ b/app/agent/llm/capability.py @@ -0,0 +1,528 @@ +"""Agent 多模态能力 provider 与调度入口。""" + +from __future__ import annotations + +import base64 +import mimetypes +import shutil +import subprocess +from abc import ABC +from io import BytesIO +from pathlib import Path +from typing import Dict, Optional +from uuid import uuid4 + +from app.core.config import settings +from app.log import logger + + +class AgentCapabilityProvider(ABC): + """Agent 能力 provider 基类,后续图片等能力可继续扩展到这里。""" + + name: str + + +class AudioCapabilityProvider(AgentCapabilityProvider): + """音频输入/输出能力 provider。""" + + MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024 + + def is_available_for_audio_input(self) -> bool: + """是否可用于音频输入转写。""" + return False + + def is_available_for_audio_output(self) -> bool: + """是否可用于语音合成输出。""" + return False + + def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: + """将音频字节转成文字。""" + raise NotImplementedError + + def synthesize_speech(self, text: str) -> Optional[Path]: + """将文字合成为可发送的音频文件。""" + raise NotImplementedError + + +class OpenAIAudioProvider(AudioCapabilityProvider): + """OpenAI / OpenAI-compatible 音频 provider。""" + + name = "openai" + + @staticmethod + def _build_client(api_key: str, base_url: Optional[str]): + from openai import OpenAI + + return OpenAI(api_key=api_key, base_url=base_url, max_retries=3) + + @staticmethod + def _input_credentials() -> tuple[Optional[str], Optional[str]]: + return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL + + @staticmethod + def _output_credentials() -> tuple[Optional[str], Optional[str]]: + return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL + + def is_available_for_audio_input(self) -> bool: + api_key, _ = self._input_credentials() + return bool(api_key) + + def is_available_for_audio_output(self) -> bool: + api_key, _ = self._output_credentials() + return bool(api_key) + + def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: + if not content: + return None + if len(content) > self.MAX_TRANSCRIBE_BYTES: + raise ValueError("语音文件超过 10MB,无法识别") + + try: + api_key, base_url = self._input_credentials() + if not api_key: + raise ValueError("音频输入 provider 未配置 API Key") + client = self._build_client(api_key=api_key, base_url=base_url) + audio_file = BytesIO(content) + audio_file.name = filename + response = client.audio.transcriptions.create( + model=settings.AUDIO_INPUT_MODEL, + file=audio_file, + language=settings.AUDIO_INPUT_LANGUAGE or "zh", + response_format="verbose_json", + ) + text = getattr(response, "text", None) + return text.strip() if text else None + except Exception as err: + logger.error(f"音频输入转写失败: provider={self.name}, error={err}") + return None + + def synthesize_speech(self, text: str) -> Optional[Path]: + if not text: + return None + + try: + api_key, base_url = self._output_credentials() + if not api_key: + raise ValueError("音频输出 provider 未配置 API Key") + client = self._build_client(api_key=api_key, base_url=base_url) + voice_dir = settings.TEMP_PATH / "voice" + voice_dir.mkdir(parents=True, exist_ok=True) + output_path = voice_dir / f"{uuid4().hex}.opus" + response = client.audio.speech.create( + model=settings.AUDIO_OUTPUT_MODEL, + voice=settings.AUDIO_OUTPUT_VOICE, + input=text, + response_format="opus", + ) + response.write_to_file(output_path) + return output_path + except Exception as err: + logger.error(f"音频输出合成失败: provider={self.name}, error={err}") + return None + + +class OpenAIChatAudioProvider(AudioCapabilityProvider): + """通过 OpenAI Chat Completions 兼容接口传入/返回音频的 provider。""" + + name = "openai_chat_audio" + DISPLAY_NAME = "OpenAI Chat Audio" + DEFAULT_BASE_URL: Optional[str] = None + DEFAULT_STT_MODEL: Optional[str] = None + DEFAULT_TTS_MODEL: Optional[str] = None + DEFAULT_VOICE = "alloy" + AUDIO_RESPONSE_FORMAT = "wav" + AUDIO_INPUT_DATA_URL = False + INCLUDE_AUDIO_MODALITIES = True + TTS_MESSAGE_ROLE = "user" + SUPPORTED_STT_MODELS: Optional[frozenset[str]] = None + SUPPORTED_TTS_MODELS: Optional[frozenset[str]] = None + UNSUPPORTED_TTS_MODELS = frozenset() + SUPPORTED_AUDIO_MIME_TYPES = { + ".flac": "audio/flac", + ".m4a": "audio/mp4", + ".mp3": "audio/mpeg", + ".ogg": "audio/ogg", + ".opus": "audio/ogg", + ".wav": "audio/wav", + } + + def _build_client(self, api_key: str, base_url: Optional[str]): + from openai import OpenAI + + return OpenAI( + api_key=api_key, + base_url=base_url or self.DEFAULT_BASE_URL, + max_retries=3, + ) + + @staticmethod + def _input_credentials() -> tuple[Optional[str], Optional[str]]: + return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL + + @staticmethod + def _output_credentials() -> tuple[Optional[str], Optional[str]]: + return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL + + def _normalize_stt_model(self) -> str: + return self._normalize_model( + model=settings.AUDIO_INPUT_MODEL, + supported_models=self.SUPPORTED_STT_MODELS, + default_model=self.DEFAULT_STT_MODEL, + ) + + def _normalize_tts_model(self) -> str: + return self._normalize_model( + model=settings.AUDIO_OUTPUT_MODEL, + supported_models=self.SUPPORTED_TTS_MODELS, + default_model=self.DEFAULT_TTS_MODEL, + ) + + @staticmethod + def _normalize_model( + model: Optional[str], + supported_models: Optional[frozenset[str]], + default_model: Optional[str], + ) -> str: + model = (model or "").strip() + if not model: + return default_model or "" + if supported_models is None: + return model + model_key = model.lower() + if model_key in supported_models: + return model_key + return default_model or model + + def _is_supported_tts_model(self) -> bool: + model = self._normalize_tts_model() + if not model: + return False + model_key = model.lower() + if model_key in self.UNSUPPORTED_TTS_MODELS: + return False + return self.SUPPORTED_TTS_MODELS is None or model_key in self.SUPPORTED_TTS_MODELS + + @classmethod + def _guess_audio_mime_type(cls, filename: str) -> str: + suffix = Path(filename or "").suffix.lower() + if suffix in cls.SUPPORTED_AUDIO_MIME_TYPES: + return cls.SUPPORTED_AUDIO_MIME_TYPES[suffix] + mime_type, _ = mimetypes.guess_type(filename or "") + return mime_type or "audio/ogg" + + @staticmethod + def _guess_audio_format(filename: str) -> str: + suffix = Path(filename or "").suffix.lower().lstrip(".") + if suffix == "opus": + return "ogg" + return suffix or "ogg" + + def _build_audio_input_payload(self, content: bytes, filename: str) -> dict: + """按不同 Chat Audio 兼容形态构造 input_audio 内容。""" + audio_data = base64.b64encode(content).decode("utf-8") + if self.AUDIO_INPUT_DATA_URL: + mime_type = self._guess_audio_mime_type(filename) + return {"data": f"data:{mime_type};base64,{audio_data}"} + return { + "data": audio_data, + "format": self._guess_audio_format(filename), + } + + @staticmethod + def _extract_message_text(message) -> Optional[str]: + """兼容音频理解响应可能放在 content 或 reasoning_content 的情况。""" + content = getattr(message, "content", None) + if isinstance(content, str) and content.strip(): + return content.strip() + + reasoning_content = getattr(message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content.strip(): + return reasoning_content.strip() + + extra = getattr(message, "model_extra", None) + if isinstance(extra, dict): + for key in ("content", "reasoning_content"): + value = extra.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + @staticmethod + def _extract_audio_data(message) -> Optional[str]: + audio = getattr(message, "audio", None) + if isinstance(audio, dict): + return audio.get("data") + if audio is not None: + return getattr(audio, "data", None) + + extra = getattr(message, "model_extra", None) + if isinstance(extra, dict) and isinstance(extra.get("audio"), dict): + return extra["audio"].get("data") + return None + + def _convert_wav_to_opus(self, wav_path: Path) -> Optional[Path]: + """将 Chat Audio 返回的 WAV 转成 OGG/Opus,便于各通知渠道发送语音。""" + if not shutil.which("ffmpeg"): + return None + + output_path = wav_path.with_suffix(".opus") + cmd = [ + "ffmpeg", + "-y", + "-i", + str(wav_path), + "-ar", + "48000", + "-ac", + "1", + "-c:a", + "libopus", + str(output_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + if result.returncode != 0 or not output_path.exists(): + logger.warning( + "%s TTS 音频转 Opus 失败,将使用 WAV 原文件: returncode=%s, stderr=%s", + self.DISPLAY_NAME, + result.returncode, + (result.stderr or "").strip()[:500], + ) + return None + return output_path + + def is_available_for_audio_input(self) -> bool: + api_key, _ = self._input_credentials() + return bool(api_key) + + def is_available_for_audio_output(self) -> bool: + api_key, _ = self._output_credentials() + return bool(api_key) and self._is_supported_tts_model() + + def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: + if not content: + return None + if len(content) > self.MAX_TRANSCRIBE_BYTES: + raise ValueError("语音文件超过 10MB,无法识别") + + try: + api_key, base_url = self._input_credentials() + if not api_key: + raise ValueError("音频输入 provider 未配置 API Key") + client = self._build_client(api_key=api_key, base_url=base_url) + language = (settings.AUDIO_INPUT_LANGUAGE or "").strip() + prompt = "请将这段音频完整转写为文字,只输出转写结果,不要添加解释。" + if language: + prompt += f"音频主要语言是 {language}。" + + completion = client.chat.completions.create( + model=self._normalize_stt_model(), + messages=[ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": self._build_audio_input_payload( + content=content, filename=filename + ), + }, + {"type": "text", "text": prompt}, + ], + } + ], + max_completion_tokens=2048, + ) + return self._extract_message_text(completion.choices[0].message) + except Exception as err: + logger.error(f"音频输入转写失败: provider={self.name}, error={err}") + return None + + def synthesize_speech(self, text: str) -> Optional[Path]: + if not text: + return None + if not self._is_supported_tts_model(): + logger.error( + "%s TTS 当前不支持该模型或模型未配置: %s", + self.DISPLAY_NAME, + settings.AUDIO_OUTPUT_MODEL, + ) + return None + + try: + api_key, base_url = self._output_credentials() + if not api_key: + raise ValueError("音频输出 provider 未配置 API Key") + client = self._build_client(api_key=api_key, base_url=base_url) + voice_dir = settings.TEMP_PATH / "voice" + voice_dir.mkdir(parents=True, exist_ok=True) + wav_path = voice_dir / f"{uuid4().hex}.wav" + request = { + "model": self._normalize_tts_model(), + "messages": [ + { + "role": self.TTS_MESSAGE_ROLE, + "content": text, + } + ], + "audio": { + "format": self.AUDIO_RESPONSE_FORMAT, + "voice": settings.AUDIO_OUTPUT_VOICE or self.DEFAULT_VOICE, + }, + } + if self.INCLUDE_AUDIO_MODALITIES: + request["modalities"] = ["text", "audio"] + completion = client.chat.completions.create(**request) + audio_data = self._extract_audio_data(completion.choices[0].message) + if not audio_data: + raise ValueError(f"{self.DISPLAY_NAME} TTS 响应中没有音频数据") + + wav_path.write_bytes(base64.b64decode(audio_data)) + return self._convert_wav_to_opus(wav_path) or wav_path + except Exception as err: + logger.error(f"音频输出合成失败: provider={self.name}, error={err}") + return None + + +class MiMoAudioProvider(OpenAIChatAudioProvider): + """Xiaomi MiMo Chat Audio 预设,仅接入普通 STT/TTS 能力。""" + + name = "mimo" + DISPLAY_NAME = "Xiaomi MiMo" + DEFAULT_BASE_URL = "https://api.xiaomimimo.com/v1" + DEFAULT_STT_MODEL = "mimo-v2.5" + DEFAULT_TTS_MODEL = "mimo-v2.5-tts" + DEFAULT_VOICE = "mimo_default" + AUDIO_INPUT_DATA_URL = True + INCLUDE_AUDIO_MODALITIES = False + TTS_MESSAGE_ROLE = "assistant" + SUPPORTED_STT_MODELS = frozenset({"mimo-v2.5", "mimo-v2-omni"}) + SUPPORTED_TTS_MODELS = frozenset({DEFAULT_TTS_MODEL}) + UNSUPPORTED_TTS_MODELS = frozenset( + { + "mimo-v2.5-tts-voiceclone", + "mimo-v2.5-tts-voicedesign", + } + ) + + def _normalize_tts_model(self) -> str: + model = (settings.AUDIO_OUTPUT_MODEL or "").strip().lower() + if not model or not model.startswith("mimo-"): + return self.DEFAULT_TTS_MODEL + return model + + +class AgentCapabilityManager: + """Agent 能力统一入口。""" + + REPLY_MODE_NATIVE = "native_voice" + REPLY_MODE_TEXT = "text" + _audio_providers: Dict[str, AudioCapabilityProvider] = { + OpenAIAudioProvider.name: OpenAIAudioProvider(), + OpenAIChatAudioProvider.name: OpenAIChatAudioProvider(), + MiMoAudioProvider.name: MiMoAudioProvider(), + } + + @classmethod + def register_audio_provider(cls, provider: AudioCapabilityProvider) -> None: + """注册新的音频 provider。""" + cls._audio_providers[provider.name.lower()] = provider + + @classmethod + def get_registered_audio_providers(cls) -> list[str]: + """返回已注册的音频 provider 名称。""" + return sorted(cls._audio_providers.keys()) + + @staticmethod + def _normalize_provider_name(provider: Optional[str]) -> str: + return (provider or "openai").strip().lower() + + @classmethod + def get_audio_provider(cls, mode: str) -> Optional[AudioCapabilityProvider]: + provider_name = cls._normalize_provider_name( + settings.AUDIO_INPUT_PROVIDER + if (mode or "").lower() == "input" + else settings.AUDIO_OUTPUT_PROVIDER + ) + provider = cls._audio_providers.get(provider_name) + if provider: + return provider + logger.warning("未注册音频 provider: mode=%s, provider=%s", mode, provider_name) + return None + + @staticmethod + def supports_image_input() -> bool: + """当前 Agent 是否启用图片输入能力。""" + return bool(settings.LLM_SUPPORT_IMAGE_INPUT) + + @staticmethod + def supports_audio_input() -> bool: + """当前 Agent 是否启用音频输入能力。""" + return bool(settings.LLM_SUPPORT_AUDIO_INPUT) + + @staticmethod + def supports_audio_output() -> bool: + """当前 Agent 是否启用音频输出能力。""" + return bool(settings.LLM_SUPPORT_AUDIO_OUTPUT) + + @classmethod + def is_audio_input_available(cls) -> bool: + if not cls.supports_audio_input(): + return False + provider = cls.get_audio_provider("input") + return bool(provider and provider.is_available_for_audio_input()) + + @classmethod + def is_audio_output_available(cls) -> bool: + if not cls.supports_audio_output(): + return False + provider = cls.get_audio_provider("output") + return bool(provider and provider.is_available_for_audio_output()) + + @classmethod + def transcribe_audio(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]: + provider = cls.get_audio_provider("input") + if not provider or not cls.is_audio_input_available(): + return None + return provider.transcribe_audio(content=content, filename=filename) + + @classmethod + def synthesize_speech(cls, text: str) -> Optional[Path]: + provider = cls.get_audio_provider("output") + if not provider or not cls.is_audio_output_available(): + return None + return provider.synthesize_speech(text=text) + + @classmethod + def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str: + """仅在支持原生语音回复的渠道上发送音频,其余渠道回退文字。""" + if cls.supports_native_voice_reply(channel=channel, source=source): + return cls.REPLY_MODE_NATIVE + return cls.REPLY_MODE_TEXT + + @classmethod + def supports_native_voice_reply( + cls, channel: Optional[str], source: Optional[str] + ) -> bool: + """判断当前渠道是否支持原生语音消息发送。""" + if not channel: + return False + + from app.helper.service import ServiceConfigHelper + from app.schemas.types import MessageChannel + + try: + channel_enum = MessageChannel(channel) + except (TypeError, ValueError): + return False + + if channel_enum == MessageChannel.Telegram: + return True + if channel_enum != MessageChannel.Wechat: + return False + + # 企业微信 bot 模式不支持发送语音,只有应用模式可用。 + for config in ServiceConfigHelper.get_notification_configs(): + if config.name != source: + continue + return (config.config or {}).get("WECHAT_MODE", "app") != "bot" + return False diff --git a/app/agent/llm/helper.py b/app/agent/llm/helper.py index a49fda4f..f3ecc427 100644 --- a/app/agent/llm/helper.py +++ b/app/agent/llm/helper.py @@ -886,7 +886,6 @@ class LLMHelper: {"id": model_id, "name": model_id} for model_id in await self._get_google_models(api_key or "") ] - model_list_base_url = base_url try: from app.agent.llm.provider import LLMProviderManager diff --git a/app/agent/prompt/__init__.py b/app/agent/prompt/__init__.py index 6cc44ae0..1de0129d 100644 --- a/app/agent/prompt/__init__.py +++ b/app/agent/prompt/__init__.py @@ -9,6 +9,7 @@ from typing import Any, Dict, Optional import yaml +from app.agent.llm.capability import AgentCapabilityManager from app.core.config import settings from app.log import logger from app.schemas import ( @@ -327,10 +328,12 @@ class PromptManager: @staticmethod def _generate_voice_reply_instructions() -> str: + if not AgentCapabilityManager.supports_audio_output(): + return "Audio output is disabled; do not call `send_voice_message`." return ( - "- Voice replies: Use normal text replies by default. " - "Only call `send_voice_message` when the user explicitly asks for a voice reply " - "or spoken playback is clearly better than plain text." + "Use normal text replies by default. Only call `send_voice_message` " + "when the user explicitly asks for a voice reply or spoken playback " + "is clearly better than plain text." ) @staticmethod diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 1503827a..e44810ef 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -77,6 +77,7 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool from app.agent.tools.impl.query_system_settings import QuerySystemSettingsTool from app.agent.tools.impl.update_system_settings import UpdateSystemSettingsTool +from app.agent.llm.capability import AgentCapabilityManager from app.core.plugin import PluginManager from app.log import logger from app.schemas.message import ChannelCapabilityManager @@ -225,12 +226,9 @@ class MoviePilotToolFactory: ] if MoviePilotToolFactory._should_enable_choice_tool(channel): tool_definitions.append(AskUserChoiceTool) - tool_definitions.extend( - [ - SendLocalFileTool, - SendVoiceMessageTool, - ] - ) + tool_definitions.append(SendLocalFileTool) + if AgentCapabilityManager.supports_audio_output(): + tool_definitions.append(SendVoiceMessageTool) # 创建内置工具 for ToolClass in tool_definitions: tool = ToolClass(session_id=session_id, user_id=user_id) diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py index dafdf65b..45d38a39 100644 --- a/app/agent/tools/impl/send_voice_message.py +++ b/app/agent/tools/impl/send_voice_message.py @@ -5,9 +5,9 @@ from typing import Optional, Type from pydantic import BaseModel, Field +from app.agent.llm.capability import AgentCapabilityManager from app.agent.tools.base import MoviePilotTool, ToolChain from app.core.config import settings -from app.helper.voice import VoiceHelper from app.log import logger from app.schemas import Notification, NotificationType @@ -50,22 +50,24 @@ class SendVoiceMessageTool(MoviePilotTool): voice_path = None used_voice = False channel = self._channel or "" - reply_mode = VoiceHelper.resolve_reply_mode( + reply_mode = AgentCapabilityManager.resolve_reply_mode( channel=channel, source=self._source, ) fallback_reason = "当前渠道不支持语音回复" - if not VoiceHelper.is_enabled(): - fallback_reason = "当前未启用音频输入输出" + if not AgentCapabilityManager.supports_audio_output(): + fallback_reason = "当前未启用音频输出" if ( - reply_mode == VoiceHelper.REPLY_MODE_NATIVE - and VoiceHelper.is_available("tts") + reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE + and AgentCapabilityManager.is_audio_output_available() ): - voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message) + voice_file = await asyncio.to_thread( + AgentCapabilityManager.synthesize_speech, message + ) if voice_file: voice_path = str(voice_file) used_voice = True - elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE: + elif reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE: fallback_reason = "当前未配置可用的语音合成能力" logger.info( @@ -87,7 +89,7 @@ class SendVoiceMessageTool(MoviePilotTool): voice_path=voice_path, voice_caption=( message - if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT + if voice_path and settings.AUDIO_OUTPUT_INCLUDE_TEXT else None ), ) diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 75af0205..3901895c 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -476,7 +476,8 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn info = settings.model_dump( include={ "AI_AGENT_ENABLE", - "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", + "LLM_SUPPORT_AUDIO_INPUT", + "LLM_SUPPORT_AUDIO_OUTPUT", "RECOGNIZE_SOURCE", "SEARCH_SOURCE", "AI_RECOMMEND_ENABLED", @@ -486,7 +487,8 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn # 智能助手总开关未开启,智能推荐状态强制返回False if not settings.AI_AGENT_ENABLE: info["AI_RECOMMEND_ENABLED"] = False - info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False + info["LLM_SUPPORT_AUDIO_INPUT"] = False + info["LLM_SUPPORT_AUDIO_OUTPUT"] = False # 追加用户唯一ID和订阅分享管理权限 share_admin = SubscribeHelper().is_admin_user() diff --git a/app/chain/message.py b/app/chain/message.py index d5d49f95..098f3de2 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -12,7 +12,7 @@ from typing import Any, Optional, Dict, Union, List, Tuple from urllib.parse import unquote, urlparse from app.agent import ReplyMode, agent_manager, prompt_manager -from app.agent.llm import LLMHelper +from app.agent.llm import AgentCapabilityManager, LLMHelper from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.media import MediaChain @@ -29,7 +29,6 @@ from app.db.transferhistory_oper import TransferHistoryOper from app.db.user_oper import UserOper from app.helper.interaction import agent_interaction_manager, media_interaction_manager, PendingMediaInteraction from app.helper.torrent import TorrentHelper -from app.helper.voice import VoiceHelper from app.log import logger from app.schemas import Notification, CommingMessage, NotExistMediaInfo from app.schemas.message import ChannelCapabilityManager @@ -1196,8 +1195,8 @@ class MessageChain(ChainBase): """ if not audio_refs: return None - if not VoiceHelper.is_available("stt"): - logger.warning("语音能力未配置,跳过语音识别") + if not AgentCapabilityManager.is_audio_input_available(): + logger.warning("音频输入能力未配置或未启用,跳过语音识别") return None transcripts = [] @@ -1303,7 +1302,7 @@ class MessageChain(ChainBase): ) continue - transcript = VoiceHelper.transcribe_bytes( + transcript = AgentCapabilityManager.transcribe_audio( content=content, filename=filename ) if transcript: diff --git a/app/core/config.py b/app/core/config.py index 90a1a3bb..381bd043 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -541,8 +541,10 @@ class ConfigModel(BaseModel): LLM_THINKING_LEVEL: Optional[str] = "off" # LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型 LLM_SUPPORT_IMAGE_INPUT: bool = True - # LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复 - LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False + # 是否启用音频输入,开启后用户语音会先转写为文本再进入 Agent + LLM_SUPPORT_AUDIO_INPUT: bool = False + # 是否启用音频输出,开启后 Agent 可在支持渠道发送语音回复 + LLM_SUPPORT_AUDIO_OUTPUT: bool = False # LLM API密钥 LLM_API_KEY: Optional[str] = None # LLM基础URL(用于自定义API端点) @@ -589,22 +591,28 @@ class ConfigModel(BaseModel): # AI智能体自动重试整理失败记录开关 AI_AGENT_RETRY_TRANSFER: bool = False - # 语音能力提供商(当前仅支持 openai/openai-compatible) - AI_VOICE_PROVIDER: str = "openai" - # 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY - AI_VOICE_API_KEY: Optional[str] = None - # 语音能力共享基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL - AI_VOICE_BASE_URL: Optional[str] = None - # 语音转文字模型 - AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe" - # 文字转语音模型 - AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts" - # TTS 发音人 - AI_VOICE_TTS_VOICE: str = "alloy" - # 语音识别语言 - AI_VOICE_LANGUAGE: str = "zh" + # 音频输入提供商:openai/openai_chat_audio/mimo + AUDIO_INPUT_PROVIDER: str = "openai" + # 音频输入 API 密钥 + AUDIO_INPUT_API_KEY: Optional[str] = None + # 音频输入基础URL + AUDIO_INPUT_BASE_URL: Optional[str] = None + # 音频输入模型 + AUDIO_INPUT_MODEL: str = "gpt-4o-mini-transcribe" + # 音频输入识别语言 + AUDIO_INPUT_LANGUAGE: str = "zh" + # 音频输出提供商:openai/openai_chat_audio/mimo + AUDIO_OUTPUT_PROVIDER: str = "openai" + # 音频输出 API 密钥 + AUDIO_OUTPUT_API_KEY: Optional[str] = None + # 音频输出基础URL + AUDIO_OUTPUT_BASE_URL: Optional[str] = None + # 音频输出模型 + AUDIO_OUTPUT_MODEL: str = "gpt-4o-mini-tts" + # 音频输出音色/发音人 + AUDIO_OUTPUT_VOICE: str = "alloy" # 回复语音时是否同时附带文字说明 - AI_VOICE_REPLY_WITH_TEXT: bool = False + AUDIO_OUTPUT_INCLUDE_TEXT: bool = False class Settings(BaseSettings, ConfigModel, LogConfigModel): @@ -824,7 +832,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel): return False, f"配置项 '{key}' 不存在" try: - field = Settings.model_fields[key] + field = Settings.model_fields.get(key) + if not field: + return False, f"配置项 '{key}' 不存在" original_value = getattr(self, key) if key == "API_TOKEN": converted_value, needs_update = self.validate_api_token( diff --git a/app/helper/voice.py b/app/helper/voice.py deleted file mode 100644 index 089512d7..00000000 --- a/app/helper/voice.py +++ /dev/null @@ -1,234 +0,0 @@ -"""语音能力辅助功能。""" - -from abc import ABC, abstractmethod -from io import BytesIO -from pathlib import Path -from typing import Dict, Optional -from uuid import uuid4 - -from app.core.config import settings -from app.log import logger - - -class VoiceProvider(ABC): - """语音 provider 抽象层。""" - - MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024 - - @property - @abstractmethod - def name(self) -> str: - """provider 名称。""" - - @abstractmethod - def is_available_for_stt(self) -> bool: - """是否可用于语音识别。""" - - @abstractmethod - def is_available_for_tts(self) -> bool: - """是否可用于语音合成。""" - - @abstractmethod - def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: - """将音频字节转成文字。""" - - @abstractmethod - def synthesize_speech(self, text: str) -> Optional[Path]: - """将文字转成语音文件。""" - - -class OpenAIVoiceProvider(VoiceProvider): - """OpenAI / OpenAI-compatible provider。""" - - @property - def name(self) -> str: - return "openai" - - @staticmethod - def _resolve_provider_name() -> str: - provider = settings.AI_VOICE_PROVIDER or "openai" - return provider.strip().lower() - - def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]: - provider = self._resolve_provider_name() - api_key = settings.AI_VOICE_API_KEY - base_url = settings.AI_VOICE_BASE_URL - - if ( - not api_key - and provider == "openai" - and (settings.LLM_PROVIDER or "").strip().lower() == "openai" - ): - api_key = settings.LLM_API_KEY - base_url = base_url or settings.LLM_BASE_URL - - return api_key, base_url - - def _get_client(self, mode: str): - from openai import OpenAI - - api_key, base_url = self._resolve_credentials() - if not api_key: - raise ValueError(f"{mode.upper()} provider 未配置 API Key") - return OpenAI(api_key=api_key, base_url=base_url, max_retries=3) - - def is_available_for_stt(self) -> bool: - api_key, _ = self._resolve_credentials() - return bool(api_key) - - def is_available_for_tts(self) -> bool: - api_key, _ = self._resolve_credentials() - return bool(api_key) - - def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: - if not content: - return None - if len(content) > self.MAX_TRANSCRIBE_BYTES: - raise ValueError("语音文件超过 10MB,无法识别") - - try: - client = self._get_client("stt") - audio_file = BytesIO(content) - audio_file.name = filename - response = client.audio.transcriptions.create( - model=settings.AI_VOICE_STT_MODEL, - file=audio_file, - language=settings.AI_VOICE_LANGUAGE or "zh", - response_format="verbose_json", - ) - text = getattr(response, "text", None) - return text.strip() if text else None - except Exception as err: - logger.error(f"语音转文字失败: provider={self.name}, error={err}") - return None - - def synthesize_speech(self, text: str) -> Optional[Path]: - if not text: - return None - - try: - client = self._get_client("tts") - voice_dir = settings.TEMP_PATH / "voice" - voice_dir.mkdir(parents=True, exist_ok=True) - output_path = voice_dir / f"{uuid4().hex}.opus" - response = client.audio.speech.create( - model=settings.AI_VOICE_TTS_MODEL, - voice=settings.AI_VOICE_TTS_VOICE, - input=text, - response_format="opus", - ) - response.write_to_file(output_path) - return output_path - except Exception as err: - logger.error(f"文字转语音失败: provider={self.name}, error={err}") - return None - - -class VoiceHelper: - """统一语音入口,负责音频能力判断与 STT/TTS provider 路由。""" - - _providers: Dict[str, VoiceProvider] = { - "openai": OpenAIVoiceProvider(), - } - REPLY_MODE_NATIVE = "native_voice" - REPLY_MODE_TEXT = "text" - - @classmethod - def register_provider(cls, provider: VoiceProvider) -> None: - cls._providers[provider.name.lower()] = provider - - @staticmethod - def is_enabled() -> bool: - """音频输入输出总开关,以显式配置为准。""" - return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT) - - @staticmethod - def _resolve_provider_name() -> str: - """标准化当前配置的语音 provider 名称。""" - provider = settings.AI_VOICE_PROVIDER or "openai" - return provider.strip().lower() - - @classmethod - def get_provider(cls, mode: str) -> Optional[VoiceProvider]: - provider_name = cls._resolve_provider_name() - provider = cls._providers.get(provider_name) - if provider: - return provider - logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}") - return None - - @classmethod - def get_registered_providers(cls) -> list[str]: - return sorted(cls._providers.keys()) - - @classmethod - def is_available(cls, mode: Optional[str] = None) -> bool: - if not cls.is_enabled(): - return False - if mode: - provider = cls.get_provider(mode) - if not provider: - return False - return ( - provider.is_available_for_stt() - if mode.lower() == "stt" - else provider.is_available_for_tts() - ) - return cls.is_available("stt") or cls.is_available("tts") - - @classmethod - def supports_native_voice_reply( - cls, channel: Optional[str], source: Optional[str] - ) -> bool: - """ - 判断当前渠道是否支持原生语音消息发送。 - """ - if not channel: - return False - - from app.helper.service import ServiceConfigHelper - from app.schemas.types import MessageChannel - - try: - channel_enum = MessageChannel(channel) - except (TypeError, ValueError): - return False - - if channel_enum == MessageChannel.Telegram: - return True - if channel_enum != MessageChannel.Wechat: - return False - - # 企业微信 bot 模式不支持发送语音,只有应用模式可用。 - for config in ServiceConfigHelper.get_notification_configs(): - if config.name != source: - continue - return (config.config or {}).get("WECHAT_MODE", "app") != "bot" - return False - - @classmethod - def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str: - """ - 仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。 - """ - if cls.supports_native_voice_reply(channel=channel, source=source): - return cls.REPLY_MODE_NATIVE - return cls.REPLY_MODE_TEXT - - @classmethod - def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]: - if not cls.is_enabled(): - return None - provider = cls.get_provider("stt") - if not provider: - return None - return provider.transcribe_bytes(content=content, filename=filename) - - @classmethod - def synthesize_speech(cls, text: str) -> Optional[Path]: - if not cls.is_enabled(): - return None - provider = cls.get_provider("tts") - if not provider: - return None - return provider.synthesize_speech(text=text) diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index a691df0c..e0b62365 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -12,10 +12,10 @@ from telebot import apihelper from app.agent.tools.impl.send_message import SendMessageInput from app.agent.tools.impl.send_local_file import SendLocalFileInput from app.agent import MoviePilotAgent, AgentChain +from app.agent.llm import AgentCapabilityManager from app.chain.message import MessageChain from app.core.config import settings from app.agent.llm import LLMHelper -from app.helper.voice import VoiceHelper from app.modules.discord import DiscordModule from app.modules.qqbot import QQBotModule from app.modules.qqbot.qqbot import QQBot @@ -284,13 +284,15 @@ class AgentImageSupportTest(unittest.TestCase): "feishu://file/om_audio/file_audio/voice.opus", ] - with patch.object(VoiceHelper, "is_available", return_value=True), patch.object( + with patch.object( + AgentCapabilityManager, "is_audio_input_available", return_value=True + ), patch.object( chain, "run_module", side_effect=[b"slack", b"discord", b"qq", b"vocechat", b"synology", b"feishu"], ) as run_module, patch.object( - VoiceHelper, - "transcribe_bytes", + AgentCapabilityManager, + "transcribe_audio", side_effect=[ "slack text", "discord text", diff --git a/tests/test_agent_llm_capability.py b/tests/test_agent_llm_capability.py new file mode 100644 index 00000000..12465874 --- /dev/null +++ b/tests/test_agent_llm_capability.py @@ -0,0 +1,235 @@ +import sys +import unittest +import importlib.util +from base64 import b64encode +from pathlib import Path +from tempfile import TemporaryDirectory +from types import SimpleNamespace +from unittest.mock import Mock, patch + +sys.modules.setdefault("psutil", Mock()) +sys.modules.setdefault("pyquery", Mock()) + +from app.core.config import settings + +module_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "capability.py" +spec = importlib.util.spec_from_file_location("test_agent_llm_capability_module", module_path) +capability_module = importlib.util.module_from_spec(spec) +assert spec and spec.loader +sys.modules[spec.name] = capability_module +spec.loader.exec_module(capability_module) + +AgentCapabilityManager = capability_module.AgentCapabilityManager +MiMoAudioProvider = capability_module.MiMoAudioProvider +OpenAIChatAudioProvider = capability_module.OpenAIChatAudioProvider +OpenAIAudioProvider = capability_module.OpenAIAudioProvider + + +class AgentCapabilityManagerTest(unittest.TestCase): + def test_registered_audio_providers_contains_builtin_providers(self): + self.assertIn("openai", AgentCapabilityManager.get_registered_audio_providers()) + self.assertIn( + "openai_chat_audio", AgentCapabilityManager.get_registered_audio_providers() + ) + self.assertIn("mimo", AgentCapabilityManager.get_registered_audio_providers()) + + def test_get_audio_provider_uses_separate_input_and_output_settings(self): + with patch.object(settings, "AUDIO_INPUT_PROVIDER", "openai"), patch.object( + settings, "AUDIO_OUTPUT_PROVIDER", "mimo" + ): + self.assertIsInstance( + AgentCapabilityManager.get_audio_provider("input"), OpenAIAudioProvider + ) + self.assertIsInstance( + AgentCapabilityManager.get_audio_provider("output"), MiMoAudioProvider + ) + + def test_chat_audio_provider_keeps_arbitrary_compatible_models(self): + provider = OpenAIChatAudioProvider() + + with patch.object( + settings, "AUDIO_INPUT_MODEL", "vendor-omni-audio" + ), patch.object(settings, "AUDIO_OUTPUT_MODEL", "vendor-tts-audio"): + self.assertEqual(provider._normalize_stt_model(), "vendor-omni-audio") + self.assertEqual(provider._normalize_tts_model(), "vendor-tts-audio") + + def test_chat_audio_provider_uses_openai_audio_payload_shape(self): + provider = OpenAIChatAudioProvider() + fake_client = Mock() + fake_client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="你好"))] + ) + + with patch.object(provider, "_build_client", return_value=fake_client), patch.object( + settings, "AUDIO_INPUT_MODEL", "gpt-4o-audio-preview" + ), patch.object(settings, "AUDIO_INPUT_LANGUAGE", "zh"), patch.object( + settings, "AUDIO_INPUT_API_KEY", "sk-test" + ), patch.object(settings, "AUDIO_INPUT_BASE_URL", "https://example.com/v1"): + result = provider.transcribe_audio(b"audio-bytes", filename="input.wav") + + self.assertEqual(result, "你好") + request = fake_client.chat.completions.create.call_args.kwargs + content = request["messages"][0]["content"] + self.assertEqual( + content[0]["input_audio"], + {"data": b64encode(b"audio-bytes").decode("utf-8"), "format": "wav"}, + ) + + def test_chat_audio_provider_requests_audio_modality_for_tts(self): + provider = OpenAIChatAudioProvider() + fake_client = Mock() + audio_data = b64encode(b"wav-bytes").decode("utf-8") + fake_client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(audio={"data": audio_data}))] + ) + + with TemporaryDirectory() as temp_dir, patch.object( + provider, "_build_client", return_value=fake_client + ), patch.object( + capability_module, + "settings", + SimpleNamespace( + TEMP_PATH=Path(temp_dir), + AUDIO_OUTPUT_MODEL="gpt-4o-audio-preview", + AUDIO_OUTPUT_VOICE="alloy", + AUDIO_OUTPUT_API_KEY="sk-test", + AUDIO_OUTPUT_BASE_URL="https://example.com/v1", + ), + ), patch.object(provider, "_convert_wav_to_opus", return_value=None): + output_path = provider.synthesize_speech("你好") + + self.assertIsNotNone(output_path) + request = fake_client.chat.completions.create.call_args.kwargs + self.assertEqual(request["messages"][0]["role"], "user") + self.assertEqual(request["modalities"], ["text", "audio"]) + self.assertEqual(request["audio"], {"format": "wav", "voice": "alloy"}) + + def test_audio_input_and_output_switches_are_independent(self): + provider = Mock() + provider.is_available_for_audio_input.return_value = True + provider.is_available_for_audio_output.return_value = True + + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT", True + ), patch.object( + settings, "LLM_SUPPORT_AUDIO_OUTPUT", False + ), patch.object( + AgentCapabilityManager, "get_audio_provider", return_value=provider + ): + self.assertTrue(AgentCapabilityManager.is_audio_input_available()) + self.assertFalse(AgentCapabilityManager.is_audio_output_available()) + + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT", False + ), patch.object( + settings, "LLM_SUPPORT_AUDIO_OUTPUT", True + ), patch.object( + AgentCapabilityManager, "get_audio_provider", return_value=provider + ): + self.assertFalse(AgentCapabilityManager.is_audio_input_available()) + self.assertTrue(AgentCapabilityManager.is_audio_output_available()) + + def test_transcribe_audio_routes_to_input_provider(self): + provider = Mock() + provider.is_available_for_audio_input.return_value = True + provider.transcribe_audio.return_value = "你好" + + with patch.object(settings, "LLM_SUPPORT_AUDIO_INPUT", True), patch.object( + AgentCapabilityManager, "get_audio_provider", return_value=provider + ): + result = AgentCapabilityManager.transcribe_audio(b"audio") + + self.assertEqual(result, "你好") + provider.transcribe_audio.assert_called_once() + + def test_synthesize_speech_routes_to_output_provider(self): + provider = Mock() + provider.is_available_for_audio_output.return_value = True + provider.synthesize_speech.return_value = Path("/tmp/reply.opus") + + with patch.object(settings, "LLM_SUPPORT_AUDIO_OUTPUT", True), patch.object( + AgentCapabilityManager, "get_audio_provider", return_value=provider + ): + result = AgentCapabilityManager.synthesize_speech("你好") + + self.assertEqual(result, Path("/tmp/reply.opus")) + provider.synthesize_speech.assert_called_once_with(text="你好") + + def test_mimo_tts_uses_chat_completions_audio_payload(self): + provider = MiMoAudioProvider() + fake_client = Mock() + audio_data = b64encode(b"wav-bytes").decode("utf-8") + fake_client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(audio={"data": audio_data}))] + ) + + with TemporaryDirectory() as temp_dir, patch.object( + provider, "_build_client", return_value=fake_client + ), patch.object( + capability_module, + "settings", + SimpleNamespace( + TEMP_PATH=Path(temp_dir), + AUDIO_OUTPUT_MODEL="mimo-v2.5-tts", + AUDIO_OUTPUT_VOICE="冰糖", + AUDIO_OUTPUT_API_KEY="sk-test", + AUDIO_OUTPUT_BASE_URL="https://api.xiaomimimo.com/v1", + ), + ), patch.object(provider, "_convert_wav_to_opus", return_value=None): + output_path = provider.synthesize_speech("你好") + output_bytes = output_path.read_bytes() if output_path else None + + self.assertIsNotNone(output_path) + self.assertEqual(output_bytes, b"wav-bytes") + fake_client.chat.completions.create.assert_called_once() + request = fake_client.chat.completions.create.call_args.kwargs + self.assertEqual(request["model"], "mimo-v2.5-tts") + self.assertEqual(request["messages"][0]["role"], "assistant") + self.assertEqual(request["messages"][0]["content"], "你好") + self.assertEqual(request["audio"], {"format": "wav", "voice": "冰糖"}) + + def test_mimo_tts_rejects_voice_design_and_clone_models(self): + provider = MiMoAudioProvider() + + with patch.object( + settings, "AUDIO_OUTPUT_MODEL", "mimo-v2.5-tts-voiceclone" + ), patch.object(provider, "_build_client") as build_client: + result = provider.synthesize_speech("你好") + + self.assertIsNone(result) + build_client.assert_not_called() + + def test_mimo_stt_rejects_non_audio_mimo_models_by_falling_back(self): + provider = MiMoAudioProvider() + + with patch.object(settings, "AUDIO_INPUT_MODEL", "mimo-v2.5-pro"): + self.assertEqual(provider._normalize_stt_model(), "mimo-v2.5") + + def test_mimo_stt_uses_base64_audio_input(self): + provider = MiMoAudioProvider() + fake_client = Mock() + fake_client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="你好"))] + ) + + with patch.object(provider, "_build_client", return_value=fake_client), patch.object( + settings, "AUDIO_INPUT_MODEL", "mimo-v2.5" + ), patch.object(settings, "AUDIO_INPUT_LANGUAGE", "zh"), patch.object( + settings, "AUDIO_INPUT_API_KEY", "sk-test" + ), patch.object( + settings, "AUDIO_INPUT_BASE_URL", "https://api.xiaomimimo.com/v1" + ): + result = provider.transcribe_audio(b"audio-bytes", filename="input.wav") + + self.assertEqual(result, "你好") + request = fake_client.chat.completions.create.call_args.kwargs + content = request["messages"][0]["content"] + self.assertEqual(request["model"], "mimo-v2.5") + self.assertTrue( + content[0]["input_audio"]["data"].startswith("data:audio/wav;base64,") + ) + self.assertIn("只输出转写结果", content[1]["text"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index b3d8eadb..2faad8a8 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -1,4 +1,5 @@ import asyncio +import importlib.machinery import sys import unittest from types import SimpleNamespace @@ -18,6 +19,10 @@ def _stub_module(name: str, **attrs): _stub_module("qbittorrentapi", TorrentFilesList=list) _stub_module("transmission_rpc", File=object) +_stub_module( + "psutil", + __spec__=importlib.machinery.ModuleSpec("psutil", loader=None), +) from app.agent.tools.factory import MoviePilotToolFactory from app.agent import ReplyMode diff --git a/tests/test_voice_helper.py b/tests/test_voice_helper.py deleted file mode 100644 index 832eb5a9..00000000 --- a/tests/test_voice_helper.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -import sys -from unittest.mock import Mock, patch - -sys.modules.setdefault("psutil", Mock()) -sys.modules.setdefault("pyquery", Mock()) - -from app.core.config import settings -from app.helper.voice import VoiceHelper, OpenAIVoiceProvider - - -class VoiceHelperTest(unittest.TestCase): - def test_registered_providers_contains_openai(self): - self.assertIn("openai", VoiceHelper.get_registered_providers()) - - def test_get_provider_uses_single_audio_provider_setting(self): - with patch.object(settings, "AI_VOICE_PROVIDER", "openai"): - provider = VoiceHelper.get_provider("stt") - - self.assertIsInstance(provider, OpenAIVoiceProvider) - - def test_is_available_checks_stt_and_tts_separately(self): - provider = Mock() - provider.is_available_for_stt.return_value = True - provider.is_available_for_tts.return_value = False - - with patch.object( - settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True - ), patch.object(VoiceHelper, "get_provider", return_value=provider): - self.assertTrue(VoiceHelper.is_available("stt")) - self.assertFalse(VoiceHelper.is_available("tts")) - - def test_is_available_returns_false_when_audio_switch_is_disabled(self): - provider = Mock() - provider.is_available_for_stt.return_value = True - - with patch.object( - settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", False - ), patch.object(VoiceHelper, "get_provider", return_value=provider): - self.assertFalse(VoiceHelper.is_available("stt")) - self.assertFalse(VoiceHelper.is_available()) - - def test_transcribe_bytes_routes_to_stt_provider(self): - provider = Mock() - provider.transcribe_bytes.return_value = "你好" - - with patch.object( - settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True - ), patch.object(VoiceHelper, "get_provider", return_value=provider): - result = VoiceHelper.transcribe_bytes(b"audio") - - self.assertEqual(result, "你好") - provider.transcribe_bytes.assert_called_once() - - def test_synthesize_speech_routes_to_tts_provider(self): - provider = Mock() - provider.synthesize_speech.return_value = "/tmp/reply.opus" - - with patch.object( - settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True - ), patch.object(VoiceHelper, "get_provider", return_value=provider): - result = VoiceHelper.synthesize_speech("你好") - - self.assertEqual(result, "/tmp/reply.opus") - provider.synthesize_speech.assert_called_once_with(text="你好") - - -if __name__ == "__main__": - unittest.main()