173 lines
7.0 KiB
Python
173 lines
7.0 KiB
Python
"""Default runtime service factory implementing core extension ports."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from runtime.ports import (
|
|
ASRPort,
|
|
ASRServiceSpec,
|
|
LLMPort,
|
|
LLMServiceSpec,
|
|
RealtimeServiceFactory,
|
|
TTSPort,
|
|
TTSServiceSpec,
|
|
)
|
|
from providers.asr.buffered import BufferedASRService
|
|
from providers.asr.dashscope import DashScopeRealtimeASRService
|
|
from providers.asr.volcengine import VolcengineRealtimeASRService
|
|
from providers.tts.dashscope import DashScopeTTSService
|
|
from providers.llm.openai import MockLLMService, OpenAILLMService
|
|
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
|
from providers.tts.openai_compatible import OpenAICompatibleTTSService
|
|
from providers.tts.mock import MockTTSService
|
|
from providers.tts.volcengine import VolcengineTTSService
|
|
|
|
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
|
_DASHSCOPE_PROVIDERS = {"dashscope"}
|
|
_VOLCENGINE_PROVIDERS = {"volcengine"}
|
|
_SUPPORTED_LLM_PROVIDERS = {"openai", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS}
|
|
|
|
|
|
class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
|
"""Build concrete runtime services from normalized specs."""
|
|
|
|
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
|
|
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
|
|
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
|
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
|
_DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
|
_DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0"
|
|
_DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
|
_DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel"
|
|
|
|
@staticmethod
|
|
def _normalize_provider(provider: Any) -> str:
|
|
return str(provider or "").strip().lower()
|
|
|
|
@staticmethod
|
|
def _resolve_dashscope_mode(raw_mode: Any) -> str:
|
|
mode = str(raw_mode or "commit").strip().lower()
|
|
if mode in {"commit", "server_commit"}:
|
|
return mode
|
|
return "commit"
|
|
|
|
def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort:
|
|
provider = self._normalize_provider(spec.provider)
|
|
if provider == "fastgpt" and spec.api_key and spec.base_url:
|
|
from providers.llm.fastgpt import FastGPTLLMService
|
|
|
|
return FastGPTLLMService(
|
|
api_key=spec.api_key,
|
|
base_url=spec.base_url,
|
|
app_id=spec.app_id,
|
|
model=spec.model,
|
|
system_prompt=spec.system_prompt,
|
|
)
|
|
|
|
if provider in _SUPPORTED_LLM_PROVIDERS and provider != "fastgpt" and spec.api_key:
|
|
return OpenAILLMService(
|
|
api_key=spec.api_key,
|
|
base_url=spec.base_url,
|
|
model=spec.model,
|
|
system_prompt=spec.system_prompt,
|
|
knowledge_config=spec.knowledge_config,
|
|
knowledge_searcher=spec.knowledge_searcher,
|
|
)
|
|
|
|
logger.warning(
|
|
"LLM provider unsupported or API key missing (provider={}); using mock LLM",
|
|
provider or "-",
|
|
)
|
|
return MockLLMService()
|
|
|
|
def create_tts_service(self, spec: TTSServiceSpec) -> TTSPort:
|
|
provider = self._normalize_provider(spec.provider)
|
|
|
|
if provider == "dashscope" and spec.api_key:
|
|
return DashScopeTTSService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_TTS_REALTIME_URL,
|
|
voice=spec.voice,
|
|
model=spec.model or self._DEFAULT_DASHSCOPE_TTS_MODEL,
|
|
mode=self._resolve_dashscope_mode(spec.mode),
|
|
sample_rate=spec.sample_rate,
|
|
speed=spec.speed,
|
|
)
|
|
|
|
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
|
return VolcengineTTSService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL,
|
|
voice=spec.voice,
|
|
model=spec.model,
|
|
app_id=spec.app_id,
|
|
resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID,
|
|
uid=spec.uid,
|
|
sample_rate=spec.sample_rate,
|
|
speed=spec.speed,
|
|
)
|
|
|
|
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
|
return OpenAICompatibleTTSService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url,
|
|
voice=spec.voice,
|
|
model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL,
|
|
sample_rate=spec.sample_rate,
|
|
speed=spec.speed,
|
|
)
|
|
|
|
logger.warning(
|
|
"TTS provider unsupported or API key missing (provider={}); using mock TTS",
|
|
provider or "-",
|
|
)
|
|
return MockTTSService(sample_rate=spec.sample_rate)
|
|
|
|
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
|
|
provider = self._normalize_provider(spec.provider)
|
|
|
|
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
|
|
return DashScopeRealtimeASRService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
|
|
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
|
|
sample_rate=spec.sample_rate,
|
|
language=spec.language,
|
|
on_transcript=spec.on_transcript,
|
|
)
|
|
|
|
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
|
return VolcengineRealtimeASRService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL,
|
|
model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL,
|
|
sample_rate=spec.sample_rate,
|
|
language=spec.language,
|
|
app_id=spec.app_id,
|
|
resource_id=spec.resource_id,
|
|
uid=spec.uid,
|
|
request_params=spec.request_params,
|
|
on_transcript=spec.on_transcript,
|
|
)
|
|
|
|
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
|
return OpenAICompatibleASRService(
|
|
api_key=spec.api_key,
|
|
api_url=spec.api_url,
|
|
model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL,
|
|
sample_rate=spec.sample_rate,
|
|
language=spec.language,
|
|
enable_interim=spec.enable_interim,
|
|
interim_interval_ms=spec.interim_interval_ms,
|
|
min_audio_for_interim_ms=spec.min_audio_for_interim_ms,
|
|
on_transcript=spec.on_transcript,
|
|
)
|
|
|
|
logger.info("Using buffered ASR service (provider={})", provider or "-")
|
|
return BufferedASRService(sample_rate=spec.sample_rate, language=spec.language)
|