"""Default runtime service factory implementing core extension ports.""" from __future__ import annotations from typing import Any from loguru import logger from core.ports import ( ASRPort, ASRServiceSpec, LLMPort, LLMServiceSpec, RealtimeServiceFactory, TTSPort, TTSServiceSpec, ) from services.asr import BufferedASRService from services.dashscope_tts import DashScopeTTSService from services.llm import MockLLMService, OpenAILLMService from services.openai_compatible_asr import OpenAICompatibleASRService from services.openai_compatible_tts import OpenAICompatibleTTSService from services.tts import MockTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} class DefaultRealtimeServiceFactory(RealtimeServiceFactory): """Build concrete runtime services from normalized specs.""" _DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" _DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" @staticmethod def _normalize_provider(provider: Any) -> str: return str(provider or "").strip().lower() @staticmethod def _resolve_dashscope_mode(raw_mode: Any) -> str: mode = str(raw_mode or "commit").strip().lower() if mode in {"commit", "server_commit"}: return mode return "commit" def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort: provider = self._normalize_provider(spec.provider) if provider in _SUPPORTED_LLM_PROVIDERS and spec.api_key: return OpenAILLMService( api_key=spec.api_key, base_url=spec.base_url, model=spec.model, system_prompt=spec.system_prompt, knowledge_config=spec.knowledge_config, knowledge_searcher=spec.knowledge_searcher, ) logger.warning( "LLM provider unsupported or API key missing (provider={}); using mock LLM", provider or "-", ) return MockLLMService() def create_tts_service(self, spec: TTSServiceSpec) -> TTSPort: provider = self._normalize_provider(spec.provider) if provider == "dashscope" and spec.api_key: return DashScopeTTSService( api_key=spec.api_key, api_url=spec.api_url or self._DEFAULT_DASHSCOPE_TTS_REALTIME_URL, voice=spec.voice, model=spec.model or self._DEFAULT_DASHSCOPE_TTS_MODEL, mode=self._resolve_dashscope_mode(spec.mode), sample_rate=spec.sample_rate, speed=spec.speed, ) if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleTTSService( api_key=spec.api_key, api_url=spec.api_url, voice=spec.voice, model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL, sample_rate=spec.sample_rate, speed=spec.speed, ) logger.warning( "TTS provider unsupported or API key missing (provider={}); using mock TTS", provider or "-", ) return MockTTSService(sample_rate=spec.sample_rate) def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort: provider = self._normalize_provider(spec.provider) if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleASRService( api_key=spec.api_key, api_url=spec.api_url, model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL, sample_rate=spec.sample_rate, language=spec.language, interim_interval_ms=spec.interim_interval_ms, min_audio_for_interim_ms=spec.min_audio_for_interim_ms, on_transcript=spec.on_transcript, ) logger.info("Using buffered ASR service (provider={})", provider or "-") return BufferedASRService(sample_rate=spec.sample_rate, language=spec.language)