Merge pull request #1872 from pipecat-ai/mb/add-sarvam-tts
Add SarvamTTSService
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `SarvamTTSService`, which implements Sarvam AI's TTS API:
|
||||
https://docs.sarvam.ai/api-reference-docs/text-to-speech/convert.
|
||||
|
||||
- Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to
|
||||
allow mangaging observers at runtime. This is useful for cases where the task
|
||||
is passed around to other code components that might want to observe the
|
||||
@@ -126,8 +129,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Other
|
||||
|
||||
- Added foundation example `07y-minimax-http.py` to show how to use the
|
||||
`MiniMaxHttpTTSService`.
|
||||
- Added foundation examples `07y-interruptible-minimax.py` and
|
||||
`07z-interruptible-sarvam.py`to show how to use the `MiniMaxHttpTTSService`
|
||||
and `SarvamTTSService`, respectively.
|
||||
|
||||
- Added an `open-telemetry-tracing` example, showing how to setup tracing. The
|
||||
example also includes Jaeger as an open source OpenTelemetry client to review
|
||||
|
||||
@@ -105,3 +105,6 @@ TWILIO_AUTH_TOKEN=...
|
||||
# MiniMax
|
||||
MINIMAX_API_KEY=...
|
||||
MINIMAX_GROUP_ID=...
|
||||
|
||||
# Sarvam AI
|
||||
SARVAM_API_KEY=...
|
||||
109
examples/foundational/07z-interruptible-sarvam.py
Normal file
109
examples/foundational/07z-interruptible-sarvam.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
# Create an HTTP session
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = SarvamTTSService(
|
||||
api_key=os.getenv("SARVAM_API_KEY"),
|
||||
aiohttp_session=session,
|
||||
params=SarvamTTSService.InputParams(language=Language.EN),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
8
src/pipecat/services/sarvam/__init__.py
Normal file
8
src/pipecat/services/sarvam/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from .tts import *
|
||||
195
src/pipecat/services/sarvam/tts.py
Normal file
195
src/pipecat/services/sarvam/tts.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
def language_to_sarvam_language(language: Language) -> Optional[str]:
|
||||
"""Convert Pipecat Language enum to Sarvam AI language codes."""
|
||||
LANGUAGE_MAP = {
|
||||
Language.BN: "bn-IN", # Bengali
|
||||
Language.EN: "en-IN", # English (India)
|
||||
Language.GU: "gu-IN", # Gujarati
|
||||
Language.HI: "hi-IN", # Hindi
|
||||
Language.KN: "kn-IN", # Kannada
|
||||
Language.ML: "ml-IN", # Malayalam
|
||||
Language.MR: "mr-IN", # Marathi
|
||||
Language.OR: "od-IN", # Odia
|
||||
Language.PA: "pa-IN", # Punjabi
|
||||
Language.TA: "ta-IN", # Tamil
|
||||
Language.TE: "te-IN", # Telugu
|
||||
}
|
||||
|
||||
return LANGUAGE_MAP.get(language)
|
||||
|
||||
|
||||
class SarvamTTSService(TTSService):
|
||||
"""Text-to-Speech service using Sarvam AI's API.
|
||||
|
||||
Converts text to speech using Sarvam AI's TTS models with support for multiple
|
||||
Indian languages. Provides control over voice characteristics like pitch, pace,
|
||||
and loudness.
|
||||
|
||||
Args:
|
||||
api_key: Sarvam AI API subscription key.
|
||||
voice_id: Speaker voice ID (e.g., "anushka", "meera").
|
||||
model: TTS model to use ("bulbul:v1" or "bulbul:v2").
|
||||
aiohttp_session: Shared aiohttp session for making requests.
|
||||
base_url: Sarvam AI API base URL.
|
||||
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000).
|
||||
params: Additional voice and preprocessing parameters.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tts = SarvamTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="anushka",
|
||||
model="bulbul:v2",
|
||||
aiohttp_session=session,
|
||||
params=SarvamTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
pitch=0.1,
|
||||
pace=1.2
|
||||
)
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
pitch: Optional[float] = Field(default=0.0, ge=-0.75, le=0.75)
|
||||
pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0)
|
||||
loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0)
|
||||
enable_preprocessing: Optional[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "anushka",
|
||||
model: str = "bulbul:v2",
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
base_url: str = "https://api.sarvam.ai",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or SarvamTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-IN",
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
return language_to_sarvam_language(language)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"target_language_code": self._settings["language"],
|
||||
"speaker": self._voice_id,
|
||||
"pitch": self._settings["pitch"],
|
||||
"pace": self._settings["pace"],
|
||||
"loudness": self._settings["loudness"],
|
||||
"speech_sample_rate": self.sample_rate,
|
||||
"enable_preprocessing": self._settings["enable_preprocessing"],
|
||||
"model": self._model_name,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"api-subscription-key": self._api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self._base_url}/text-to-speech"
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Sarvam API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam API error: {error_text}"))
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Decode base64 audio data
|
||||
if "audios" not in response_data or not response_data["audios"]:
|
||||
logger.error("No audio data received from Sarvam API")
|
||||
await self.push_error(ErrorFrame("No audio data received"))
|
||||
return
|
||||
|
||||
# Get the first audio (there should be only one for single text input)
|
||||
base64_audio = response_data["audios"][0]
|
||||
audio_data = base64.b64decode(base64_audio)
|
||||
|
||||
# Strip WAV header (first 44 bytes) if present
|
||||
if audio_data.startswith(b"RIFF"):
|
||||
logger.debug("Stripping WAV header from Sarvam audio data")
|
||||
audio_data = audio_data[44:]
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
Reference in New Issue
Block a user