From 3c819955a2fb58c4f82a8b70c38507d4c65dc8c0 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Thu, 22 May 2025 16:14:16 -0400 Subject: [PATCH] Add SarvamTTSService --- CHANGELOG.md | 8 +- dot-env.template | 3 + .../foundational/07z-interruptible-sarvam.py | 109 ++++++++++ src/pipecat/services/sarvam/__init__.py | 8 + src/pipecat/services/sarvam/tts.py | 195 ++++++++++++++++++ 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 examples/foundational/07z-interruptible-sarvam.py create mode 100644 src/pipecat/services/sarvam/__init__.py create mode 100644 src/pipecat/services/sarvam/tts.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9386e280e..6ed8f019a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `SarvamTTSService`, which implements Sarvam AI's TTS API: + https://docs.sarvam.ai/api-reference-docs/text-to-speech/convert. + - Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to allow mangaging observers at runtime. This is useful for cases where the task is passed around to other code components that might want to observe the @@ -126,8 +129,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Other -- Added foundation example `07y-minimax-http.py` to show how to use the - `MiniMaxHttpTTSService`. +- Added foundation examples `07y-interruptible-minimax.py` and + `07z-interruptible-sarvam.py`to show how to use the `MiniMaxHttpTTSService` + and `SarvamTTSService`, respectively. - Added an `open-telemetry-tracing` example, showing how to setup tracing. The example also includes Jaeger as an open source OpenTelemetry client to review diff --git a/dot-env.template b/dot-env.template index aa8068451..20d73b3ad 100644 --- a/dot-env.template +++ b/dot-env.template @@ -105,3 +105,6 @@ TWILIO_AUTH_TOKEN=... # MiniMax MINIMAX_API_KEY=... MINIMAX_GROUP_ID=... + +# Sarvam AI +SARVAM_API_KEY=... \ No newline at end of file diff --git a/examples/foundational/07z-interruptible-sarvam.py b/examples/foundational/07z-interruptible-sarvam.py new file mode 100644 index 000000000..fafee5e93 --- /dev/null +++ b/examples/foundational/07z-interruptible-sarvam.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os + +import aiohttp +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.sarvam.tts import SarvamTTSService +from pipecat.transcriptions.language import Language +from pipecat.transports.base_transport import TransportParams +from pipecat.transports.network.small_webrtc import SmallWebRTCTransport +from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection + +load_dotenv(override=True) + + +async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): + logger.info(f"Starting bot") + + transport = SmallWebRTCTransport( + webrtc_connection=webrtc_connection, + params=TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + # Create an HTTP session + async with aiohttp.ClientSession() as session: + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = SarvamTTSService( + api_key=os.getenv("SARVAM_API_KEY"), + aiohttp_session=session, + params=SarvamTTSService.InputParams(language=Language.EN), + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") + await task.cancel() + + runner = PipelineRunner(handle_sigint=False) + + await runner.run(task) + + +if __name__ == "__main__": + from run import main + + main() diff --git a/src/pipecat/services/sarvam/__init__.py b/src/pipecat/services/sarvam/__init__.py new file mode 100644 index 000000000..0d444e949 --- /dev/null +++ b/src/pipecat/services/sarvam/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +from .tts import * diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py new file mode 100644 index 000000000..f9ce4e70f --- /dev/null +++ b/src/pipecat/services/sarvam/tts.py @@ -0,0 +1,195 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import base64 +from typing import AsyncGenerator, Optional + +import aiohttp +from loguru import logger +from pydantic import BaseModel, Field + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + StartFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.tts_service import TTSService +from pipecat.transcriptions.language import Language +from pipecat.utils.tracing.service_decorators import traced_tts + + +def language_to_sarvam_language(language: Language) -> Optional[str]: + """Convert Pipecat Language enum to Sarvam AI language codes.""" + LANGUAGE_MAP = { + Language.BN: "bn-IN", # Bengali + Language.EN: "en-IN", # English (India) + Language.GU: "gu-IN", # Gujarati + Language.HI: "hi-IN", # Hindi + Language.KN: "kn-IN", # Kannada + Language.ML: "ml-IN", # Malayalam + Language.MR: "mr-IN", # Marathi + Language.OR: "od-IN", # Odia + Language.PA: "pa-IN", # Punjabi + Language.TA: "ta-IN", # Tamil + Language.TE: "te-IN", # Telugu + } + + return LANGUAGE_MAP.get(language) + + +class SarvamTTSService(TTSService): + """Text-to-Speech service using Sarvam AI's API. + + Converts text to speech using Sarvam AI's TTS models with support for multiple + Indian languages. Provides control over voice characteristics like pitch, pace, + and loudness. + + Args: + api_key: Sarvam AI API subscription key. + voice_id: Speaker voice ID (e.g., "anushka", "meera"). + model: TTS model to use ("bulbul:v1" or "bulbul:v2"). + aiohttp_session: Shared aiohttp session for making requests. + base_url: Sarvam AI API base URL. + sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). + params: Additional voice and preprocessing parameters. + + Example: + ```python + tts = SarvamTTSService( + api_key="your-api-key", + voice_id="anushka", + model="bulbul:v2", + aiohttp_session=session, + params=SarvamTTSService.InputParams( + language=Language.HI, + pitch=0.1, + pace=1.2 + ) + ) + ``` + """ + + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + pitch: Optional[float] = Field(default=0.0, ge=-0.75, le=0.75) + pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0) + loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0) + enable_preprocessing: Optional[bool] = False + + def __init__( + self, + *, + api_key: str, + voice_id: str = "anushka", + model: str = "bulbul:v2", + aiohttp_session: aiohttp.ClientSession, + base_url: str = "https://api.sarvam.ai", + sample_rate: Optional[int] = None, + params: Optional[InputParams] = None, + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + params = params or SarvamTTSService.InputParams() + + self._api_key = api_key + self._base_url = base_url + self._session = aiohttp_session + + self._settings = { + "language": self.language_to_service_language(params.language) + if params.language + else "en-IN", + "pitch": params.pitch, + "pace": params.pace, + "loudness": params.loudness, + "enable_preprocessing": params.enable_preprocessing, + } + + self.set_model_name(model) + self.set_voice(voice_id) + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + return language_to_sarvam_language(language) + + async def start(self, frame: StartFrame): + await super().start(frame) + self._settings["sample_rate"] = self.sample_rate + + @traced_tts + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + await self.start_ttfb_metrics() + + payload = { + "text": text, + "target_language_code": self._settings["language"], + "speaker": self._voice_id, + "pitch": self._settings["pitch"], + "pace": self._settings["pace"], + "loudness": self._settings["loudness"], + "speech_sample_rate": self.sample_rate, + "enable_preprocessing": self._settings["enable_preprocessing"], + "model": self._model_name, + } + + headers = { + "api-subscription-key": self._api_key, + "Content-Type": "application/json", + } + + url = f"{self._base_url}/text-to-speech" + + yield TTSStartedFrame() + + async with self._session.post(url, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Sarvam API error: {error_text}") + await self.push_error(ErrorFrame(f"Sarvam API error: {error_text}")) + return + + response_data = await response.json() + + await self.start_tts_usage_metrics(text) + + # Decode base64 audio data + if "audios" not in response_data or not response_data["audios"]: + logger.error("No audio data received from Sarvam API") + await self.push_error(ErrorFrame("No audio data received")) + return + + # Get the first audio (there should be only one for single text input) + base64_audio = response_data["audios"][0] + audio_data = base64.b64decode(base64_audio) + + # Strip WAV header (first 44 bytes) if present + if audio_data.startswith(b"RIFF"): + logger.debug("Stripping WAV header from Sarvam audio data") + audio_data = audio_data[44:] + + frame = TTSAudioRawFrame( + audio=audio_data, + sample_rate=self.sample_rate, + num_channels=1, + ) + + yield frame + + except Exception as e: + logger.error(f"{self} exception: {e}") + await self.push_error(ErrorFrame(f"Error generating TTS: {e}")) + finally: + await self.stop_ttfb_metrics() + yield TTSStoppedFrame()