From 1665ce181ae1400ad518f4bbfbe087c42cf9421b Mon Sep 17 00:00:00 2001 From: dhruvladia-sarvam Date: Tue, 3 Feb 2026 14:33:41 +0530 Subject: [PATCH] refactor(sarvam): centralize model configuration with dataclasses --- src/pipecat/services/sarvam/stt.py | 176 ++++++++++--------- src/pipecat/services/sarvam/tts.py | 266 ++++++++++++++--------------- 2 files changed, 227 insertions(+), 215 deletions(-) diff --git a/src/pipecat/services/sarvam/stt.py b/src/pipecat/services/sarvam/stt.py index ec5e42df0..799c79921 100644 --- a/src/pipecat/services/sarvam/stt.py +++ b/src/pipecat/services/sarvam/stt.py @@ -6,7 +6,8 @@ can handle multiple audio formats for Indian language speech recognition. """ import base64 -from typing import Literal, Optional +from dataclasses import dataclass +from typing import Dict, Literal, Optional from loguru import logger from pydantic import BaseModel @@ -68,6 +69,60 @@ def language_to_sarvam_language(language: Language) -> str: return resolve_language(language, LANGUAGE_MAP, use_base_code=False) +@dataclass(frozen=True) +class ModelConfig: + """Immutable configuration for a Sarvam STT model. + + Attributes: + supports_prompt: Whether the model accepts prompt parameter. + supports_mode: Whether the model accepts mode parameter. + supports_language: Whether the model accepts language parameter. + default_language: Default language code (None = auto-detect). + default_mode: Default mode (None = not applicable). + use_translate_endpoint: Whether to use speech_to_text_translate_streaming endpoint. + use_translate_method: Whether to use translate() method instead of transcribe(). + """ + + supports_prompt: bool + supports_mode: bool + supports_language: bool + default_language: Optional[str] + default_mode: Optional[str] + use_translate_endpoint: bool + use_translate_method: bool + + +MODEL_CONFIGS: Dict[str, ModelConfig] = { + "saarika:v2.5": ModelConfig( + supports_prompt=False, + supports_mode=False, + supports_language=True, + default_language="unknown", + default_mode=None, + use_translate_endpoint=False, + use_translate_method=False, + ), + "saaras:v2.5": ModelConfig( + supports_prompt=True, + supports_mode=False, + supports_language=False, + default_language=None, # Auto-detects language + default_mode=None, + use_translate_endpoint=True, + use_translate_method=True, + ), + "saaras:v3": ModelConfig( + supports_prompt=True, + supports_mode=True, + supports_language=True, + default_language="en-IN", + default_mode="transcribe", + use_translate_endpoint=False, + use_translate_method=False, + ), +} + + class SarvamSTTService(STTService): """Sarvam speech-to-text service. @@ -103,9 +158,6 @@ class SarvamSTTService(STTService): model: str = "saarika:v2.5", sample_rate: Optional[int] = None, input_audio_codec: str = "wav", - mode: Optional[ - Literal["transcribe", "translate", "verbatim", "translit", "codemix"] - ] = None, params: Optional[InputParams] = None, **kwargs, ): @@ -119,73 +171,44 @@ class SarvamSTTService(STTService): - "saaras:v3": Advanced STT model (supports mode and prompts) sample_rate: Audio sample rate. Defaults to 16000 if not specified. input_audio_codec: Audio codec/format of the input file. Defaults to "wav". - mode: Mode of operation for saaras:v3 models only. Options: transcribe, translate, - verbatim, translit, codemix. Defaults to "transcribe" for saaras:v3. params: Configuration parameters for Sarvam STT service. **kwargs: Additional arguments passed to the parent STTService. """ params = params or SarvamSTTService.InputParams() - # Allow mode to be passed directly or via params - if mode is not None and params.mode is None: - params = params.model_copy(update={"mode": mode}) - # Validate allowed models - allowed_models = {"saarika:v2.5", "saaras:v3", "saaras:v2.5"} - if model not in allowed_models: - allowed_models_list = ", ".join(sorted(allowed_models)) - raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed_models_list}.") + # Get model configuration (validates model exists) + if model not in MODEL_CONFIGS: + allowed = ", ".join(sorted(MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") - # Validate model-specific parameter restrictions - if "saarika" in model.lower(): - # saarika models don't accept prompt or mode - if params.prompt is not None: - raise ValueError( - f"Model '{model}' does not accept prompt parameter. " - "Prompts are only supported for saaras models (v2.5 and v3)." - ) - if params.mode is not None: - raise ValueError( - f"Model '{model}' does not accept mode parameter. " - "Mode is only supported for saaras:v3 model." - ) - elif model.lower() == "saaras:v2.5": - # saaras:v2.5 supports prompt but not mode - if params.mode is not None: - raise ValueError( - f"Model '{model}' does not accept mode parameter. " - "Mode is only supported for saaras:v3 model." - ) - if params.language is not None: - raise ValueError( - f"Model '{model}' does not accept language parameter. " - "saaras:v2.5 (STT-Translate) auto-detects language." - ) + self._config = MODEL_CONFIGS[model] + + # Validate parameters against model capabilities + if params.prompt is not None and not self._config.supports_prompt: + raise ValueError(f"Model '{model}' does not support prompt parameter.") + if params.mode is not None and not self._config.supports_mode: + raise ValueError(f"Model '{model}' does not support mode parameter.") + if params.language is not None and not self._config.supports_language: + raise ValueError( + f"Model '{model}' does not support language parameter (auto-detects language)." + ) super().__init__(sample_rate=sample_rate, **kwargs) self.set_model_name(model) self._api_key = api_key self._language_code: Optional[Language] = params.language - # Set language string based on model type - # - saarika:v2.5: uses language_code or defaults to "unknown" - # - saaras:v2.5: auto-detects language (no language_code needed) - # - saaras:v3: uses language_code or defaults to "en-IN" + + # Set language string: use provided language or model's default if params.language: self._language_string = language_to_sarvam_language(params.language) - elif "saarika" in model.lower(): - self._language_string = "unknown" - elif model.lower() == "saaras:v2.5": - self._language_string = None # STT-Translate auto-detects language - elif model.lower() == "saaras:v3": - self._language_string = "en-IN" else: - self._language_string = None + self._language_string = self._config.default_language + self._prompt = params.prompt - # Set mode for saaras:v3, default to "transcribe" - if model.lower() == "saaras:v3": - self._mode = params.mode if params.mode is not None else "transcribe" - else: - self._mode = None + + # Set mode: use provided mode or model's default + self._mode = params.mode if params.mode is not None else self._config.default_mode # Store connection parameters self._vad_signals = params.vad_signals @@ -250,13 +273,12 @@ class SarvamSTTService(STTService): language: The language to use for speech recognition. Raises: - ValueError: If called on saaras:v2.5 model which auto-detects language. + ValueError: If called on a model that auto-detects language. """ - # saaras:v2.5 (STT-Translate) auto-detects language - if self.model_name.lower() == "saaras:v2.5": + if not self._config.supports_language: raise ValueError( - f"Model '{self.model_name}' does not accept language parameter. " - "saaras:v2.5 (STT-Translate) auto-detects language." + f"Model '{self.model_name}' does not support language parameter " + "(auto-detects language)." ) logger.info(f"Switching STT language to: [{language}]") @@ -271,16 +293,12 @@ class SarvamSTTService(STTService): Args: prompt: Prompt text to guide transcription/translation style/context. Pass None to clear/disable prompt. - Only applicable to saaras models (v2.5 and v3). + Only applicable to models that support prompts. """ - # saarika models do not accept prompt parameter - if "saarika" in self.model_name.lower(): + if not self._config.supports_prompt: if prompt is not None: - raise ValueError( - f"Model '{self.model_name}' does not accept prompt parameter. " - "Prompts are only supported for saaras models (v2.5 and v3)." - ) - # If prompt is None and it's saarika, just silently return (no-op) + raise ValueError(f"Model '{self.model_name}' does not support prompt parameter.") + # If prompt is None and model doesn't support prompts, silently return (no-op) return logger.info(f"Updating {self.model_name} prompt.") @@ -347,12 +365,10 @@ class SarvamSTTService(STTService): "sample_rate": self.sample_rate, } - # Use appropriate method based on model type - if self.model_name.lower() == "saaras:v2.5": - # STT-Translate: auto-detects input language and returns translated text + # Use appropriate method based on model configuration + if self._config.use_translate_method: await self._socket_client.translate(**method_kwargs) else: - # saarika:v2.5 and saaras:v3 use transcribe await self._socket_client.transcribe(**method_kwargs) except Exception as e: @@ -377,12 +393,12 @@ class SarvamSTTService(STTService): "sample_rate": str(self.sample_rate), } - # Add language_code for models that require it (not saaras:v2.5 which auto-detects) + # Add language_code for models that support it if self._language_string is not None: connect_kwargs["language_code"] = self._language_string - # Add mode for saaras:v3 only - if self.model_name.lower() == "saaras:v3" and self._mode is not None: + # Add mode for models that support it + if self._config.supports_mode and self._mode is not None: connect_kwargs["mode"] = self._mode def _connect_with_sdk_headers(connect_fn, **kwargs): @@ -394,15 +410,13 @@ class SarvamSTTService(STTService): pass return connect_fn(**kwargs) - # Choose the appropriate endpoint based on model - if self.model_name.lower() == "saaras:v2.5": - # STT-Translate: auto-detects input language and returns translated text + # Choose the appropriate endpoint based on model configuration + if self._config.use_translate_endpoint: self._websocket_context = _connect_with_sdk_headers( self._sarvam_client.speech_to_text_translate_streaming.connect, **connect_kwargs, ) else: - # saarika:v2.5 and saaras:v3 use speech_to_text_streaming self._websocket_context = _connect_with_sdk_headers( self._sarvam_client.speech_to_text_streaming.connect, **connect_kwargs, @@ -411,8 +425,8 @@ class SarvamSTTService(STTService): # Enter the async context manager self._socket_client = await self._websocket_context.__aenter__() - # Set prompt if provided (only for saaras models, after connection) - if self._prompt is not None and "saaras" in self.model_name.lower(): + # Set prompt if provided (only for models that support prompts) + if self._prompt is not None and self._config.supports_prompt: await self._socket_client.set_prompt(self._prompt) # Register event handler for incoming messages diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index bd63558a8..b6819185a 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -31,8 +31,9 @@ See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream for full API import asyncio import base64 import json +from dataclasses import dataclass from enum import Enum -from typing import Any, AsyncGenerator, List, Mapping, Optional +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple import aiohttp from loguru import logger @@ -133,16 +134,52 @@ class SarvamTTSSpeakerV3(str, Enum): SOPHIA = "sophia" -# Default sample rates per model -SARVAM_DEFAULT_SAMPLE_RATES = { - SarvamTTSModel.BULBUL_V2: 22050, - SarvamTTSModel.BULBUL_V3_BETA: 24000, -} +@dataclass(frozen=True) +class TTSModelConfig: + """Immutable configuration for a Sarvam TTS model. -# Default speakers per model -SARVAM_DEFAULT_SPEAKERS = { - SarvamTTSModel.BULBUL_V2: SarvamTTSSpeakerV2.ANUSHKA.value, - SarvamTTSModel.BULBUL_V3_BETA: SarvamTTSSpeakerV3.ADITYA.value, + Attributes: + supports_pitch: Whether the model accepts pitch parameter. + supports_loudness: Whether the model accepts loudness parameter. + supports_temperature: Whether the model accepts temperature parameter. + default_sample_rate: Default audio sample rate in Hz. + default_speaker: Default speaker voice ID. + pace_range: Valid range for pace parameter (min, max). + preprocessing_always_enabled: Whether preprocessing is always enabled. + speakers: Tuple of available speaker names for this model. + """ + + supports_pitch: bool + supports_loudness: bool + supports_temperature: bool + default_sample_rate: int + default_speaker: str + pace_range: Tuple[float, float] + preprocessing_always_enabled: bool + speakers: Tuple[str, ...] + + +TTS_MODEL_CONFIGS: Dict[str, TTSModelConfig] = { + "bulbul:v2": TTSModelConfig( + supports_pitch=True, + supports_loudness=True, + supports_temperature=False, + default_sample_rate=22050, + default_speaker="anushka", + pace_range=(0.3, 3.0), + preprocessing_always_enabled=False, + speakers=tuple(s.value for s in SarvamTTSSpeakerV2), + ), + "bulbul:v3-beta": TTSModelConfig( + supports_pitch=False, + supports_loudness=False, + supports_temperature=True, + default_sample_rate=24000, + default_speaker="aditya", + pace_range=(0.5, 2.0), + preprocessing_always_enabled=True, + speakers=tuple(s.value for s in SarvamTTSSpeakerV3), + ), } @@ -155,9 +192,10 @@ def get_speakers_for_model(model: str) -> List[str]: Returns: List of speaker names available for the model. """ - if model in (SarvamTTSModel.BULBUL_V3_BETA.value): - return [s.value for s in SarvamTTSSpeakerV3] - return [s.value for s in SarvamTTSSpeakerV2] + if model in TTS_MODEL_CONFIGS: + return list(TTS_MODEL_CONFIGS[model].speakers) + # Default to v2 speakers for unknown models + return list(TTS_MODEL_CONFIGS["bulbul:v2"].speakers) def language_to_sarvam_language(language: Language) -> Optional[str]: @@ -304,35 +342,26 @@ class SarvamHttpTTSService(TTSService): Args: api_key: Sarvam AI API subscription key. aiohttp_session: Shared aiohttp session for making requests. - voice_id: Speaker voice ID. If None, uses model-appropriate default: - - bulbul:v2: "anushka" - - bulbul:v3-beta/v3: "aditya" + voice_id: Speaker voice ID. If None, uses model-appropriate default. model: TTS model to use. Options: - "bulbul:v2" (default): Standard model with pitch/loudness support - "bulbul:v3-beta": Advanced model with temperature control - - "bulbul:v3": Alias for v3-beta base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai". sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). - If None, uses model-specific default (v2: 22050, v3: 24000). + If None, uses model-specific default. params: Additional voice and preprocessing parameters. If None, uses defaults. **kwargs: Additional arguments passed to parent TTSService. - - Note: - When using bulbul:v3-beta: - - pitch and loudness parameters are ignored - - pace range is limited to 0.5-2.0 - - preprocessing is always enabled - - use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu") """ - # Determine if using v3 model - is_v3_model = model in ( - SarvamTTSModel.BULBUL_V3_BETA.value, - "bulbul:v3-beta", - ) + # Get model configuration (validates model exists) + if model not in TTS_MODEL_CONFIGS: + allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") + + self._config = TTS_MODEL_CONFIGS[model] # Set default sample rate based on model if not specified if sample_rate is None: - sample_rate = 24000 if is_v3_model else 22050 + sample_rate = self._config.default_sample_rate super().__init__(sample_rate=sample_rate, **kwargs) @@ -340,55 +369,46 @@ class SarvamHttpTTSService(TTSService): # Set default voice based on model if not specified if voice_id is None: - voice_id = "aditya" if is_v3_model else "anushka" + voice_id = self._config.default_speaker self._api_key = api_key self._base_url = base_url self._session = aiohttp_session - self._is_v3_model = is_v3_model - # Build base settings common to all models + # Validate and clamp pace to model's valid range + pace = params.pace + pace_min, pace_max = self._config.pace_range + if pace is not None and (pace < pace_min or pace > pace_max): + logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.") + pace = max(pace_min, min(pace_max, pace)) + + # Build base settings self._settings = { "language": ( self.language_to_service_language(params.language) if params.language else "en-IN" ), - "enable_preprocessing": params.enable_preprocessing if not is_v3_model else True, + "enable_preprocessing": ( + True if self._config.preprocessing_always_enabled else params.enable_preprocessing + ), + "pace": pace, + "model": model, } - # Add model-specific parameters - if is_v3_model: - # Validate pace for v3 (0.5-2.0) - pace = params.pace - if pace is not None and (pace < 0.5 or pace > 2.0): - logger.warning( - f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range." - ) - pace = max(0.5, min(2.0, pace)) + # Add parameters based on model support + if self._config.supports_pitch: + self._settings["pitch"] = params.pitch + elif params.pitch != 0.0: + logger.warning(f"pitch parameter is ignored for {model}") - self._settings.update( - { - "temperature": params.temperature, - "pace": pace, - "model": model, - } - ) - # Log warning if v2-only parameters are set - if params.pitch != 0.0: - logger.warning(f"pitch parameter is ignored for {model}") - if params.loudness != 1.0: - logger.warning(f"loudness parameter is ignored for {model}") - else: - self._settings.update( - { - "pitch": params.pitch, - "pace": params.pace, - "loudness": params.loudness, - "model": model, - } - ) - # Log warning if v3-only parameters are set - if params.temperature != 0.6: - logger.warning(f"temperature parameter is ignored for {model}") + if self._config.supports_loudness: + self._settings["loudness"] = params.loudness + elif params.loudness != 1.0: + logger.warning(f"loudness parameter is ignored for {model}") + + if self._config.supports_temperature: + self._settings["temperature"] = params.temperature + elif params.temperature != 0.6: + logger.warning(f"temperature parameter is ignored for {model}") self.set_model_name(model) self.set_voice(voice_id) @@ -436,7 +456,7 @@ class SarvamHttpTTSService(TTSService): try: await self.start_ttfb_metrics() - # Build payload based on model type + # Build payload with common parameters payload = { "text": text, "target_language_code": self._settings["language"], @@ -444,19 +464,16 @@ class SarvamHttpTTSService(TTSService): "sample_rate": self.sample_rate, "enable_preprocessing": self._settings["enable_preprocessing"], "model": self._model_name, + "pace": self._settings.get("pace", 1.0), } - # Add model-specific parameters - if self._is_v3_model: - # v3 models use temperature and pace (0.5-2.0) - payload["temperature"] = self._settings.get("temperature", 0.6) - if "pace" in self._settings: - payload["pace"] = self._settings["pace"] - else: - # v2 models use pitch, pace, loudness + # Add model-specific parameters based on config + if self._config.supports_pitch: payload["pitch"] = self._settings.get("pitch", 0.0) - payload["pace"] = self._settings.get("pace", 1.0) + if self._config.supports_loudness: payload["loudness"] = self._settings.get("loudness", 1.0) + if self._config.supports_temperature: + payload["temperature"] = self._settings.get("temperature", 0.6) headers = { "api-subscription-key": self._api_key, @@ -669,35 +686,26 @@ class SarvamTTSService(InterruptibleTTSService): model: TTS model to use. Options: - "bulbul:v2" (default): Standard model with pitch/loudness support - "bulbul:v3-beta": Advanced model with temperature control - - "bulbul:v3": Alias for v3-beta - voice_id: Speaker voice ID. If None, uses model-appropriate default: - - bulbul:v2: "anushka" - - bulbul:v3-beta/v3: "aditya" + voice_id: Speaker voice ID. If None, uses model-appropriate default. url: WebSocket URL for the TTS backend (default production URL). aggregate_sentences: Merge multiple sentences into one audio chunk (default True). sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000). - If None, uses model-specific default (v2: 22050, v3: 24000). + If None, uses model-specific default. params: Optional input parameters to override defaults. **kwargs: Arguments forwarded to InterruptibleTTSService. - Note: - When using bulbul:v3-beta: - - pitch and loudness parameters are ignored - - pace range is limited to 0.5-2.0 - - preprocessing is always enabled - - use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu") - See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream """ - # Determine if using v3 model - is_v3_model = model in ( - SarvamTTSModel.BULBUL_V3_BETA.value, - "bulbul:v3-beta", - ) + # Get model configuration (validates model exists) + if model not in TTS_MODEL_CONFIGS: + allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") + + self._config = TTS_MODEL_CONFIGS[model] # Set default sample rate based on model if not specified if sample_rate is None: - sample_rate = 24000 if is_v3_model else 22050 + sample_rate = self._config.default_sample_rate # Initialize parent class first super().__init__( @@ -712,9 +720,7 @@ class SarvamTTSService(InterruptibleTTSService): # Set default voice based on model if not specified if voice_id is None: - voice_id = "aditya" if is_v3_model else "anushka" - - self._is_v3_model = is_v3_model + voice_id = self._config.default_speaker # WebSocket endpoint URL with model query parameter self._websocket_url = f"{url}?model={model}" @@ -722,54 +728,46 @@ class SarvamTTSService(InterruptibleTTSService): self.set_model_name(model) self.set_voice(voice_id) - # Build base settings common to all models + # Validate and clamp pace to model's valid range + pace = params.pace + pace_min, pace_max = self._config.pace_range + if pace is not None and (pace < pace_min or pace > pace_max): + logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.") + pace = max(pace_min, min(pace_max, pace)) + + # Build base settings self._settings = { "target_language_code": ( self.language_to_service_language(params.language) if params.language else "en-IN" ), "speaker": voice_id, "speech_sample_rate": str(sample_rate), - "enable_preprocessing": params.enable_preprocessing if not is_v3_model else True, + "enable_preprocessing": ( + True if self._config.preprocessing_always_enabled else params.enable_preprocessing + ), "min_buffer_size": params.min_buffer_size, "max_chunk_length": params.max_chunk_length, "output_audio_codec": params.output_audio_codec, "output_audio_bitrate": params.output_audio_bitrate, + "pace": pace, + "model": model, } - # Add model-specific parameters - if is_v3_model: - # Validate pace for v3 (0.5-2.0) - pace = params.pace - if pace is not None and (pace < 0.5 or pace > 2.0): - logger.warning( - f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range." - ) - pace = max(0.5, min(2.0, pace)) + # Add parameters based on model support + if self._config.supports_pitch: + self._settings["pitch"] = params.pitch + elif params.pitch != 0.0: + logger.warning(f"pitch parameter is ignored for {model}") - self._settings.update( - { - "temperature": params.temperature, - "pace": pace, - "model": model, - } - ) - # Log warning if v2-only parameters are set - if params.pitch != 0.0: - logger.warning(f"pitch parameter is ignored for {model}") - if params.loudness != 1.0: - logger.warning(f"loudness parameter is ignored for {model}") - else: - self._settings.update( - { - "pitch": params.pitch, - "pace": params.pace, - "loudness": params.loudness, - "model": model, - } - ) - # Log warning if v3-only parameters are set - if params.temperature != 0.6: - logger.warning(f"temperature parameter is ignored for {model}") + if self._config.supports_loudness: + self._settings["loudness"] = params.loudness + elif params.loudness != 1.0: + logger.warning(f"loudness parameter is ignored for {model}") + + if self._config.supports_temperature: + self._settings["temperature"] = params.temperature + elif params.temperature != 0.6: + logger.warning(f"temperature parameter is ignored for {model}") self._started = False self._receive_task = None