diff --git a/src/pipecat/services/sarvam/stt.py b/src/pipecat/services/sarvam/stt.py index 164ad289e..799c79921 100644 --- a/src/pipecat/services/sarvam/stt.py +++ b/src/pipecat/services/sarvam/stt.py @@ -6,7 +6,8 @@ can handle multiple audio formats for Indian language speech recognition. """ import base64 -from typing import Optional +from dataclasses import dataclass +from typing import Dict, Literal, Optional from loguru import logger from pydantic import BaseModel @@ -68,6 +69,60 @@ def language_to_sarvam_language(language: Language) -> str: return resolve_language(language, LANGUAGE_MAP, use_base_code=False) +@dataclass(frozen=True) +class ModelConfig: + """Immutable configuration for a Sarvam STT model. + + Attributes: + supports_prompt: Whether the model accepts prompt parameter. + supports_mode: Whether the model accepts mode parameter. + supports_language: Whether the model accepts language parameter. + default_language: Default language code (None = auto-detect). + default_mode: Default mode (None = not applicable). + use_translate_endpoint: Whether to use speech_to_text_translate_streaming endpoint. + use_translate_method: Whether to use translate() method instead of transcribe(). + """ + + supports_prompt: bool + supports_mode: bool + supports_language: bool + default_language: Optional[str] + default_mode: Optional[str] + use_translate_endpoint: bool + use_translate_method: bool + + +MODEL_CONFIGS: Dict[str, ModelConfig] = { + "saarika:v2.5": ModelConfig( + supports_prompt=False, + supports_mode=False, + supports_language=True, + default_language="unknown", + default_mode=None, + use_translate_endpoint=False, + use_translate_method=False, + ), + "saaras:v2.5": ModelConfig( + supports_prompt=True, + supports_mode=False, + supports_language=False, + default_language=None, # Auto-detects language + default_mode=None, + use_translate_endpoint=True, + use_translate_method=True, + ), + "saaras:v3": ModelConfig( + supports_prompt=True, + supports_mode=True, + supports_language=True, + default_language="en-IN", + default_mode="transcribe", + use_translate_endpoint=False, + use_translate_method=False, + ), +} + + class SarvamSTTService(STTService): """Sarvam speech-to-text service. @@ -78,15 +133,21 @@ class SarvamSTTService(STTService): """Configuration parameters for Sarvam STT service. Parameters: - language: Target language for transcription. Defaults to None (required for saarika models). - prompt: Optional prompt to guide translation style/context for STT-Translate models. - Only applicable to saaras (STT-Translate) models. Defaults to None. + language: Target language for transcription. + - saarika:v2.5: Defaults to "unknown" (auto-detect supported) + - saaras:v2.5: Not used (auto-detects language) + - saaras:v3: Defaults to "en-IN" + prompt: Optional prompt to guide transcription/translation style/context. + Only applicable to saaras models (v2.5 and v3). Defaults to None. + mode: Mode of operation for saaras:v3 models only. Options: transcribe, translate, + verbatim, translit, codemix. Defaults to "transcribe" for saaras:v3. vad_signals: Enable VAD signals in response. Defaults to None. high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to None. """ language: Optional[Language] = None prompt: Optional[str] = None + mode: Optional[Literal["transcribe", "translate", "verbatim", "translit", "codemix"]] = None vad_signals: bool = None high_vad_sensitivity: bool = None @@ -104,7 +165,10 @@ class SarvamSTTService(STTService): Args: api_key: Sarvam API key for authentication. - model: Sarvam model to use for transcription. + model: Sarvam model to use for transcription. Allowed values: + - "saarika:v2.5": Standard STT model + - "saaras:v2.5": STT-Translate model (auto-detects language, supports prompts) + - "saaras:v3": Advanced STT model (supports mode and prompts) sample_rate: Audio sample rate. Defaults to 16000 if not specified. input_audio_codec: Audio codec/format of the input file. Defaults to "wav". params: Configuration parameters for Sarvam STT service. @@ -112,36 +176,40 @@ class SarvamSTTService(STTService): """ params = params or SarvamSTTService.InputParams() - # Validate that saaras models don't accept language parameter - if "saaras" in model.lower(): - if params.language is not None: - raise ValueError( - f"Model '{model}' does not accept language parameter. " - "STT-Translate models auto-detect language." - ) + # Get model configuration (validates model exists) + if model not in MODEL_CONFIGS: + allowed = ", ".join(sorted(MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") - # Validate that saarika models don't accept prompt parameter - if "saarika" in model.lower(): - if params.prompt is not None: - raise ValueError( - f"Model '{model}' does not accept prompt parameter. " - "Prompts are only supported for STT-Translate models" - ) + self._config = MODEL_CONFIGS[model] + + # Validate parameters against model capabilities + if params.prompt is not None and not self._config.supports_prompt: + raise ValueError(f"Model '{model}' does not support prompt parameter.") + if params.mode is not None and not self._config.supports_mode: + raise ValueError(f"Model '{model}' does not support mode parameter.") + if params.language is not None and not self._config.supports_language: + raise ValueError( + f"Model '{model}' does not support language parameter (auto-detects language)." + ) super().__init__(sample_rate=sample_rate, **kwargs) self.set_model_name(model) self._api_key = api_key self._language_code: Optional[Language] = params.language - # For saarika models, default to "unknown" if language is not provided + + # Set language string: use provided language or model's default if params.language: self._language_string = language_to_sarvam_language(params.language) - elif "saarika" in model.lower(): - self._language_string = "unknown" else: - self._language_string = None + self._language_string = self._config.default_language + self._prompt = params.prompt + # Set mode: use provided mode or model's default + self._mode = params.mode if params.mode is not None else self._config.default_mode + # Store connection parameters self._vad_signals = params.vad_signals self._high_vad_sensitivity = params.high_vad_sensitivity @@ -203,12 +271,14 @@ class SarvamSTTService(STTService): Args: language: The language to use for speech recognition. + + Raises: + ValueError: If called on a model that auto-detects language. """ - # saaras models do not accept a language parameter - if "saaras" in self.model_name.lower(): + if not self._config.supports_language: raise ValueError( - f"Model '{self.model_name}' (saaras) does not accept language parameter. " - "saaras models auto-detect language." + f"Model '{self.model_name}' does not support language parameter " + "(auto-detects language)." ) logger.info(f"Switching STT language to: [{language}]") @@ -218,24 +288,20 @@ class SarvamSTTService(STTService): await self._connect() async def set_prompt(self, prompt: Optional[str]): - """Set the translation prompt and reconnect. + """Set the transcription/translation prompt and reconnect. Args: - prompt: Prompt text to guide translation style/context. + prompt: Prompt text to guide transcription/translation style/context. Pass None to clear/disable prompt. - Only applicable to STT-Translate models, not STT models. + Only applicable to models that support prompts. """ - # saarika models do not accept prompt parameter - if "saarika" in self.model_name.lower(): + if not self._config.supports_prompt: if prompt is not None: - raise ValueError( - f"Model '{self.model_name}' does not accept prompt parameter. " - "Prompts are only supported for STT-Translate models." - ) - # If prompt is None and it's saarika, just silently return (no-op) + raise ValueError(f"Model '{self.model_name}' does not support prompt parameter.") + # If prompt is None and model doesn't support prompts, silently return (no-op) return - logger.info("Updating STT-Translate prompt.") + logger.info(f"Updating {self.model_name} prompt.") self._prompt = prompt await self._disconnect() await self._connect() @@ -299,13 +365,11 @@ class SarvamSTTService(STTService): "sample_rate": self.sample_rate, } - # Use appropriate method based on service type - if "saarika" in self.model_name.lower(): - # STT service - await self._socket_client.transcribe(**method_kwargs) - else: - # STT-Translate service - auto-detects input language and returns translated text + # Use appropriate method based on model configuration + if self._config.use_translate_method: await self._socket_client.translate(**method_kwargs) + else: + await self._socket_client.transcribe(**method_kwargs) except Exception as e: yield ErrorFrame(error=f"Error sending audio to Sarvam: {e}", exception=e) @@ -326,10 +390,17 @@ class SarvamSTTService(STTService): "model": self.model_name, "vad_signals": vad_signals_str, "high_vad_sensitivity": high_vad_sensitivity_str, - "input_audio_codec": self._input_audio_codec, "sample_rate": str(self.sample_rate), } + # Add language_code for models that support it + if self._language_string is not None: + connect_kwargs["language_code"] = self._language_string + + # Add mode for models that support it + if self._config.supports_mode and self._mode is not None: + connect_kwargs["mode"] = self._mode + def _connect_with_sdk_headers(connect_fn, **kwargs): # Different SDK versions may use different kwarg names. for header_kw in ("headers", "additional_headers", "extra_headers"): @@ -339,26 +410,23 @@ class SarvamSTTService(STTService): pass return connect_fn(**kwargs) - # Choose the appropriate service based on model - if "saarika" in self.model_name.lower(): - # STT service - requires language_code - connect_kwargs["language_code"] = self._language_string + # Choose the appropriate endpoint based on model configuration + if self._config.use_translate_endpoint: self._websocket_context = _connect_with_sdk_headers( - self._sarvam_client.speech_to_text_streaming.connect, + self._sarvam_client.speech_to_text_translate_streaming.connect, **connect_kwargs, ) else: - # STT-Translate service - auto-detects input language and returns translated text self._websocket_context = _connect_with_sdk_headers( - self._sarvam_client.speech_to_text_translate_streaming.connect, + self._sarvam_client.speech_to_text_streaming.connect, **connect_kwargs, ) # Enter the async context manager self._socket_client = await self._websocket_context.__aenter__() - # Set prompt if provided (only for STT-Translate models, after connection) - if self._prompt is not None and "saaras" in self.model_name.lower(): + # Set prompt if provided (only for models that support prompts) + if self._prompt is not None and self._config.supports_prompt: await self._socket_client.set_prompt(self._prompt) # Register event handler for incoming messages diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index cef228b84..5feeffd72 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -4,12 +4,36 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Sarvam AI text-to-speech service implementation.""" +"""Sarvam AI text-to-speech service implementation. + +This module provides TTS services using Sarvam AI's API with support for multiple +Indian languages and two model variants: + +**Model Variants:** + +- **bulbul:v2** (default): Standard TTS model + - Supports: pitch, loudness, pace (0.3-3.0) + - Default sample rate: 22050 Hz + - Speakers: anushka (default), abhilash, manisha, vidya, arya, karun, hitesh + +- **bulbul:v3-beta**: Advanced TTS model with temperature control + - Does NOT support: pitch, loudness + - Supports: pace (0.5-2.0), temperature (0.01-1.0) + - Default sample rate: 24000 Hz + - Preprocessing is always enabled + - Speakers: aditya (default), ritu, priya, neha, rahul, pooja, rohan, simran, + kavya, amit, dev, ishita, shreya, ratan, varun, manan, sumit, roopa, kabir, + aayan, shubh, ashutosh, advait, amelia, sophia + +See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream for full API details. +""" import asyncio import base64 import json -from typing import Any, AsyncGenerator, Mapping, Optional +from dataclasses import dataclass +from enum import Enum +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple import aiohttp from loguru import logger @@ -42,6 +66,138 @@ except ModuleNotFoundError as e: raise Exception(f"Missing module: {e}") +class SarvamTTSModel(str, Enum): + """Available Sarvam TTS models. + + Attributes: + BULBUL_V2: Standard TTS model with pitch/loudness control. + - Supports pitch, loudness, pace (0.3-3.0) + - Default sample rate: 22050 Hz + BULBUL_V3_BETA: Advanced model with temperature control. + - Does NOT support pitch/loudness + - Pace range: 0.5-2.0 + - Supports temperature parameter + - Default sample rate: 24000 Hz + - Preprocessing is always enabled + """ + + BULBUL_V2 = "bulbul:v2" + BULBUL_V3_BETA = "bulbul:v3-beta" + + +class SarvamTTSSpeakerV2(str, Enum): + """Available speakers for bulbul:v2 model. + + Female voices: anushka, manisha, vidya, arya + Male voices: abhilash, karun, hitesh + """ + + ANUSHKA = "anushka" + ABHILASH = "abhilash" + MANISHA = "manisha" + VIDYA = "vidya" + ARYA = "arya" + KARUN = "karun" + HITESH = "hitesh" + + +class SarvamTTSSpeakerV3(str, Enum): + """Available speakers for bulbul:v3-beta model. + + Includes a wider variety of voices with different characteristics. + """ + + ADITYA = "aditya" + RITU = "ritu" + PRIYA = "priya" + NEHA = "neha" + RAHUL = "rahul" + POOJA = "pooja" + ROHAN = "rohan" + SIMRAN = "simran" + KAVYA = "kavya" + AMIT = "amit" + DEV = "dev" + ISHITA = "ishita" + SHREYA = "shreya" + RATAN = "ratan" + VARUN = "varun" + MANAN = "manan" + SUMIT = "sumit" + ROOPA = "roopa" + KABIR = "kabir" + AAYAN = "aayan" + SHUBH = "shubh" + ASHUTOSH = "ashutosh" + ADVAIT = "advait" + AMELIA = "amelia" + SOPHIA = "sophia" + + +@dataclass(frozen=True) +class TTSModelConfig: + """Immutable configuration for a Sarvam TTS model. + + Attributes: + supports_pitch: Whether the model accepts pitch parameter. + supports_loudness: Whether the model accepts loudness parameter. + supports_temperature: Whether the model accepts temperature parameter. + default_sample_rate: Default audio sample rate in Hz. + default_speaker: Default speaker voice ID. + pace_range: Valid range for pace parameter (min, max). + preprocessing_always_enabled: Whether preprocessing is always enabled. + speakers: Tuple of available speaker names for this model. + """ + + supports_pitch: bool + supports_loudness: bool + supports_temperature: bool + default_sample_rate: int + default_speaker: str + pace_range: Tuple[float, float] + preprocessing_always_enabled: bool + speakers: Tuple[str, ...] + + +TTS_MODEL_CONFIGS: Dict[str, TTSModelConfig] = { + "bulbul:v2": TTSModelConfig( + supports_pitch=True, + supports_loudness=True, + supports_temperature=False, + default_sample_rate=22050, + default_speaker="anushka", + pace_range=(0.3, 3.0), + preprocessing_always_enabled=False, + speakers=tuple(s.value for s in SarvamTTSSpeakerV2), + ), + "bulbul:v3-beta": TTSModelConfig( + supports_pitch=False, + supports_loudness=False, + supports_temperature=True, + default_sample_rate=24000, + default_speaker="shubh", + pace_range=(0.5, 2.0), + preprocessing_always_enabled=True, + speakers=tuple(s.value for s in SarvamTTSSpeakerV3), + ), +} + + +def get_speakers_for_model(model: str) -> List[str]: + """Get the list of available speakers for a given model. + + Args: + model: The model name (e.g., "bulbul:v2" or "bulbul:v3-beta"). + + Returns: + List of speaker names available for the model. + """ + if model in TTS_MODEL_CONFIGS: + return list(TTS_MODEL_CONFIGS[model].speakers) + # Default to v2 speakers for unknown models + return list(TTS_MODEL_CONFIGS["bulbul:v2"].speakers) + + def language_to_sarvam_language(language: Language) -> Optional[str]: """Convert Pipecat Language enum to Sarvam AI language codes. @@ -72,11 +228,27 @@ class SarvamHttpTTSService(TTSService): """Text-to-Speech service using Sarvam AI's API. Converts text to speech using Sarvam AI's TTS models with support for multiple - Indian languages. Provides control over voice characteristics like pitch, pace, - and loudness. + Indian languages. Provides control over voice characteristics. + + **Model Differences:** + + - **bulbul:v2** (default): + - Supports: pitch (-0.75 to 0.75), loudness (0.3 to 3.0), pace (0.3 to 3.0) + - Default sample rate: 22050 Hz + - Speakers: anushka, abhilash, manisha, vidya, arya, karun, hitesh + + - **bulbul:v3-beta**: + - Does NOT support: pitch, loudness (will be ignored) + - Supports: pace (0.5 to 2.0), temperature (0.01 to 1.0) + - Default sample rate: 24000 Hz + - Preprocessing is always enabled + - Speakers: aditya, ritu, priya, neha, rahul, pooja, rohan, simran, kavya, + amit, dev, ishita, shreya, ratan, varun, manan, sumit, roopa, kabir, + aayan, shubh, ashutosh, advait, amelia, sophia Example:: + # Using bulbul:v2 (default) tts = SarvamHttpTTSService( api_key="your-api-key", voice_id="anushka", @@ -85,18 +257,20 @@ class SarvamHttpTTSService(TTSService): params=SarvamHttpTTSService.InputParams( language=Language.HI, pitch=0.1, - pace=1.2 + pace=1.2, + loudness=1.5 ) ) - # For bulbul v3 beta with any speaker: + # Using bulbul:v3-beta with temperature control tts_v3 = SarvamHttpTTSService( api_key="your-api-key", - voice_id="speaker_name", - model="bulbul:v3, + voice_id="aditya", # Use v3 speaker + model="bulbul:v3-beta", aiohttp_session=session, params=SarvamHttpTTSService.InputParams( language=Language.HI, + pace=1.2, # Range: 0.5-2.0 for v3 temperature=0.8 ) ) @@ -108,23 +282,47 @@ class SarvamHttpTTSService(TTSService): Parameters: language: Language for synthesis. Defaults to English (India). pitch: Voice pitch adjustment (-0.75 to 0.75). Defaults to 0.0. - pace: Speech pace multiplier (0.3 to 3.0). Defaults to 1.0. - loudness: Volume multiplier (0.1 to 3.0). Defaults to 1.0. + **Note:** Only supported for bulbul:v2. Ignored for v3 models. + pace: Speech pace multiplier. Defaults to 1.0. + - bulbul:v2: Range 0.3 to 3.0 + - bulbul:v3-beta: Range 0.5 to 2.0 + loudness: Volume multiplier (0.3 to 3.0). Defaults to 1.0. + **Note:** Only supported for bulbul:v2. Ignored for v3 models. enable_preprocessing: Whether to enable text preprocessing. Defaults to False. + **Note:** Always enabled for bulbul:v3-beta (cannot be disabled). + temperature: Controls output randomness for bulbul:v3-beta (0.01 to 1.0). + Lower values = more deterministic, higher = more random. Defaults to 0.6. + **Note:** Only supported for bulbul:v3-beta. Ignored for v2. """ language: Optional[Language] = Language.EN - pitch: Optional[float] = Field(default=0.0, ge=-0.75, le=0.75) - pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0) - loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0) - enable_preprocessing: Optional[bool] = False + pitch: Optional[float] = Field( + default=0.0, + ge=-0.75, + le=0.75, + description="Voice pitch adjustment. Only for bulbul:v2.", + ) + pace: Optional[float] = Field( + default=1.0, + ge=0.3, + le=3.0, + description="Speech pace. v2: 0.3-3.0, v3: 0.5-2.0.", + ) + loudness: Optional[float] = Field( + default=1.0, + ge=0.3, + le=3.0, + description="Volume multiplier. Only for bulbul:v2.", + ) + enable_preprocessing: Optional[bool] = Field( + default=False, + description="Enable text preprocessing. Always enabled for v3-beta model.", + ) temperature: Optional[float] = Field( default=0.6, ge=0.01, le=1.0, - description="Controls the randomness of the output for bulbul v3 beta. " - "Lower values make the output more focused and deterministic, while " - "higher values make it more random. Range: 0.01 to 1.0. Default: 0.6.", + description="Output randomness for bulbul:v3-beta only. Range: 0.01-1.0.", ) def __init__( @@ -132,7 +330,7 @@ class SarvamHttpTTSService(TTSService): *, api_key: str, aiohttp_session: aiohttp.ClientSession, - voice_id: str = "anushka", + voice_id: Optional[str] = None, model: str = "bulbul:v2", base_url: str = "https://api.sarvam.ai", sample_rate: Optional[int] = None, @@ -144,46 +342,73 @@ class SarvamHttpTTSService(TTSService): Args: api_key: Sarvam AI API subscription key. aiohttp_session: Shared aiohttp session for making requests. - voice_id: Speaker voice ID (e.g., "anushka", "meera"). Defaults to "anushka". - model: TTS model to use ("bulbul:v2" or "bulbul:v3-beta" or "bulbul:v3"). Defaults to "bulbul:v2". + voice_id: Speaker voice ID. If None, uses model-appropriate default. + model: TTS model to use. Options: + - "bulbul:v2" (default): Standard model with pitch/loudness support + - "bulbul:v3-beta": Advanced model with temperature control base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai". - sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). If None, uses default. + sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). + If None, uses model-specific default. params: Additional voice and preprocessing parameters. If None, uses defaults. **kwargs: Additional arguments passed to parent TTSService. """ + # Get model configuration (validates model exists) + if model not in TTS_MODEL_CONFIGS: + allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") + + self._config = TTS_MODEL_CONFIGS[model] + + # Set default sample rate based on model if not specified + if sample_rate is None: + sample_rate = self._config.default_sample_rate + super().__init__(sample_rate=sample_rate, **kwargs) params = params or SarvamHttpTTSService.InputParams() + # Set default voice based on model if not specified + if voice_id is None: + voice_id = self._config.default_speaker + self._api_key = api_key self._base_url = base_url self._session = aiohttp_session - # Build base settings common to all models + # Validate and clamp pace to model's valid range + pace = params.pace + pace_min, pace_max = self._config.pace_range + if pace is not None and (pace < pace_min or pace > pace_max): + logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.") + pace = max(pace_min, min(pace_max, pace)) + + # Build base settings self._settings = { "language": ( self.language_to_service_language(params.language) if params.language else "en-IN" ), - "enable_preprocessing": params.enable_preprocessing, + "enable_preprocessing": ( + True if self._config.preprocessing_always_enabled else params.enable_preprocessing + ), + "pace": pace, + "model": model, } - # Add model-specific parameters - if model in ("bulbul:v3-beta", "bulbul:v3"): - self._settings.update( - { - "temperature": getattr(params, "temperature", 0.6), - "model": model, - } - ) - else: - self._settings.update( - { - "pitch": params.pitch, - "pace": params.pace, - "loudness": params.loudness, - "model": model, - } - ) + # Add parameters based on model support + if self._config.supports_pitch: + self._settings["pitch"] = params.pitch + elif params.pitch != 0.0: + logger.warning(f"pitch parameter is ignored for {model}") + + if self._config.supports_loudness: + self._settings["loudness"] = params.loudness + elif params.loudness != 1.0: + logger.warning(f"loudness parameter is ignored for {model}") + + if self._config.supports_temperature: + self._settings["temperature"] = params.temperature + elif params.temperature != 0.6: + logger.warning(f"temperature parameter is ignored for {model}") self.set_model_name(model) self.set_voice(voice_id) @@ -231,18 +456,25 @@ class SarvamHttpTTSService(TTSService): try: await self.start_ttfb_metrics() + # Build payload with common parameters payload = { "text": text, "target_language_code": self._settings["language"], "speaker": self._voice_id, - "pitch": self._settings["pitch"], - "pace": self._settings["pace"], - "loudness": self._settings["loudness"], "sample_rate": self.sample_rate, "enable_preprocessing": self._settings["enable_preprocessing"], "model": self._model_name, + "pace": self._settings.get("pace", 1.0), } + # Add model-specific parameters based on config + if self._config.supports_pitch: + payload["pitch"] = self._settings.get("pitch", 0.0) + if self._config.supports_loudness: + payload["loudness"] = self._settings.get("loudness", 1.0) + if self._config.supports_temperature: + payload["temperature"] = self._settings.get("temperature", 0.6) + headers = { "api-subscription-key": self._api_key, "Content-Type": "application/json", @@ -296,10 +528,34 @@ class SarvamTTSService(InterruptibleTTSService): """WebSocket-based text-to-speech service using Sarvam AI. Provides streaming TTS with real-time audio generation for multiple Indian languages. - Supports voice control parameters like pitch, pace, and loudness adjustment. + Uses WebSocket for low-latency streaming audio synthesis. + + **Model Differences:** + + - **bulbul:v2** (default): + - Supports: pitch (-0.75 to 0.75), loudness (0.3 to 3.0), pace (0.3 to 3.0) + - Default sample rate: 22050 Hz + - Speakers: anushka, abhilash, manisha, vidya, arya, karun, hitesh + + - **bulbul:v3-beta** / **bulbul:v3**: + - Does NOT support: pitch, loudness (will be ignored) + - Supports: pace (0.5 to 2.0), temperature (0.01 to 1.0) + - Default sample rate: 24000 Hz + - Preprocessing is always enabled + - Speakers: aditya, ritu, priya, neha, rahul, pooja, rohan, simran, kavya, + amit, dev, ishita, shreya, ratan, varun, manan, sumit, roopa, kabir, + aayan, shubh, ashutosh, advait, amelia, sophia + + **WebSocket Protocol:** + The service uses a WebSocket connection for real-time streaming. Messages include: + - config: Initial configuration with voice settings + - text: Text chunks for synthesis + - flush: Signal to process remaining buffered text + - ping: Keepalive signal Example:: + # Using bulbul:v2 (default) tts = SarvamTTSService( api_key="your-api-key", voice_id="anushka", @@ -307,63 +563,108 @@ class SarvamTTSService(InterruptibleTTSService): params=SarvamTTSService.InputParams( language=Language.HI, pitch=0.1, - pace=1.2 + pace=1.2, + loudness=1.5 ) ) - # For bulbul v3 beta with any speaker and temperature: - # Note: pace and loudness are not supported for bulbul v3 and bulbul v3 beta + # Using bulbul:v3-beta with temperature control tts_v3 = SarvamTTSService( api_key="your-api-key", - voice_id="speaker_name", - model="bulbul:v3", + voice_id="aditya", # Use v3 speaker + model="bulbul:v3-beta", params=SarvamTTSService.InputParams( language=Language.HI, + pace=1.2, # Range: 0.5-2.0 for v3 temperature=0.8 ) ) + + See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream for API details. """ class InputParams(BaseModel): - """Configuration parameters for Sarvam TTS. + """Configuration parameters for Sarvam TTS WebSocket service. Parameters: pitch: Voice pitch adjustment (-0.75 to 0.75). Defaults to 0.0. - pace: Speech pace multiplier (0.3 to 3.0). Defaults to 1.0. - loudness: Volume multiplier (0.1 to 3.0). Defaults to 1.0. + **Note:** Only supported for bulbul:v2. Ignored for v3 models. + pace: Speech pace multiplier. Defaults to 1.0. + - bulbul:v2: Range 0.3 to 3.0 + - bulbul:v3-beta: Range 0.5 to 2.0 + loudness: Volume multiplier (0.3 to 3.0). Defaults to 1.0. + **Note:** Only supported for bulbul:v2. Ignored for v3 models. enable_preprocessing: Enable text preprocessing. Defaults to False. - min_buffer_size: Minimum number of characters to buffer before generating audio. + **Note:** Always enabled for bulbul:v3-beta. + min_buffer_size: Minimum characters to buffer before generating audio. Lower values reduce latency but may affect quality. Defaults to 50. - max_chunk_length: Maximum number of characters processed in a single chunk. - Controls memory usage and processing efficiency. Defaults to 200. - output_audio_codec: Audio codec format. Defaults to "linear16". - output_audio_bitrate: Audio bitrate. Defaults to "128k". - language: Target language for synthesis. Supports Bengali (bn-IN), English (en-IN), - Gujarati (gu-IN), Hindi (hi-IN), Kannada (kn-IN), Malayalam (ml-IN), - Marathi (mr-IN), Odia (od-IN), Punjabi (pa-IN), Tamil (ta-IN), - Telugu (te-IN). Defaults to en-IN. + max_chunk_length: Maximum characters processed in a single chunk. + Controls memory usage and processing efficiency. Defaults to 150. + output_audio_codec: Audio codec format. Options: linear16, mulaw, alaw, + opus, flac, aac, wav, mp3. Defaults to "linear16". + output_audio_bitrate: Audio bitrate (32k, 64k, 96k, 128k, 192k). + Defaults to "128k". + language: Target language for synthesis. Supports Indian languages. + temperature: Controls output randomness for bulbul:v3-beta (0.01 to 1.0). + Lower = more deterministic, higher = more random. Defaults to 0.6. + **Note:** Only supported for bulbul:v3-beta. Ignored for v2. - Available Speakers: - Female: anushka, manisha, vidya, arya - Male: abhilash, karun, hitesh + **Speakers by Model:** + + bulbul:v2: + - Female: anushka (default), manisha, vidya, arya + - Male: abhilash, karun, hitesh + + bulbul:v3-beta: + - aditya (default), ritu, priya, neha, rahul, pooja, rohan, simran, + kavya, amit, dev, ishita, shreya, ratan, varun, manan, sumit, + roopa, kabir, aayan, shubh, ashutosh, advait, amelia, sophia """ - pitch: Optional[float] = Field(default=0.0, ge=-0.75, le=0.75) - pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0) - loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0) - enable_preprocessing: Optional[bool] = False - min_buffer_size: Optional[int] = 50 - max_chunk_length: Optional[int] = 200 - output_audio_codec: Optional[str] = "linear16" - output_audio_bitrate: Optional[str] = "128k" + pitch: Optional[float] = Field( + default=0.0, + ge=-0.75, + le=0.75, + description="Voice pitch adjustment. Only for bulbul:v2.", + ) + pace: Optional[float] = Field( + default=1.0, + ge=0.3, + le=3.0, + description="Speech pace. v2: 0.3-3.0, v3: 0.5-2.0.", + ) + loudness: Optional[float] = Field( + default=1.0, + ge=0.3, + le=3.0, + description="Volume multiplier. Only for bulbul:v2.", + ) + enable_preprocessing: Optional[bool] = Field( + default=False, + description="Enable text preprocessing. Always enabled for v3 models.", + ) + min_buffer_size: Optional[int] = Field( + default=50, + description="Minimum characters to buffer before TTS processing.", + ) + max_chunk_length: Optional[int] = Field( + default=150, + description="Maximum length for sentence splitting.", + ) + output_audio_codec: Optional[str] = Field( + default="linear16", + description="Audio codec: linear16, mulaw, alaw, opus, flac, aac, wav, mp3.", + ) + output_audio_bitrate: Optional[str] = Field( + default="128k", + description="Audio bitrate: 32k, 64k, 96k, 128k, 192k.", + ) language: Optional[Language] = Language.EN temperature: Optional[float] = Field( default=0.6, ge=0.01, le=1.0, - description="Controls the randomness of the output for bulbul v3 beta. " - "Lower values make the output more focused and deterministic, while " - "higher values make it more random. Range: 0.01 to 1.0. Default: 0.6.", + description="Output randomness for bulbul:v3-beta only. Range: 0.01-1.0.", ) def __init__( @@ -371,7 +672,7 @@ class SarvamTTSService(InterruptibleTTSService): *, api_key: str, model: str = "bulbul:v2", - voice_id: str = "anushka", + voice_id: Optional[str] = None, url: str = "wss://api.sarvam.ai/text-to-speech/ws", aggregate_sentences: Optional[bool] = True, sample_rate: Optional[int] = None, @@ -382,20 +683,30 @@ class SarvamTTSService(InterruptibleTTSService): Args: api_key: Sarvam API key for authenticating TTS requests. - model: Identifier of the Sarvam speech model (default "bulbul:v2"). - Supports "bulbul:v2", "bulbul:v3-beta" and "bulbul:v3". - voice_id: Voice identifier for synthesis (default "anushka"). - url: WebSocket URL for connecting to the TTS backend (default production URL). - aggregate_sentences: Whether to merge multiple sentences into one audio chunk (default True). - sample_rate: Desired sample rate for the output audio in Hz (overrides default if set). - params: Optional input parameters to override global configuration. - **kwargs: Optional keyword arguments forwarded to InterruptibleTTSService (such as - `push_stop_frames`, `sample_rate`, task manager parameters, event hooks, etc.) - to customize transport behavior or enable metrics support. + model: TTS model to use. Options: + - "bulbul:v2" (default): Standard model with pitch/loudness support + - "bulbul:v3-beta": Advanced model with temperature control + voice_id: Speaker voice ID. If None, uses model-appropriate default. + url: WebSocket URL for the TTS backend (default production URL). + aggregate_sentences: Merge multiple sentences into one audio chunk (default True). + sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000). + If None, uses model-specific default. + params: Optional input parameters to override defaults. + **kwargs: Arguments forwarded to InterruptibleTTSService. - This method sets up the internal TTS configuration mapping, constructs the WebSocket - URL based on the chosen model, and initializes state flags before connecting. + See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream """ + # Get model configuration (validates model exists) + if model not in TTS_MODEL_CONFIGS: + allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys())) + raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.") + + self._config = TTS_MODEL_CONFIGS[model] + + # Set default sample rate based on model if not specified + if sample_rate is None: + sample_rate = self._config.default_sample_rate + # Initialize parent class first super().__init__( aggregate_sentences=aggregate_sentences, @@ -407,44 +718,58 @@ class SarvamTTSService(InterruptibleTTSService): ) params = params or SarvamTTSService.InputParams() - # WebSocket endpoint URL + # Set default voice based on model if not specified + if voice_id is None: + voice_id = self._config.default_speaker + + # WebSocket endpoint URL with model query parameter self._websocket_url = f"{url}?model={model}" self._api_key = api_key self.set_model_name(model) self.set_voice(voice_id) - # Build base settings common to all models + + # Validate and clamp pace to model's valid range + pace = params.pace + pace_min, pace_max = self._config.pace_range + if pace is not None and (pace < pace_min or pace > pace_max): + logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.") + pace = max(pace_min, min(pace_max, pace)) + + # Build base settings self._settings = { "target_language_code": ( self.language_to_service_language(params.language) if params.language else "en-IN" ), "speaker": voice_id, - "speech_sample_rate": 0, - "enable_preprocessing": params.enable_preprocessing, + "speech_sample_rate": str(sample_rate), + "enable_preprocessing": ( + True if self._config.preprocessing_always_enabled else params.enable_preprocessing + ), "min_buffer_size": params.min_buffer_size, "max_chunk_length": params.max_chunk_length, "output_audio_codec": params.output_audio_codec, "output_audio_bitrate": params.output_audio_bitrate, + "pace": pace, + "model": model, } - # Add model-specific parameters - if model in ("bulbul:v3-beta", "bulbul:v3"): - self._settings.update( - { - "temperature": getattr(params, "temperature", 0.6), - "model": model, - } - ) - else: - self._settings.update( - { - "pitch": params.pitch, - "pace": params.pace, - "loudness": params.loudness, - "model": model, - } - ) - self._started = False + # Add parameters based on model support + if self._config.supports_pitch: + self._settings["pitch"] = params.pitch + elif params.pitch != 0.0: + logger.warning(f"pitch parameter is ignored for {model}") + if self._config.supports_loudness: + self._settings["loudness"] = params.loudness + elif params.loudness != 1.0: + logger.warning(f"loudness parameter is ignored for {model}") + + if self._config.supports_temperature: + self._settings["temperature"] = params.temperature + elif params.temperature != 0.6: + logger.warning(f"temperature parameter is ignored for {model}") + + self._started = False self._receive_task = None self._keepalive_task = None self._disconnecting = False @@ -476,7 +801,8 @@ class SarvamTTSService(InterruptibleTTSService): """ await super().start(frame) - self._settings["speech_sample_rate"] = self.sample_rate + # WebSocket API expects sample rate as string + self._settings["speech_sample_rate"] = str(self.sample_rate) await self._connect() async def stop(self, frame: EndFrame):