refactor(sarvam): centralize model configuration with dataclasses

This commit is contained in:
dhruvladia-sarvam
2026-02-03 14:33:41 +05:30
parent 57821cf709
commit 1665ce181a
2 changed files with 227 additions and 215 deletions

View File

@@ -6,7 +6,8 @@ can handle multiple audio formats for Indian language speech recognition.
"""
import base64
from typing import Literal, Optional
from dataclasses import dataclass
from typing import Dict, Literal, Optional
from loguru import logger
from pydantic import BaseModel
@@ -68,6 +69,60 @@ def language_to_sarvam_language(language: Language) -> str:
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
@dataclass(frozen=True)
class ModelConfig:
"""Immutable configuration for a Sarvam STT model.
Attributes:
supports_prompt: Whether the model accepts prompt parameter.
supports_mode: Whether the model accepts mode parameter.
supports_language: Whether the model accepts language parameter.
default_language: Default language code (None = auto-detect).
default_mode: Default mode (None = not applicable).
use_translate_endpoint: Whether to use speech_to_text_translate_streaming endpoint.
use_translate_method: Whether to use translate() method instead of transcribe().
"""
supports_prompt: bool
supports_mode: bool
supports_language: bool
default_language: Optional[str]
default_mode: Optional[str]
use_translate_endpoint: bool
use_translate_method: bool
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"saarika:v2.5": ModelConfig(
supports_prompt=False,
supports_mode=False,
supports_language=True,
default_language="unknown",
default_mode=None,
use_translate_endpoint=False,
use_translate_method=False,
),
"saaras:v2.5": ModelConfig(
supports_prompt=True,
supports_mode=False,
supports_language=False,
default_language=None, # Auto-detects language
default_mode=None,
use_translate_endpoint=True,
use_translate_method=True,
),
"saaras:v3": ModelConfig(
supports_prompt=True,
supports_mode=True,
supports_language=True,
default_language="en-IN",
default_mode="transcribe",
use_translate_endpoint=False,
use_translate_method=False,
),
}
class SarvamSTTService(STTService):
"""Sarvam speech-to-text service.
@@ -103,9 +158,6 @@ class SarvamSTTService(STTService):
model: str = "saarika:v2.5",
sample_rate: Optional[int] = None,
input_audio_codec: str = "wav",
mode: Optional[
Literal["transcribe", "translate", "verbatim", "translit", "codemix"]
] = None,
params: Optional[InputParams] = None,
**kwargs,
):
@@ -119,73 +171,44 @@ class SarvamSTTService(STTService):
- "saaras:v3": Advanced STT model (supports mode and prompts)
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
mode: Mode of operation for saaras:v3 models only. Options: transcribe, translate,
verbatim, translit, codemix. Defaults to "transcribe" for saaras:v3.
params: Configuration parameters for Sarvam STT service.
**kwargs: Additional arguments passed to the parent STTService.
"""
params = params or SarvamSTTService.InputParams()
# Allow mode to be passed directly or via params
if mode is not None and params.mode is None:
params = params.model_copy(update={"mode": mode})
# Validate allowed models
allowed_models = {"saarika:v2.5", "saaras:v3", "saaras:v2.5"}
if model not in allowed_models:
allowed_models_list = ", ".join(sorted(allowed_models))
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed_models_list}.")
# Get model configuration (validates model exists)
if model not in MODEL_CONFIGS:
allowed = ", ".join(sorted(MODEL_CONFIGS.keys()))
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
# Validate model-specific parameter restrictions
if "saarika" in model.lower():
# saarika models don't accept prompt or mode
if params.prompt is not None:
raise ValueError(
f"Model '{model}' does not accept prompt parameter. "
"Prompts are only supported for saaras models (v2.5 and v3)."
)
if params.mode is not None:
raise ValueError(
f"Model '{model}' does not accept mode parameter. "
"Mode is only supported for saaras:v3 model."
)
elif model.lower() == "saaras:v2.5":
# saaras:v2.5 supports prompt but not mode
if params.mode is not None:
raise ValueError(
f"Model '{model}' does not accept mode parameter. "
"Mode is only supported for saaras:v3 model."
)
if params.language is not None:
raise ValueError(
f"Model '{model}' does not accept language parameter. "
"saaras:v2.5 (STT-Translate) auto-detects language."
)
self._config = MODEL_CONFIGS[model]
# Validate parameters against model capabilities
if params.prompt is not None and not self._config.supports_prompt:
raise ValueError(f"Model '{model}' does not support prompt parameter.")
if params.mode is not None and not self._config.supports_mode:
raise ValueError(f"Model '{model}' does not support mode parameter.")
if params.language is not None and not self._config.supports_language:
raise ValueError(
f"Model '{model}' does not support language parameter (auto-detects language)."
)
super().__init__(sample_rate=sample_rate, **kwargs)
self.set_model_name(model)
self._api_key = api_key
self._language_code: Optional[Language] = params.language
# Set language string based on model type
# - saarika:v2.5: uses language_code or defaults to "unknown"
# - saaras:v2.5: auto-detects language (no language_code needed)
# - saaras:v3: uses language_code or defaults to "en-IN"
# Set language string: use provided language or model's default
if params.language:
self._language_string = language_to_sarvam_language(params.language)
elif "saarika" in model.lower():
self._language_string = "unknown"
elif model.lower() == "saaras:v2.5":
self._language_string = None # STT-Translate auto-detects language
elif model.lower() == "saaras:v3":
self._language_string = "en-IN"
else:
self._language_string = None
self._language_string = self._config.default_language
self._prompt = params.prompt
# Set mode for saaras:v3, default to "transcribe"
if model.lower() == "saaras:v3":
self._mode = params.mode if params.mode is not None else "transcribe"
else:
self._mode = None
# Set mode: use provided mode or model's default
self._mode = params.mode if params.mode is not None else self._config.default_mode
# Store connection parameters
self._vad_signals = params.vad_signals
@@ -250,13 +273,12 @@ class SarvamSTTService(STTService):
language: The language to use for speech recognition.
Raises:
ValueError: If called on saaras:v2.5 model which auto-detects language.
ValueError: If called on a model that auto-detects language.
"""
# saaras:v2.5 (STT-Translate) auto-detects language
if self.model_name.lower() == "saaras:v2.5":
if not self._config.supports_language:
raise ValueError(
f"Model '{self.model_name}' does not accept language parameter. "
"saaras:v2.5 (STT-Translate) auto-detects language."
f"Model '{self.model_name}' does not support language parameter "
"(auto-detects language)."
)
logger.info(f"Switching STT language to: [{language}]")
@@ -271,16 +293,12 @@ class SarvamSTTService(STTService):
Args:
prompt: Prompt text to guide transcription/translation style/context.
Pass None to clear/disable prompt.
Only applicable to saaras models (v2.5 and v3).
Only applicable to models that support prompts.
"""
# saarika models do not accept prompt parameter
if "saarika" in self.model_name.lower():
if not self._config.supports_prompt:
if prompt is not None:
raise ValueError(
f"Model '{self.model_name}' does not accept prompt parameter. "
"Prompts are only supported for saaras models (v2.5 and v3)."
)
# If prompt is None and it's saarika, just silently return (no-op)
raise ValueError(f"Model '{self.model_name}' does not support prompt parameter.")
# If prompt is None and model doesn't support prompts, silently return (no-op)
return
logger.info(f"Updating {self.model_name} prompt.")
@@ -347,12 +365,10 @@ class SarvamSTTService(STTService):
"sample_rate": self.sample_rate,
}
# Use appropriate method based on model type
if self.model_name.lower() == "saaras:v2.5":
# STT-Translate: auto-detects input language and returns translated text
# Use appropriate method based on model configuration
if self._config.use_translate_method:
await self._socket_client.translate(**method_kwargs)
else:
# saarika:v2.5 and saaras:v3 use transcribe
await self._socket_client.transcribe(**method_kwargs)
except Exception as e:
@@ -377,12 +393,12 @@ class SarvamSTTService(STTService):
"sample_rate": str(self.sample_rate),
}
# Add language_code for models that require it (not saaras:v2.5 which auto-detects)
# Add language_code for models that support it
if self._language_string is not None:
connect_kwargs["language_code"] = self._language_string
# Add mode for saaras:v3 only
if self.model_name.lower() == "saaras:v3" and self._mode is not None:
# Add mode for models that support it
if self._config.supports_mode and self._mode is not None:
connect_kwargs["mode"] = self._mode
def _connect_with_sdk_headers(connect_fn, **kwargs):
@@ -394,15 +410,13 @@ class SarvamSTTService(STTService):
pass
return connect_fn(**kwargs)
# Choose the appropriate endpoint based on model
if self.model_name.lower() == "saaras:v2.5":
# STT-Translate: auto-detects input language and returns translated text
# Choose the appropriate endpoint based on model configuration
if self._config.use_translate_endpoint:
self._websocket_context = _connect_with_sdk_headers(
self._sarvam_client.speech_to_text_translate_streaming.connect,
**connect_kwargs,
)
else:
# saarika:v2.5 and saaras:v3 use speech_to_text_streaming
self._websocket_context = _connect_with_sdk_headers(
self._sarvam_client.speech_to_text_streaming.connect,
**connect_kwargs,
@@ -411,8 +425,8 @@ class SarvamSTTService(STTService):
# Enter the async context manager
self._socket_client = await self._websocket_context.__aenter__()
# Set prompt if provided (only for saaras models, after connection)
if self._prompt is not None and "saaras" in self.model_name.lower():
# Set prompt if provided (only for models that support prompts)
if self._prompt is not None and self._config.supports_prompt:
await self._socket_client.set_prompt(self._prompt)
# Register event handler for incoming messages

View File

@@ -31,8 +31,9 @@ See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream for full API
import asyncio
import base64
import json
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncGenerator, List, Mapping, Optional
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple
import aiohttp
from loguru import logger
@@ -133,16 +134,52 @@ class SarvamTTSSpeakerV3(str, Enum):
SOPHIA = "sophia"
# Default sample rates per model
SARVAM_DEFAULT_SAMPLE_RATES = {
SarvamTTSModel.BULBUL_V2: 22050,
SarvamTTSModel.BULBUL_V3_BETA: 24000,
}
@dataclass(frozen=True)
class TTSModelConfig:
"""Immutable configuration for a Sarvam TTS model.
# Default speakers per model
SARVAM_DEFAULT_SPEAKERS = {
SarvamTTSModel.BULBUL_V2: SarvamTTSSpeakerV2.ANUSHKA.value,
SarvamTTSModel.BULBUL_V3_BETA: SarvamTTSSpeakerV3.ADITYA.value,
Attributes:
supports_pitch: Whether the model accepts pitch parameter.
supports_loudness: Whether the model accepts loudness parameter.
supports_temperature: Whether the model accepts temperature parameter.
default_sample_rate: Default audio sample rate in Hz.
default_speaker: Default speaker voice ID.
pace_range: Valid range for pace parameter (min, max).
preprocessing_always_enabled: Whether preprocessing is always enabled.
speakers: Tuple of available speaker names for this model.
"""
supports_pitch: bool
supports_loudness: bool
supports_temperature: bool
default_sample_rate: int
default_speaker: str
pace_range: Tuple[float, float]
preprocessing_always_enabled: bool
speakers: Tuple[str, ...]
TTS_MODEL_CONFIGS: Dict[str, TTSModelConfig] = {
"bulbul:v2": TTSModelConfig(
supports_pitch=True,
supports_loudness=True,
supports_temperature=False,
default_sample_rate=22050,
default_speaker="anushka",
pace_range=(0.3, 3.0),
preprocessing_always_enabled=False,
speakers=tuple(s.value for s in SarvamTTSSpeakerV2),
),
"bulbul:v3-beta": TTSModelConfig(
supports_pitch=False,
supports_loudness=False,
supports_temperature=True,
default_sample_rate=24000,
default_speaker="aditya",
pace_range=(0.5, 2.0),
preprocessing_always_enabled=True,
speakers=tuple(s.value for s in SarvamTTSSpeakerV3),
),
}
@@ -155,9 +192,10 @@ def get_speakers_for_model(model: str) -> List[str]:
Returns:
List of speaker names available for the model.
"""
if model in (SarvamTTSModel.BULBUL_V3_BETA.value):
return [s.value for s in SarvamTTSSpeakerV3]
return [s.value for s in SarvamTTSSpeakerV2]
if model in TTS_MODEL_CONFIGS:
return list(TTS_MODEL_CONFIGS[model].speakers)
# Default to v2 speakers for unknown models
return list(TTS_MODEL_CONFIGS["bulbul:v2"].speakers)
def language_to_sarvam_language(language: Language) -> Optional[str]:
@@ -304,35 +342,26 @@ class SarvamHttpTTSService(TTSService):
Args:
api_key: Sarvam AI API subscription key.
aiohttp_session: Shared aiohttp session for making requests.
voice_id: Speaker voice ID. If None, uses model-appropriate default:
- bulbul:v2: "anushka"
- bulbul:v3-beta/v3: "aditya"
voice_id: Speaker voice ID. If None, uses model-appropriate default.
model: TTS model to use. Options:
- "bulbul:v2" (default): Standard model with pitch/loudness support
- "bulbul:v3-beta": Advanced model with temperature control
- "bulbul:v3": Alias for v3-beta
base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai".
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000).
If None, uses model-specific default (v2: 22050, v3: 24000).
If None, uses model-specific default.
params: Additional voice and preprocessing parameters. If None, uses defaults.
**kwargs: Additional arguments passed to parent TTSService.
Note:
When using bulbul:v3-beta:
- pitch and loudness parameters are ignored
- pace range is limited to 0.5-2.0
- preprocessing is always enabled
- use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu")
"""
# Determine if using v3 model
is_v3_model = model in (
SarvamTTSModel.BULBUL_V3_BETA.value,
"bulbul:v3-beta",
)
# Get model configuration (validates model exists)
if model not in TTS_MODEL_CONFIGS:
allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys()))
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
self._config = TTS_MODEL_CONFIGS[model]
# Set default sample rate based on model if not specified
if sample_rate is None:
sample_rate = 24000 if is_v3_model else 22050
sample_rate = self._config.default_sample_rate
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -340,55 +369,46 @@ class SarvamHttpTTSService(TTSService):
# Set default voice based on model if not specified
if voice_id is None:
voice_id = "aditya" if is_v3_model else "anushka"
voice_id = self._config.default_speaker
self._api_key = api_key
self._base_url = base_url
self._session = aiohttp_session
self._is_v3_model = is_v3_model
# Build base settings common to all models
# Validate and clamp pace to model's valid range
pace = params.pace
pace_min, pace_max = self._config.pace_range
if pace is not None and (pace < pace_min or pace > pace_max):
logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.")
pace = max(pace_min, min(pace_max, pace))
# Build base settings
self._settings = {
"language": (
self.language_to_service_language(params.language) if params.language else "en-IN"
),
"enable_preprocessing": params.enable_preprocessing if not is_v3_model else True,
"enable_preprocessing": (
True if self._config.preprocessing_always_enabled else params.enable_preprocessing
),
"pace": pace,
"model": model,
}
# Add model-specific parameters
if is_v3_model:
# Validate pace for v3 (0.5-2.0)
pace = params.pace
if pace is not None and (pace < 0.5 or pace > 2.0):
logger.warning(
f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range."
)
pace = max(0.5, min(2.0, pace))
# Add parameters based on model support
if self._config.supports_pitch:
self._settings["pitch"] = params.pitch
elif params.pitch != 0.0:
logger.warning(f"pitch parameter is ignored for {model}")
self._settings.update(
{
"temperature": params.temperature,
"pace": pace,
"model": model,
}
)
# Log warning if v2-only parameters are set
if params.pitch != 0.0:
logger.warning(f"pitch parameter is ignored for {model}")
if params.loudness != 1.0:
logger.warning(f"loudness parameter is ignored for {model}")
else:
self._settings.update(
{
"pitch": params.pitch,
"pace": params.pace,
"loudness": params.loudness,
"model": model,
}
)
# Log warning if v3-only parameters are set
if params.temperature != 0.6:
logger.warning(f"temperature parameter is ignored for {model}")
if self._config.supports_loudness:
self._settings["loudness"] = params.loudness
elif params.loudness != 1.0:
logger.warning(f"loudness parameter is ignored for {model}")
if self._config.supports_temperature:
self._settings["temperature"] = params.temperature
elif params.temperature != 0.6:
logger.warning(f"temperature parameter is ignored for {model}")
self.set_model_name(model)
self.set_voice(voice_id)
@@ -436,7 +456,7 @@ class SarvamHttpTTSService(TTSService):
try:
await self.start_ttfb_metrics()
# Build payload based on model type
# Build payload with common parameters
payload = {
"text": text,
"target_language_code": self._settings["language"],
@@ -444,19 +464,16 @@ class SarvamHttpTTSService(TTSService):
"sample_rate": self.sample_rate,
"enable_preprocessing": self._settings["enable_preprocessing"],
"model": self._model_name,
"pace": self._settings.get("pace", 1.0),
}
# Add model-specific parameters
if self._is_v3_model:
# v3 models use temperature and pace (0.5-2.0)
payload["temperature"] = self._settings.get("temperature", 0.6)
if "pace" in self._settings:
payload["pace"] = self._settings["pace"]
else:
# v2 models use pitch, pace, loudness
# Add model-specific parameters based on config
if self._config.supports_pitch:
payload["pitch"] = self._settings.get("pitch", 0.0)
payload["pace"] = self._settings.get("pace", 1.0)
if self._config.supports_loudness:
payload["loudness"] = self._settings.get("loudness", 1.0)
if self._config.supports_temperature:
payload["temperature"] = self._settings.get("temperature", 0.6)
headers = {
"api-subscription-key": self._api_key,
@@ -669,35 +686,26 @@ class SarvamTTSService(InterruptibleTTSService):
model: TTS model to use. Options:
- "bulbul:v2" (default): Standard model with pitch/loudness support
- "bulbul:v3-beta": Advanced model with temperature control
- "bulbul:v3": Alias for v3-beta
voice_id: Speaker voice ID. If None, uses model-appropriate default:
- bulbul:v2: "anushka"
- bulbul:v3-beta/v3: "aditya"
voice_id: Speaker voice ID. If None, uses model-appropriate default.
url: WebSocket URL for the TTS backend (default production URL).
aggregate_sentences: Merge multiple sentences into one audio chunk (default True).
sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000).
If None, uses model-specific default (v2: 22050, v3: 24000).
If None, uses model-specific default.
params: Optional input parameters to override defaults.
**kwargs: Arguments forwarded to InterruptibleTTSService.
Note:
When using bulbul:v3-beta:
- pitch and loudness parameters are ignored
- pace range is limited to 0.5-2.0
- preprocessing is always enabled
- use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu")
See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream
"""
# Determine if using v3 model
is_v3_model = model in (
SarvamTTSModel.BULBUL_V3_BETA.value,
"bulbul:v3-beta",
)
# Get model configuration (validates model exists)
if model not in TTS_MODEL_CONFIGS:
allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys()))
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
self._config = TTS_MODEL_CONFIGS[model]
# Set default sample rate based on model if not specified
if sample_rate is None:
sample_rate = 24000 if is_v3_model else 22050
sample_rate = self._config.default_sample_rate
# Initialize parent class first
super().__init__(
@@ -712,9 +720,7 @@ class SarvamTTSService(InterruptibleTTSService):
# Set default voice based on model if not specified
if voice_id is None:
voice_id = "aditya" if is_v3_model else "anushka"
self._is_v3_model = is_v3_model
voice_id = self._config.default_speaker
# WebSocket endpoint URL with model query parameter
self._websocket_url = f"{url}?model={model}"
@@ -722,54 +728,46 @@ class SarvamTTSService(InterruptibleTTSService):
self.set_model_name(model)
self.set_voice(voice_id)
# Build base settings common to all models
# Validate and clamp pace to model's valid range
pace = params.pace
pace_min, pace_max = self._config.pace_range
if pace is not None and (pace < pace_min or pace > pace_max):
logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.")
pace = max(pace_min, min(pace_max, pace))
# Build base settings
self._settings = {
"target_language_code": (
self.language_to_service_language(params.language) if params.language else "en-IN"
),
"speaker": voice_id,
"speech_sample_rate": str(sample_rate),
"enable_preprocessing": params.enable_preprocessing if not is_v3_model else True,
"enable_preprocessing": (
True if self._config.preprocessing_always_enabled else params.enable_preprocessing
),
"min_buffer_size": params.min_buffer_size,
"max_chunk_length": params.max_chunk_length,
"output_audio_codec": params.output_audio_codec,
"output_audio_bitrate": params.output_audio_bitrate,
"pace": pace,
"model": model,
}
# Add model-specific parameters
if is_v3_model:
# Validate pace for v3 (0.5-2.0)
pace = params.pace
if pace is not None and (pace < 0.5 or pace > 2.0):
logger.warning(
f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range."
)
pace = max(0.5, min(2.0, pace))
# Add parameters based on model support
if self._config.supports_pitch:
self._settings["pitch"] = params.pitch
elif params.pitch != 0.0:
logger.warning(f"pitch parameter is ignored for {model}")
self._settings.update(
{
"temperature": params.temperature,
"pace": pace,
"model": model,
}
)
# Log warning if v2-only parameters are set
if params.pitch != 0.0:
logger.warning(f"pitch parameter is ignored for {model}")
if params.loudness != 1.0:
logger.warning(f"loudness parameter is ignored for {model}")
else:
self._settings.update(
{
"pitch": params.pitch,
"pace": params.pace,
"loudness": params.loudness,
"model": model,
}
)
# Log warning if v3-only parameters are set
if params.temperature != 0.6:
logger.warning(f"temperature parameter is ignored for {model}")
if self._config.supports_loudness:
self._settings["loudness"] = params.loudness
elif params.loudness != 1.0:
logger.warning(f"loudness parameter is ignored for {model}")
if self._config.supports_temperature:
self._settings["temperature"] = params.temperature
elif params.temperature != 0.6:
logger.warning(f"temperature parameter is ignored for {model}")
self._started = False
self._receive_task = None