refactor(sarvam): centralize model configuration with dataclasses
This commit is contained in:
@@ -6,7 +6,8 @@ can handle multiple audio formats for Indian language speech recognition.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -68,6 +69,60 @@ def language_to_sarvam_language(language: Language) -> str:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""Immutable configuration for a Sarvam STT model.
|
||||
|
||||
Attributes:
|
||||
supports_prompt: Whether the model accepts prompt parameter.
|
||||
supports_mode: Whether the model accepts mode parameter.
|
||||
supports_language: Whether the model accepts language parameter.
|
||||
default_language: Default language code (None = auto-detect).
|
||||
default_mode: Default mode (None = not applicable).
|
||||
use_translate_endpoint: Whether to use speech_to_text_translate_streaming endpoint.
|
||||
use_translate_method: Whether to use translate() method instead of transcribe().
|
||||
"""
|
||||
|
||||
supports_prompt: bool
|
||||
supports_mode: bool
|
||||
supports_language: bool
|
||||
default_language: Optional[str]
|
||||
default_mode: Optional[str]
|
||||
use_translate_endpoint: bool
|
||||
use_translate_method: bool
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"saarika:v2.5": ModelConfig(
|
||||
supports_prompt=False,
|
||||
supports_mode=False,
|
||||
supports_language=True,
|
||||
default_language="unknown",
|
||||
default_mode=None,
|
||||
use_translate_endpoint=False,
|
||||
use_translate_method=False,
|
||||
),
|
||||
"saaras:v2.5": ModelConfig(
|
||||
supports_prompt=True,
|
||||
supports_mode=False,
|
||||
supports_language=False,
|
||||
default_language=None, # Auto-detects language
|
||||
default_mode=None,
|
||||
use_translate_endpoint=True,
|
||||
use_translate_method=True,
|
||||
),
|
||||
"saaras:v3": ModelConfig(
|
||||
supports_prompt=True,
|
||||
supports_mode=True,
|
||||
supports_language=True,
|
||||
default_language="en-IN",
|
||||
default_mode="transcribe",
|
||||
use_translate_endpoint=False,
|
||||
use_translate_method=False,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SarvamSTTService(STTService):
|
||||
"""Sarvam speech-to-text service.
|
||||
|
||||
@@ -103,9 +158,6 @@ class SarvamSTTService(STTService):
|
||||
model: str = "saarika:v2.5",
|
||||
sample_rate: Optional[int] = None,
|
||||
input_audio_codec: str = "wav",
|
||||
mode: Optional[
|
||||
Literal["transcribe", "translate", "verbatim", "translit", "codemix"]
|
||||
] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -119,73 +171,44 @@ class SarvamSTTService(STTService):
|
||||
- "saaras:v3": Advanced STT model (supports mode and prompts)
|
||||
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
|
||||
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
|
||||
mode: Mode of operation for saaras:v3 models only. Options: transcribe, translate,
|
||||
verbatim, translit, codemix. Defaults to "transcribe" for saaras:v3.
|
||||
params: Configuration parameters for Sarvam STT service.
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
params = params or SarvamSTTService.InputParams()
|
||||
# Allow mode to be passed directly or via params
|
||||
if mode is not None and params.mode is None:
|
||||
params = params.model_copy(update={"mode": mode})
|
||||
|
||||
# Validate allowed models
|
||||
allowed_models = {"saarika:v2.5", "saaras:v3", "saaras:v2.5"}
|
||||
if model not in allowed_models:
|
||||
allowed_models_list = ", ".join(sorted(allowed_models))
|
||||
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed_models_list}.")
|
||||
# Get model configuration (validates model exists)
|
||||
if model not in MODEL_CONFIGS:
|
||||
allowed = ", ".join(sorted(MODEL_CONFIGS.keys()))
|
||||
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
|
||||
|
||||
# Validate model-specific parameter restrictions
|
||||
if "saarika" in model.lower():
|
||||
# saarika models don't accept prompt or mode
|
||||
if params.prompt is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept prompt parameter. "
|
||||
"Prompts are only supported for saaras models (v2.5 and v3)."
|
||||
)
|
||||
if params.mode is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept mode parameter. "
|
||||
"Mode is only supported for saaras:v3 model."
|
||||
)
|
||||
elif model.lower() == "saaras:v2.5":
|
||||
# saaras:v2.5 supports prompt but not mode
|
||||
if params.mode is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept mode parameter. "
|
||||
"Mode is only supported for saaras:v3 model."
|
||||
)
|
||||
if params.language is not None:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not accept language parameter. "
|
||||
"saaras:v2.5 (STT-Translate) auto-detects language."
|
||||
)
|
||||
self._config = MODEL_CONFIGS[model]
|
||||
|
||||
# Validate parameters against model capabilities
|
||||
if params.prompt is not None and not self._config.supports_prompt:
|
||||
raise ValueError(f"Model '{model}' does not support prompt parameter.")
|
||||
if params.mode is not None and not self._config.supports_mode:
|
||||
raise ValueError(f"Model '{model}' does not support mode parameter.")
|
||||
if params.language is not None and not self._config.supports_language:
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support language parameter (auto-detects language)."
|
||||
)
|
||||
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
self._language_code: Optional[Language] = params.language
|
||||
# Set language string based on model type
|
||||
# - saarika:v2.5: uses language_code or defaults to "unknown"
|
||||
# - saaras:v2.5: auto-detects language (no language_code needed)
|
||||
# - saaras:v3: uses language_code or defaults to "en-IN"
|
||||
|
||||
# Set language string: use provided language or model's default
|
||||
if params.language:
|
||||
self._language_string = language_to_sarvam_language(params.language)
|
||||
elif "saarika" in model.lower():
|
||||
self._language_string = "unknown"
|
||||
elif model.lower() == "saaras:v2.5":
|
||||
self._language_string = None # STT-Translate auto-detects language
|
||||
elif model.lower() == "saaras:v3":
|
||||
self._language_string = "en-IN"
|
||||
else:
|
||||
self._language_string = None
|
||||
self._language_string = self._config.default_language
|
||||
|
||||
self._prompt = params.prompt
|
||||
# Set mode for saaras:v3, default to "transcribe"
|
||||
if model.lower() == "saaras:v3":
|
||||
self._mode = params.mode if params.mode is not None else "transcribe"
|
||||
else:
|
||||
self._mode = None
|
||||
|
||||
# Set mode: use provided mode or model's default
|
||||
self._mode = params.mode if params.mode is not None else self._config.default_mode
|
||||
|
||||
# Store connection parameters
|
||||
self._vad_signals = params.vad_signals
|
||||
@@ -250,13 +273,12 @@ class SarvamSTTService(STTService):
|
||||
language: The language to use for speech recognition.
|
||||
|
||||
Raises:
|
||||
ValueError: If called on saaras:v2.5 model which auto-detects language.
|
||||
ValueError: If called on a model that auto-detects language.
|
||||
"""
|
||||
# saaras:v2.5 (STT-Translate) auto-detects language
|
||||
if self.model_name.lower() == "saaras:v2.5":
|
||||
if not self._config.supports_language:
|
||||
raise ValueError(
|
||||
f"Model '{self.model_name}' does not accept language parameter. "
|
||||
"saaras:v2.5 (STT-Translate) auto-detects language."
|
||||
f"Model '{self.model_name}' does not support language parameter "
|
||||
"(auto-detects language)."
|
||||
)
|
||||
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
@@ -271,16 +293,12 @@ class SarvamSTTService(STTService):
|
||||
Args:
|
||||
prompt: Prompt text to guide transcription/translation style/context.
|
||||
Pass None to clear/disable prompt.
|
||||
Only applicable to saaras models (v2.5 and v3).
|
||||
Only applicable to models that support prompts.
|
||||
"""
|
||||
# saarika models do not accept prompt parameter
|
||||
if "saarika" in self.model_name.lower():
|
||||
if not self._config.supports_prompt:
|
||||
if prompt is not None:
|
||||
raise ValueError(
|
||||
f"Model '{self.model_name}' does not accept prompt parameter. "
|
||||
"Prompts are only supported for saaras models (v2.5 and v3)."
|
||||
)
|
||||
# If prompt is None and it's saarika, just silently return (no-op)
|
||||
raise ValueError(f"Model '{self.model_name}' does not support prompt parameter.")
|
||||
# If prompt is None and model doesn't support prompts, silently return (no-op)
|
||||
return
|
||||
|
||||
logger.info(f"Updating {self.model_name} prompt.")
|
||||
@@ -347,12 +365,10 @@ class SarvamSTTService(STTService):
|
||||
"sample_rate": self.sample_rate,
|
||||
}
|
||||
|
||||
# Use appropriate method based on model type
|
||||
if self.model_name.lower() == "saaras:v2.5":
|
||||
# STT-Translate: auto-detects input language and returns translated text
|
||||
# Use appropriate method based on model configuration
|
||||
if self._config.use_translate_method:
|
||||
await self._socket_client.translate(**method_kwargs)
|
||||
else:
|
||||
# saarika:v2.5 and saaras:v3 use transcribe
|
||||
await self._socket_client.transcribe(**method_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
@@ -377,12 +393,12 @@ class SarvamSTTService(STTService):
|
||||
"sample_rate": str(self.sample_rate),
|
||||
}
|
||||
|
||||
# Add language_code for models that require it (not saaras:v2.5 which auto-detects)
|
||||
# Add language_code for models that support it
|
||||
if self._language_string is not None:
|
||||
connect_kwargs["language_code"] = self._language_string
|
||||
|
||||
# Add mode for saaras:v3 only
|
||||
if self.model_name.lower() == "saaras:v3" and self._mode is not None:
|
||||
# Add mode for models that support it
|
||||
if self._config.supports_mode and self._mode is not None:
|
||||
connect_kwargs["mode"] = self._mode
|
||||
|
||||
def _connect_with_sdk_headers(connect_fn, **kwargs):
|
||||
@@ -394,15 +410,13 @@ class SarvamSTTService(STTService):
|
||||
pass
|
||||
return connect_fn(**kwargs)
|
||||
|
||||
# Choose the appropriate endpoint based on model
|
||||
if self.model_name.lower() == "saaras:v2.5":
|
||||
# STT-Translate: auto-detects input language and returns translated text
|
||||
# Choose the appropriate endpoint based on model configuration
|
||||
if self._config.use_translate_endpoint:
|
||||
self._websocket_context = _connect_with_sdk_headers(
|
||||
self._sarvam_client.speech_to_text_translate_streaming.connect,
|
||||
**connect_kwargs,
|
||||
)
|
||||
else:
|
||||
# saarika:v2.5 and saaras:v3 use speech_to_text_streaming
|
||||
self._websocket_context = _connect_with_sdk_headers(
|
||||
self._sarvam_client.speech_to_text_streaming.connect,
|
||||
**connect_kwargs,
|
||||
@@ -411,8 +425,8 @@ class SarvamSTTService(STTService):
|
||||
# Enter the async context manager
|
||||
self._socket_client = await self._websocket_context.__aenter__()
|
||||
|
||||
# Set prompt if provided (only for saaras models, after connection)
|
||||
if self._prompt is not None and "saaras" in self.model_name.lower():
|
||||
# Set prompt if provided (only for models that support prompts)
|
||||
if self._prompt is not None and self._config.supports_prompt:
|
||||
await self._socket_client.set_prompt(self._prompt)
|
||||
|
||||
# Register event handler for incoming messages
|
||||
|
||||
@@ -31,8 +31,9 @@ See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream for full API
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -133,16 +134,52 @@ class SarvamTTSSpeakerV3(str, Enum):
|
||||
SOPHIA = "sophia"
|
||||
|
||||
|
||||
# Default sample rates per model
|
||||
SARVAM_DEFAULT_SAMPLE_RATES = {
|
||||
SarvamTTSModel.BULBUL_V2: 22050,
|
||||
SarvamTTSModel.BULBUL_V3_BETA: 24000,
|
||||
}
|
||||
@dataclass(frozen=True)
|
||||
class TTSModelConfig:
|
||||
"""Immutable configuration for a Sarvam TTS model.
|
||||
|
||||
# Default speakers per model
|
||||
SARVAM_DEFAULT_SPEAKERS = {
|
||||
SarvamTTSModel.BULBUL_V2: SarvamTTSSpeakerV2.ANUSHKA.value,
|
||||
SarvamTTSModel.BULBUL_V3_BETA: SarvamTTSSpeakerV3.ADITYA.value,
|
||||
Attributes:
|
||||
supports_pitch: Whether the model accepts pitch parameter.
|
||||
supports_loudness: Whether the model accepts loudness parameter.
|
||||
supports_temperature: Whether the model accepts temperature parameter.
|
||||
default_sample_rate: Default audio sample rate in Hz.
|
||||
default_speaker: Default speaker voice ID.
|
||||
pace_range: Valid range for pace parameter (min, max).
|
||||
preprocessing_always_enabled: Whether preprocessing is always enabled.
|
||||
speakers: Tuple of available speaker names for this model.
|
||||
"""
|
||||
|
||||
supports_pitch: bool
|
||||
supports_loudness: bool
|
||||
supports_temperature: bool
|
||||
default_sample_rate: int
|
||||
default_speaker: str
|
||||
pace_range: Tuple[float, float]
|
||||
preprocessing_always_enabled: bool
|
||||
speakers: Tuple[str, ...]
|
||||
|
||||
|
||||
TTS_MODEL_CONFIGS: Dict[str, TTSModelConfig] = {
|
||||
"bulbul:v2": TTSModelConfig(
|
||||
supports_pitch=True,
|
||||
supports_loudness=True,
|
||||
supports_temperature=False,
|
||||
default_sample_rate=22050,
|
||||
default_speaker="anushka",
|
||||
pace_range=(0.3, 3.0),
|
||||
preprocessing_always_enabled=False,
|
||||
speakers=tuple(s.value for s in SarvamTTSSpeakerV2),
|
||||
),
|
||||
"bulbul:v3-beta": TTSModelConfig(
|
||||
supports_pitch=False,
|
||||
supports_loudness=False,
|
||||
supports_temperature=True,
|
||||
default_sample_rate=24000,
|
||||
default_speaker="aditya",
|
||||
pace_range=(0.5, 2.0),
|
||||
preprocessing_always_enabled=True,
|
||||
speakers=tuple(s.value for s in SarvamTTSSpeakerV3),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -155,9 +192,10 @@ def get_speakers_for_model(model: str) -> List[str]:
|
||||
Returns:
|
||||
List of speaker names available for the model.
|
||||
"""
|
||||
if model in (SarvamTTSModel.BULBUL_V3_BETA.value):
|
||||
return [s.value for s in SarvamTTSSpeakerV3]
|
||||
return [s.value for s in SarvamTTSSpeakerV2]
|
||||
if model in TTS_MODEL_CONFIGS:
|
||||
return list(TTS_MODEL_CONFIGS[model].speakers)
|
||||
# Default to v2 speakers for unknown models
|
||||
return list(TTS_MODEL_CONFIGS["bulbul:v2"].speakers)
|
||||
|
||||
|
||||
def language_to_sarvam_language(language: Language) -> Optional[str]:
|
||||
@@ -304,35 +342,26 @@ class SarvamHttpTTSService(TTSService):
|
||||
Args:
|
||||
api_key: Sarvam AI API subscription key.
|
||||
aiohttp_session: Shared aiohttp session for making requests.
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default:
|
||||
- bulbul:v2: "anushka"
|
||||
- bulbul:v3-beta/v3: "aditya"
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default.
|
||||
model: TTS model to use. Options:
|
||||
- "bulbul:v2" (default): Standard model with pitch/loudness support
|
||||
- "bulbul:v3-beta": Advanced model with temperature control
|
||||
- "bulbul:v3": Alias for v3-beta
|
||||
base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai".
|
||||
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000).
|
||||
If None, uses model-specific default (v2: 22050, v3: 24000).
|
||||
If None, uses model-specific default.
|
||||
params: Additional voice and preprocessing parameters. If None, uses defaults.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
|
||||
Note:
|
||||
When using bulbul:v3-beta:
|
||||
- pitch and loudness parameters are ignored
|
||||
- pace range is limited to 0.5-2.0
|
||||
- preprocessing is always enabled
|
||||
- use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu")
|
||||
"""
|
||||
# Determine if using v3 model
|
||||
is_v3_model = model in (
|
||||
SarvamTTSModel.BULBUL_V3_BETA.value,
|
||||
"bulbul:v3-beta",
|
||||
)
|
||||
# Get model configuration (validates model exists)
|
||||
if model not in TTS_MODEL_CONFIGS:
|
||||
allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys()))
|
||||
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
|
||||
|
||||
self._config = TTS_MODEL_CONFIGS[model]
|
||||
|
||||
# Set default sample rate based on model if not specified
|
||||
if sample_rate is None:
|
||||
sample_rate = 24000 if is_v3_model else 22050
|
||||
sample_rate = self._config.default_sample_rate
|
||||
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
@@ -340,55 +369,46 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
# Set default voice based on model if not specified
|
||||
if voice_id is None:
|
||||
voice_id = "aditya" if is_v3_model else "anushka"
|
||||
voice_id = self._config.default_speaker
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
self._is_v3_model = is_v3_model
|
||||
|
||||
# Build base settings common to all models
|
||||
# Validate and clamp pace to model's valid range
|
||||
pace = params.pace
|
||||
pace_min, pace_max = self._config.pace_range
|
||||
if pace is not None and (pace < pace_min or pace > pace_max):
|
||||
logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.")
|
||||
pace = max(pace_min, min(pace_max, pace))
|
||||
|
||||
# Build base settings
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"enable_preprocessing": params.enable_preprocessing if not is_v3_model else True,
|
||||
"enable_preprocessing": (
|
||||
True if self._config.preprocessing_always_enabled else params.enable_preprocessing
|
||||
),
|
||||
"pace": pace,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if is_v3_model:
|
||||
# Validate pace for v3 (0.5-2.0)
|
||||
pace = params.pace
|
||||
if pace is not None and (pace < 0.5 or pace > 2.0):
|
||||
logger.warning(
|
||||
f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range."
|
||||
)
|
||||
pace = max(0.5, min(2.0, pace))
|
||||
# Add parameters based on model support
|
||||
if self._config.supports_pitch:
|
||||
self._settings["pitch"] = params.pitch
|
||||
elif params.pitch != 0.0:
|
||||
logger.warning(f"pitch parameter is ignored for {model}")
|
||||
|
||||
self._settings.update(
|
||||
{
|
||||
"temperature": params.temperature,
|
||||
"pace": pace,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
# Log warning if v2-only parameters are set
|
||||
if params.pitch != 0.0:
|
||||
logger.warning(f"pitch parameter is ignored for {model}")
|
||||
if params.loudness != 1.0:
|
||||
logger.warning(f"loudness parameter is ignored for {model}")
|
||||
else:
|
||||
self._settings.update(
|
||||
{
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
# Log warning if v3-only parameters are set
|
||||
if params.temperature != 0.6:
|
||||
logger.warning(f"temperature parameter is ignored for {model}")
|
||||
if self._config.supports_loudness:
|
||||
self._settings["loudness"] = params.loudness
|
||||
elif params.loudness != 1.0:
|
||||
logger.warning(f"loudness parameter is ignored for {model}")
|
||||
|
||||
if self._config.supports_temperature:
|
||||
self._settings["temperature"] = params.temperature
|
||||
elif params.temperature != 0.6:
|
||||
logger.warning(f"temperature parameter is ignored for {model}")
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
@@ -436,7 +456,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build payload based on model type
|
||||
# Build payload with common parameters
|
||||
payload = {
|
||||
"text": text,
|
||||
"target_language_code": self._settings["language"],
|
||||
@@ -444,19 +464,16 @@ class SarvamHttpTTSService(TTSService):
|
||||
"sample_rate": self.sample_rate,
|
||||
"enable_preprocessing": self._settings["enable_preprocessing"],
|
||||
"model": self._model_name,
|
||||
"pace": self._settings.get("pace", 1.0),
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if self._is_v3_model:
|
||||
# v3 models use temperature and pace (0.5-2.0)
|
||||
payload["temperature"] = self._settings.get("temperature", 0.6)
|
||||
if "pace" in self._settings:
|
||||
payload["pace"] = self._settings["pace"]
|
||||
else:
|
||||
# v2 models use pitch, pace, loudness
|
||||
# Add model-specific parameters based on config
|
||||
if self._config.supports_pitch:
|
||||
payload["pitch"] = self._settings.get("pitch", 0.0)
|
||||
payload["pace"] = self._settings.get("pace", 1.0)
|
||||
if self._config.supports_loudness:
|
||||
payload["loudness"] = self._settings.get("loudness", 1.0)
|
||||
if self._config.supports_temperature:
|
||||
payload["temperature"] = self._settings.get("temperature", 0.6)
|
||||
|
||||
headers = {
|
||||
"api-subscription-key": self._api_key,
|
||||
@@ -669,35 +686,26 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
model: TTS model to use. Options:
|
||||
- "bulbul:v2" (default): Standard model with pitch/loudness support
|
||||
- "bulbul:v3-beta": Advanced model with temperature control
|
||||
- "bulbul:v3": Alias for v3-beta
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default:
|
||||
- bulbul:v2: "anushka"
|
||||
- bulbul:v3-beta/v3: "aditya"
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default.
|
||||
url: WebSocket URL for the TTS backend (default production URL).
|
||||
aggregate_sentences: Merge multiple sentences into one audio chunk (default True).
|
||||
sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000).
|
||||
If None, uses model-specific default (v2: 22050, v3: 24000).
|
||||
If None, uses model-specific default.
|
||||
params: Optional input parameters to override defaults.
|
||||
**kwargs: Arguments forwarded to InterruptibleTTSService.
|
||||
|
||||
Note:
|
||||
When using bulbul:v3-beta:
|
||||
- pitch and loudness parameters are ignored
|
||||
- pace range is limited to 0.5-2.0
|
||||
- preprocessing is always enabled
|
||||
- use SarvamTTSSpeakerV3 speakers (e.g., "aditya", "ritu")
|
||||
|
||||
See https://docs.sarvam.ai/api-reference-docs/text-to-speech/stream
|
||||
"""
|
||||
# Determine if using v3 model
|
||||
is_v3_model = model in (
|
||||
SarvamTTSModel.BULBUL_V3_BETA.value,
|
||||
"bulbul:v3-beta",
|
||||
)
|
||||
# Get model configuration (validates model exists)
|
||||
if model not in TTS_MODEL_CONFIGS:
|
||||
allowed = ", ".join(sorted(TTS_MODEL_CONFIGS.keys()))
|
||||
raise ValueError(f"Unsupported model '{model}'. Allowed values: {allowed}.")
|
||||
|
||||
self._config = TTS_MODEL_CONFIGS[model]
|
||||
|
||||
# Set default sample rate based on model if not specified
|
||||
if sample_rate is None:
|
||||
sample_rate = 24000 if is_v3_model else 22050
|
||||
sample_rate = self._config.default_sample_rate
|
||||
|
||||
# Initialize parent class first
|
||||
super().__init__(
|
||||
@@ -712,9 +720,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
# Set default voice based on model if not specified
|
||||
if voice_id is None:
|
||||
voice_id = "aditya" if is_v3_model else "anushka"
|
||||
|
||||
self._is_v3_model = is_v3_model
|
||||
voice_id = self._config.default_speaker
|
||||
|
||||
# WebSocket endpoint URL with model query parameter
|
||||
self._websocket_url = f"{url}?model={model}"
|
||||
@@ -722,54 +728,46 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
# Build base settings common to all models
|
||||
# Validate and clamp pace to model's valid range
|
||||
pace = params.pace
|
||||
pace_min, pace_max = self._config.pace_range
|
||||
if pace is not None and (pace < pace_min or pace > pace_max):
|
||||
logger.warning(f"Pace {pace} is outside model range ({pace_min}-{pace_max}). Clamping.")
|
||||
pace = max(pace_min, min(pace_max, pace))
|
||||
|
||||
# Build base settings
|
||||
self._settings = {
|
||||
"target_language_code": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"speaker": voice_id,
|
||||
"speech_sample_rate": str(sample_rate),
|
||||
"enable_preprocessing": params.enable_preprocessing if not is_v3_model else True,
|
||||
"enable_preprocessing": (
|
||||
True if self._config.preprocessing_always_enabled else params.enable_preprocessing
|
||||
),
|
||||
"min_buffer_size": params.min_buffer_size,
|
||||
"max_chunk_length": params.max_chunk_length,
|
||||
"output_audio_codec": params.output_audio_codec,
|
||||
"output_audio_bitrate": params.output_audio_bitrate,
|
||||
"pace": pace,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if is_v3_model:
|
||||
# Validate pace for v3 (0.5-2.0)
|
||||
pace = params.pace
|
||||
if pace is not None and (pace < 0.5 or pace > 2.0):
|
||||
logger.warning(
|
||||
f"Pace {pace} is outside v3 model range (0.5-2.0). Clamping to valid range."
|
||||
)
|
||||
pace = max(0.5, min(2.0, pace))
|
||||
# Add parameters based on model support
|
||||
if self._config.supports_pitch:
|
||||
self._settings["pitch"] = params.pitch
|
||||
elif params.pitch != 0.0:
|
||||
logger.warning(f"pitch parameter is ignored for {model}")
|
||||
|
||||
self._settings.update(
|
||||
{
|
||||
"temperature": params.temperature,
|
||||
"pace": pace,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
# Log warning if v2-only parameters are set
|
||||
if params.pitch != 0.0:
|
||||
logger.warning(f"pitch parameter is ignored for {model}")
|
||||
if params.loudness != 1.0:
|
||||
logger.warning(f"loudness parameter is ignored for {model}")
|
||||
else:
|
||||
self._settings.update(
|
||||
{
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
# Log warning if v3-only parameters are set
|
||||
if params.temperature != 0.6:
|
||||
logger.warning(f"temperature parameter is ignored for {model}")
|
||||
if self._config.supports_loudness:
|
||||
self._settings["loudness"] = params.loudness
|
||||
elif params.loudness != 1.0:
|
||||
logger.warning(f"loudness parameter is ignored for {model}")
|
||||
|
||||
if self._config.supports_temperature:
|
||||
self._settings["temperature"] = params.temperature
|
||||
elif params.temperature != 0.6:
|
||||
logger.warning(f"temperature parameter is ignored for {model}")
|
||||
|
||||
self._started = False
|
||||
self._receive_task = None
|
||||
|
||||
Reference in New Issue
Block a user