"""Default runtime service factory implementing core extension ports.""" from __future__ import annotations from typing import Any from loguru import logger from runtime.ports import ( ASRPort, ASRServiceSpec, LLMPort, LLMServiceSpec, RealtimeServiceFactory, TTSPort, TTSServiceSpec, ) from providers.asr.buffered import BufferedASRService from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.volcengine import VolcengineRealtimeASRService from providers.tts.dashscope import DashScopeTTSService from providers.llm.openai import MockLLMService, OpenAILLMService from providers.asr.openai_compatible import OpenAICompatibleASRService from providers.tts.openai_compatible import OpenAICompatibleTTSService from providers.tts.mock import MockTTSService from providers.tts.volcengine import VolcengineTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _DASHSCOPE_PROVIDERS = {"dashscope"} _VOLCENGINE_PROVIDERS = {"volcengine"} _SUPPORTED_LLM_PROVIDERS = {"openai", "dify", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS} class DefaultRealtimeServiceFactory(RealtimeServiceFactory): """Build concrete runtime services from normalized specs.""" _DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" _DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime" _DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" _DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" _DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" _DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0" _DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" _DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel" @staticmethod def _normalize_provider(provider: Any) -> str: return str(provider or "").strip().lower() @staticmethod def _resolve_dashscope_mode(raw_mode: Any) -> str: mode = str(raw_mode or "commit").strip().lower() if mode in {"commit", "server_commit"}: return mode return "commit" def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort: provider = self._normalize_provider(spec.provider) if provider == "dify" and spec.api_key and spec.base_url: from providers.llm.dify import DifyLLMService return DifyLLMService( api_key=spec.api_key, base_url=spec.base_url, model=spec.model, system_prompt=spec.system_prompt, ) if provider == "fastgpt" and spec.api_key and spec.base_url: from providers.llm.fastgpt import FastGPTLLMService return FastGPTLLMService( api_key=spec.api_key, base_url=spec.base_url, app_id=spec.app_id, model=spec.model, system_prompt=spec.system_prompt, ) if provider in _SUPPORTED_LLM_PROVIDERS and provider != "fastgpt" and spec.api_key: return OpenAILLMService( api_key=spec.api_key, base_url=spec.base_url, model=spec.model, system_prompt=spec.system_prompt, knowledge_config=spec.knowledge_config, knowledge_searcher=spec.knowledge_searcher, ) logger.warning( "LLM provider unsupported or API key missing (provider={}); using mock LLM", provider or "-", ) return MockLLMService() def create_tts_service(self, spec: TTSServiceSpec) -> TTSPort: provider = self._normalize_provider(spec.provider) if provider == "dashscope" and spec.api_key: return DashScopeTTSService( api_key=spec.api_key, api_url=spec.api_url or self._DEFAULT_DASHSCOPE_TTS_REALTIME_URL, voice=spec.voice, model=spec.model or self._DEFAULT_DASHSCOPE_TTS_MODEL, mode=self._resolve_dashscope_mode(spec.mode), sample_rate=spec.sample_rate, speed=spec.speed, ) if provider in _VOLCENGINE_PROVIDERS and spec.api_key: return VolcengineTTSService( api_key=spec.api_key, api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL, voice=spec.voice, model=spec.model, app_id=spec.app_id, resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID, uid=spec.uid, sample_rate=spec.sample_rate, speed=spec.speed, ) if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleTTSService( api_key=spec.api_key, api_url=spec.api_url, voice=spec.voice, model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL, sample_rate=spec.sample_rate, speed=spec.speed, ) logger.warning( "TTS provider unsupported or API key missing (provider={}); using mock TTS", provider or "-", ) return MockTTSService(sample_rate=spec.sample_rate) def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort: provider = self._normalize_provider(spec.provider) if provider in _DASHSCOPE_PROVIDERS and spec.api_key: return DashScopeRealtimeASRService( api_key=spec.api_key, api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL, model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL, sample_rate=spec.sample_rate, language=spec.language, on_transcript=spec.on_transcript, ) if provider in _VOLCENGINE_PROVIDERS and spec.api_key: return VolcengineRealtimeASRService( api_key=spec.api_key, api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL, model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL, sample_rate=spec.sample_rate, language=spec.language, app_id=spec.app_id, resource_id=spec.resource_id, uid=spec.uid, request_params=spec.request_params, on_transcript=spec.on_transcript, ) if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleASRService( api_key=spec.api_key, api_url=spec.api_url, model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL, sample_rate=spec.sample_rate, language=spec.language, enable_interim=spec.enable_interim, interim_interval_ms=spec.interim_interval_ms, min_audio_for_interim_ms=spec.min_audio_for_interim_ms, on_transcript=spec.on_transcript, ) logger.info("Using buffered ASR service (provider={})", provider or "-") return BufferedASRService(sample_rate=spec.sample_rate, language=spec.language)