from __future__ import annotations from collections.abc import AsyncGenerator from openai import BadRequestError from openai import NOT_GIVEN from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE from pipecat.services.openai.llm import OpenAILLMService from pipecat.services.openai.stt import OpenAISTTService from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService from pipecat.transcriptions.language import Language from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig from .xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService from .xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService def create_stt_service(config: STTConfig, audio: AudioConfig | None = None): if config.provider == "xfyun": sample_rate = audio.sample_rate_hz if audio else 16000 return XfyunASRService( app_id=config.app_id, api_key=config.api_key or "", api_secret=config.api_secret, url=config.base_url or DEFAULT_XFYUN_ASR_URL, language=config.language or "zh_cn", domain=config.domain, accent=config.accent, sample_rate=sample_rate, encoding=config.encoding, frame_size=config.frame_size, open_timeout=config.timeout_sec, dynamic_correction=config.dynamic_correction, ) _require_provider(config.provider, "openai", "stt") return OpenAISTTService( api_key=config.api_key or None, base_url=config.base_url, settings=OpenAISTTService.Settings( model=config.model, language=_language(config.language), ), ) def create_llm_service(config: LLMConfig): _require_provider(config.provider, "openai", "llm") return OpenAILLMService( api_key=config.api_key or None, base_url=config.base_url, settings=OpenAILLMService.Settings( model=config.model, temperature=config.temperature if config.temperature is not None else NOT_GIVEN, ), ) def create_tts_service(config: TTSConfig, audio: AudioConfig): if config.provider == "xfyun": source_sample_rate = config.source_sample_rate_hz or audio.sample_rate_hz if source_sample_rate not in (8000, 16000): raise ValueError("Xfyun TTS source_sample_rate_hz must be 8000 or 16000") return XfyunTTSService( app_id=config.app_id, api_key=config.api_key or "", api_secret=config.api_secret, voice=config.voice, url=config.base_url or DEFAULT_XFYUN_TTS_URL, sample_rate=audio.sample_rate_hz, source_sample_rate=source_sample_rate, encoding=config.aue, text_encoding=config.tte, speed=config.speed, volume=config.volume, pitch=config.pitch, timeout=config.timeout_sec, ) _require_provider(config.provider, "openai", "tts") service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService return service_class( api_key=config.api_key or None, base_url=config.base_url, sample_rate=audio.sample_rate_hz, source_sample_rate=config.source_sample_rate_hz, settings=OpenAITTSService.Settings( model=config.model, voice=config.voice, ), ) class OpenAICompatibleTTSService(OpenAITTSService): """OpenAI-compatible TTS service that permits provider-specific voice ids.""" def __init__(self, *, source_sample_rate: int | None = None, **kwargs): super().__init__(**kwargs) self._source_sample_rate = source_sample_rate or OPENAI_SAMPLE_RATE async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: voice = self._settings.voice if not voice: yield ErrorFrame(error="TTS voice must be specified") return try: create_params = { "input": text, "model": self._settings.model, "voice": voice, "response_format": "pcm", } if self._settings.instructions: create_params["instructions"] = self._settings.instructions if self._settings.speed: create_params["speed"] = self._settings.speed async with self._client.audio.speech.with_streaming_response.create( **create_params ) as response: if response.status_code != 200: error = await response.text() yield ErrorFrame( error=f"TTS request failed (status: {response.status_code}, error: {error})" ) return await self.start_tts_usage_metrics(text) async def audio_chunks(): async for chunk in response.iter_bytes(self.chunk_size): if chunk: yield chunk first_frame = True async for frame in self._stream_audio_frames_from_iterator( audio_chunks(), in_sample_rate=self._source_sample_rate, context_id=context_id, ): if first_frame: await self.stop_ttfb_metrics() first_frame = False yield frame except BadRequestError as exc: yield ErrorFrame(error=f"TTS request failed: {exc}") except Exception as exc: yield ErrorFrame(error=f"TTS request failed: {exc}") def _require_provider(actual: str, expected: str, service_name: str) -> None: if actual != expected: raise ValueError(f"Unsupported {service_name} provider {actual!r}; expected {expected!r}") def _language(value: str | None) -> Language | None: if value is None: return None normalized = value.replace("-", "_").upper() return getattr(Language, normalized, value)