Merge pull request #1872 from pipecat-ai/mb/add-sarvam-tts

Add SarvamTTSService
This commit is contained in:
Mark Backman
2025-05-22 18:02:36 -04:00
committed by GitHub
5 changed files with 321 additions and 2 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `SarvamTTSService`, which implements Sarvam AI's TTS API:
https://docs.sarvam.ai/api-reference-docs/text-to-speech/convert.
- Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to
allow mangaging observers at runtime. This is useful for cases where the task
is passed around to other code components that might want to observe the
@@ -126,8 +129,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Other
- Added foundation example `07y-minimax-http.py` to show how to use the
`MiniMaxHttpTTSService`.
- Added foundation examples `07y-interruptible-minimax.py` and
`07z-interruptible-sarvam.py`to show how to use the `MiniMaxHttpTTSService`
and `SarvamTTSService`, respectively.
- Added an `open-telemetry-tracing` example, showing how to setup tracing. The
example also includes Jaeger as an open source OpenTelemetry client to review

View File

@@ -105,3 +105,6 @@ TWILIO_AUTH_TOKEN=...
# MiniMax
MINIMAX_API_KEY=...
MINIMAX_GROUP_ID=...
# Sarvam AI
SARVAM_API_KEY=...

View File

@@ -0,0 +1,109 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.sarvam.tts import SarvamTTSService
from pipecat.transcriptions.language import Language
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
# Create an HTTP session
async with aiohttp.ClientSession() as session:
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = SarvamTTSService(
api_key=os.getenv("SARVAM_API_KEY"),
aiohttp_session=session,
params=SarvamTTSService.InputParams(language=Language.EN),
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -0,0 +1,8 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from .tts import *

View File

@@ -0,0 +1,195 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
def language_to_sarvam_language(language: Language) -> Optional[str]:
"""Convert Pipecat Language enum to Sarvam AI language codes."""
LANGUAGE_MAP = {
Language.BN: "bn-IN", # Bengali
Language.EN: "en-IN", # English (India)
Language.GU: "gu-IN", # Gujarati
Language.HI: "hi-IN", # Hindi
Language.KN: "kn-IN", # Kannada
Language.ML: "ml-IN", # Malayalam
Language.MR: "mr-IN", # Marathi
Language.OR: "od-IN", # Odia
Language.PA: "pa-IN", # Punjabi
Language.TA: "ta-IN", # Tamil
Language.TE: "te-IN", # Telugu
}
return LANGUAGE_MAP.get(language)
class SarvamTTSService(TTSService):
"""Text-to-Speech service using Sarvam AI's API.
Converts text to speech using Sarvam AI's TTS models with support for multiple
Indian languages. Provides control over voice characteristics like pitch, pace,
and loudness.
Args:
api_key: Sarvam AI API subscription key.
voice_id: Speaker voice ID (e.g., "anushka", "meera").
model: TTS model to use ("bulbul:v1" or "bulbul:v2").
aiohttp_session: Shared aiohttp session for making requests.
base_url: Sarvam AI API base URL.
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000).
params: Additional voice and preprocessing parameters.
Example:
```python
tts = SarvamTTSService(
api_key="your-api-key",
voice_id="anushka",
model="bulbul:v2",
aiohttp_session=session,
params=SarvamTTSService.InputParams(
language=Language.HI,
pitch=0.1,
pace=1.2
)
)
```
"""
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
pitch: Optional[float] = Field(default=0.0, ge=-0.75, le=0.75)
pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0)
loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0)
enable_preprocessing: Optional[bool] = False
def __init__(
self,
*,
api_key: str,
voice_id: str = "anushka",
model: str = "bulbul:v2",
aiohttp_session: aiohttp.ClientSession,
base_url: str = "https://api.sarvam.ai",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or SarvamTTSService.InputParams()
self._api_key = api_key
self._base_url = base_url
self._session = aiohttp_session
self._settings = {
"language": self.language_to_service_language(params.language)
if params.language
else "en-IN",
"pitch": params.pitch,
"pace": params.pace,
"loudness": params.loudness,
"enable_preprocessing": params.enable_preprocessing,
}
self.set_model_name(model)
self.set_voice(voice_id)
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_sarvam_language(language)
async def start(self, frame: StartFrame):
await super().start(frame)
self._settings["sample_rate"] = self.sample_rate
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
payload = {
"text": text,
"target_language_code": self._settings["language"],
"speaker": self._voice_id,
"pitch": self._settings["pitch"],
"pace": self._settings["pace"],
"loudness": self._settings["loudness"],
"speech_sample_rate": self.sample_rate,
"enable_preprocessing": self._settings["enable_preprocessing"],
"model": self._model_name,
}
headers = {
"api-subscription-key": self._api_key,
"Content-Type": "application/json",
}
url = f"{self._base_url}/text-to-speech"
yield TTSStartedFrame()
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Sarvam API error: {error_text}")
await self.push_error(ErrorFrame(f"Sarvam API error: {error_text}"))
return
response_data = await response.json()
await self.start_tts_usage_metrics(text)
# Decode base64 audio data
if "audios" not in response_data or not response_data["audios"]:
logger.error("No audio data received from Sarvam API")
await self.push_error(ErrorFrame("No audio data received"))
return
# Get the first audio (there should be only one for single text input)
base64_audio = response_data["audios"][0]
audio_data = base64.b64decode(base64_audio)
# Strip WAV header (first 44 bytes) if present
if audio_data.startswith(b"RIFF"):
logger.debug("Stripping WAV header from Sarvam audio data")
audio_data = audio_data[44:]
frame = TTSAudioRawFrame(
audio=audio_data,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()