feat: Add support for bulbul:v3 and bulbul:v3-beta
This commit is contained in:
@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -76,17 +76,29 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
Example::
|
||||
|
||||
tts = SarvamTTSService(
|
||||
tts = SarvamHttpTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="anushka",
|
||||
model="bulbul:v2",
|
||||
aiohttp_session=session,
|
||||
params=SarvamTTSService.InputParams(
|
||||
params=SarvamHttpTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
pitch=0.1,
|
||||
pace=1.2
|
||||
)
|
||||
)
|
||||
|
||||
# For bulbul v3 beta with any speaker:
|
||||
tts_v3 = SarvamHttpTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="speaker_name",
|
||||
model="bulbul:v3,
|
||||
aiohttp_session=session,
|
||||
params=SarvamHttpTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
temperature=0.8
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -105,6 +117,11 @@ class SarvamHttpTTSService(TTSService):
|
||||
pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0)
|
||||
loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0)
|
||||
enable_preprocessing: Optional[bool] = False
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.6,
|
||||
ge=0.01,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,7 +141,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
api_key: Sarvam AI API subscription key.
|
||||
aiohttp_session: Shared aiohttp session for making requests.
|
||||
voice_id: Speaker voice ID (e.g., "anushka", "meera"). Defaults to "anushka".
|
||||
model: TTS model to use ("bulbul:v1" or "bulbul:v2"). Defaults to "bulbul:v2".
|
||||
model: TTS model to use ("bulbul:v2" or "bulbul:v3-beta" or "bulbul:v3"). Defaults to "bulbul:v2".
|
||||
base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai".
|
||||
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). If None, uses default.
|
||||
params: Additional voice and preprocessing parameters. If None, uses defaults.
|
||||
@@ -138,15 +155,31 @@ class SarvamHttpTTSService(TTSService):
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
}
|
||||
if model == "bulbul:v3-beta" or model == "bulbul:v3":
|
||||
# For bulbul v3 beta, exclude pace and loudness parameters
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-IN"
|
||||
),
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
"temperature": getattr(params, "temperature", 0.6),
|
||||
"model": model, # Include model in settings for v3 beta
|
||||
}
|
||||
else:
|
||||
# For bulbul v2, include all parameters including pace and loudness
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
@@ -275,6 +308,18 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
pace=1.2
|
||||
)
|
||||
)
|
||||
|
||||
# For bulbul v3 beta with any speaker and temperature:
|
||||
# Note: pace and loudness are not supported for bulbul v3 beta
|
||||
tts_v3 = SarvamTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="speaker_name",
|
||||
model="bulbul:v3",
|
||||
params=SarvamTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
temperature=0.8
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -310,6 +355,14 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
output_audio_codec: Optional[str] = "linear16"
|
||||
output_audio_bitrate: Optional[str] = "128k"
|
||||
language: Optional[Language] = Language.EN
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.6,
|
||||
ge=0.01,
|
||||
le=1.0,
|
||||
description="Controls the randomness of the output for bulbul v3 beta. "
|
||||
"Lower values make the output more focused and deterministic, while "
|
||||
"higher values make it more random. Range: 0.01 to 1.0. Default: 0.6.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -329,13 +382,12 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
Args:
|
||||
api_key: Sarvam API key for authenticating TTS requests.
|
||||
model: Identifier of the Sarvam speech model (default "bulbul:v2").
|
||||
Supports "bulbul:v2", "bulbul:v3-beta" and "bulbul:v3".
|
||||
voice_id: Voice identifier for synthesis (default "anushka").
|
||||
url: WebSocket URL for connecting to the TTS backend (default production URL).
|
||||
aiohttp_session: Optional shared aiohttp session. To maintain backward compatibility.
|
||||
|
||||
.. deprecated:: 0.0.81
|
||||
aiohttp_session is no longer used. This parameter will be removed in a future version.
|
||||
|
||||
aggregate_sentences: Whether to merge multiple sentences into one audio chunk (default True).
|
||||
sample_rate: Desired sample rate for the output audio in Hz (overrides default if set).
|
||||
params: Optional input parameters to override global configuration.
|
||||
@@ -372,21 +424,43 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
# Configuration parameters
|
||||
self._settings = {
|
||||
"target_language_code": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"speaker": voice_id,
|
||||
"loudness": params.loudness,
|
||||
"speech_sample_rate": 0,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
"min_buffer_size": params.min_buffer_size,
|
||||
"max_chunk_length": params.max_chunk_length,
|
||||
"output_audio_codec": params.output_audio_codec,
|
||||
"output_audio_bitrate": params.output_audio_bitrate,
|
||||
}
|
||||
if model == "bulbul:v3-beta" or model == "bulbul:v3":
|
||||
# For bulbul v3 beta, exclude pace and loudness parameters
|
||||
self._settings = {
|
||||
"target_language_code": (
|
||||
self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-IN"
|
||||
),
|
||||
"speaker": voice_id,
|
||||
"speech_sample_rate": 0,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
"min_buffer_size": params.min_buffer_size,
|
||||
"max_chunk_length": params.max_chunk_length,
|
||||
"output_audio_codec": params.output_audio_codec,
|
||||
"output_audio_bitrate": params.output_audio_bitrate,
|
||||
"temperature": getattr(params, "temperature", 0.6),
|
||||
"model": model, # Include model in settings for v3 beta
|
||||
}
|
||||
else:
|
||||
# For bulbul v2, include all parameters including pace and loudness
|
||||
self._settings = {
|
||||
"target_language_code": (
|
||||
self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"speaker": voice_id,
|
||||
"loudness": params.loudness,
|
||||
"speech_sample_rate": 0,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
"min_buffer_size": params.min_buffer_size,
|
||||
"max_chunk_length": params.max_chunk_length,
|
||||
"output_audio_codec": params.output_audio_codec,
|
||||
"output_audio_bitrate": params.output_audio_bitrate,
|
||||
}
|
||||
self._started = False
|
||||
|
||||
self._receive_task = None
|
||||
@@ -455,7 +529,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
Reference in New Issue
Block a user