Merge pull request #2821 from shreyas-sarvam/sarvam/stt

Sarvam STT/STTT WS implementation
This commit is contained in:
Mark Backman
2025-10-31 13:47:27 -04:00
committed by GitHub
6 changed files with 500 additions and 5 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added supprt for Sarvam Speech-to-Text service (`SarvamSTTService`) with streaming WebSocket
support for `saarika` (STT) and `saaras` (STT-translate) models.
- Added a new `DeepgramHttpTTSService`, which delivers a meaningful reduction
in latency when compared to the `DeepgramTTSService`.

View File

@@ -22,8 +22,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.sarvam.stt import SarvamSTTService
from pipecat.services.sarvam.tts import SarvamHttpTTSService
from pipecat.transcriptions.language import Language
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -63,7 +63,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Create an HTTP session
async with aiohttp.ClientSession() as session:
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = SarvamSTTService(
api_key=os.getenv("SARVAM_API_KEY"),
model="saarika:v2.5",
)
tts = SarvamHttpTTSService(
api_key=os.getenv("SARVAM_API_KEY"),

View File

@@ -24,8 +24,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.sarvam.stt import SarvamSTTService
from pipecat.services.sarvam.tts import SarvamTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
@@ -62,7 +62,10 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = SarvamSTTService(
api_key=os.getenv("SARVAM_API_KEY"),
model="saarika:v2.5",
)
tts = SarvamTTSService(
api_key=os.getenv("SARVAM_API_KEY"),

View File

@@ -93,7 +93,7 @@ rime = [ "pipecat-ai[websockets-base]" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sambanova = []
sarvam = [ "pipecat-ai[websockets-base]" ]
sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ]
sentry = [ "sentry-sdk>=2.28.0,<3" ]
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]

View File

@@ -0,0 +1,468 @@
"""Sarvam AI Speech-to-Text service implementation.
This module provides a streaming Speech-to-Text service using Sarvam AI's WebSocket-based
API. It supports real-time transcription with Voice Activity Detection (VAD) and
can handle multiple audio formats for Indian language speech recognition.
"""
import base64
from typing import Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from sarvamai import AsyncSarvamAI
from sarvamai.core.api_error import ApiError
from sarvamai.core.events import EventType
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Sarvam, you need to `pip install pipecat-ai[sarvam]`.")
raise Exception(f"Missing module: {e}")
def language_to_sarvam_language(language: Language) -> str:
"""Convert a Language enum to Sarvam's language code format.
Args:
language: The Language enum value to convert.
Returns:
The Sarvam language code string.
"""
# Mapping of pipecat Language enum to Sarvam language codes
SARVAM_LANGUAGES = {
Language.BN_IN: "bn-IN",
Language.GU_IN: "gu-IN",
Language.HI_IN: "hi-IN",
Language.KN_IN: "kn-IN",
Language.ML_IN: "ml-IN",
Language.MR_IN: "mr-IN",
Language.TA_IN: "ta-IN",
Language.TE_IN: "te-IN",
Language.PA_IN: "pa-IN",
Language.OR_IN: "od-IN",
Language.EN_IN: "en-IN",
Language.AS_IN: "as-IN",
}
return SARVAM_LANGUAGES.get(
language, "unknown"
) # Default to unknown (Sarvam models auto-detect the language)
class SarvamSTTService(STTService):
"""Sarvam speech-to-text service.
Provides real-time speech recognition using Sarvam's WebSocket API.
"""
class InputParams(BaseModel):
"""Configuration parameters for Sarvam STT service.
Parameters:
language: Target language for transcription. Defaults to None (required for saarika models).
prompt: Optional prompt to guide translation style/context for STT-Translate models.
Only applicable to saaras (STT-Translate) models. Defaults to None.
vad_signals: Enable VAD signals in response. Defaults to True.
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to False.
"""
language: Optional[Language] = None
prompt: Optional[str] = None
vad_signals: bool = True
high_vad_sensitivity: bool = False
def __init__(
self,
*,
api_key: str,
model: str = "saarika:v2.5",
sample_rate: Optional[int] = None,
input_audio_codec: str = "wav",
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Sarvam STT service.
Args:
api_key: Sarvam API key for authentication.
model: Sarvam model to use for transcription.
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
params: Configuration parameters for Sarvam STT service.
**kwargs: Additional arguments passed to the parent STTService.
"""
params = params or SarvamSTTService.InputParams()
# Validate that saaras models don't accept language parameter
if "saaras" in model.lower():
if params.language is not None:
raise ValueError(
f"Model '{model}' does not accept language parameter. "
"STT-Translate models auto-detect language."
)
# Validate that saarika models don't accept prompt parameter
if "saarika" in model.lower():
if params.prompt is not None:
raise ValueError(
f"Model '{model}' does not accept prompt parameter. "
"Prompts are only supported for STT-Translate models"
)
super().__init__(sample_rate=sample_rate, **kwargs)
self.set_model_name(model)
self._api_key = api_key
self._language_code = params.language
# For saarika models, default to "unknown" if language is not provided
if params.language:
self._language_string = language_to_sarvam_language(params.language)
elif "saarika" in model.lower():
self._language_string = "unknown"
else:
self._language_string = None
self._prompt = params.prompt
# Store connection parameters
self._vad_signals = params.vad_signals
self._high_vad_sensitivity = params.high_vad_sensitivity
self._input_audio_codec = input_audio_codec
# Initialize Sarvam SDK client
self._sarvam_client = AsyncSarvamAI(api_subscription_key=api_key)
self._websocket_context = None
self._socket_client = None
self._receive_task = None
def language_to_service_language(self, language: Language) -> str:
"""Convert pipecat Language enum to Sarvam's language code.
Args:
language: The Language enum value to convert.
Returns:
The Sarvam language code string.
"""
return language_to_sarvam_language(language)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Sarvam service supports metrics generation.
"""
return True
async def set_language(self, language: Language):
"""Set the recognition language and reconnect.
Args:
language: The language to use for speech recognition.
"""
# saaras models do not accept a language parameter
if "saaras" in self.model_name.lower():
raise ValueError(
f"Model '{self.model_name}' (saaras) does not accept language parameter. "
"saaras models auto-detect language."
)
logger.info(f"Switching STT language to: [{language}]")
self._language_code = language
self._language_string = language_to_sarvam_language(language)
await self._disconnect()
await self._connect()
async def set_prompt(self, prompt: Optional[str]):
"""Set the translation prompt and reconnect.
Args:
prompt: Prompt text to guide translation style/context.
Pass None to clear/disable prompt.
Only applicable to STT-Translate models, not STT models.
"""
# saarika models do not accept prompt parameter
if "saarika" in self.model_name.lower():
if prompt is not None:
raise ValueError(
f"Model '{self.model_name}' does not accept prompt parameter. "
"Prompts are only supported for STT-Translate models."
)
# If prompt is None and it's saarika, just silently return (no-op)
return
logger.info("Updating STT-Translate prompt.")
self._prompt = prompt
await self._disconnect()
await self._connect()
async def start(self, frame: StartFrame):
"""Start the Sarvam STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Sarvam STT service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Sarvam STT service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes):
"""Send audio data to Sarvam for transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: None (transcription results come via WebSocket callbacks).
"""
if not self._socket_client:
logger.warning("WebSocket not connected, cannot process audio")
yield None
return
try:
# Convert audio bytes to base64 for Sarvam API
audio_base64 = base64.b64encode(audio).decode("utf-8")
# Convert input_audio_codec to encoding format (prepend "audio/" if needed)
encoding = (
self._input_audio_codec
if self._input_audio_codec.startswith("audio/")
else f"audio/{self._input_audio_codec}"
)
# Build method arguments
method_kwargs = {
"audio": audio_base64,
"encoding": encoding,
"sample_rate": self.sample_rate,
}
# Use appropriate method based on service type
if "saarika" in self.model_name.lower():
# STT service
await self._socket_client.transcribe(**method_kwargs)
else:
# STT-Translate service - auto-detects input language and returns translated text
await self._socket_client.translate(**method_kwargs)
except Exception as e:
logger.error(f"Error sending audio to Sarvam: {e}")
await self.push_error(ErrorFrame(f"Failed to send audio: {e}"))
yield None
async def _connect(self):
"""Connect to Sarvam WebSocket API using the SDK."""
logger.debug("Connecting to Sarvam")
try:
# Convert boolean parameters to string for SDK
vad_signals_str = "true" if self._vad_signals else "false"
high_vad_sensitivity_str = "true" if self._high_vad_sensitivity else "false"
# Build common connection parameters
connect_kwargs = {
"model": self.model_name,
"vad_signals": vad_signals_str,
"high_vad_sensitivity": high_vad_sensitivity_str,
"input_audio_codec": self._input_audio_codec,
"sample_rate": str(self.sample_rate),
}
# Choose the appropriate service based on model
if "saarika" in self.model_name.lower():
# STT service - requires language_code
connect_kwargs["language_code"] = self._language_string
self._websocket_context = self._sarvam_client.speech_to_text_streaming.connect(
**connect_kwargs
)
else:
# STT-Translate service - auto-detects input language and returns translated text
self._websocket_context = (
self._sarvam_client.speech_to_text_translate_streaming.connect(**connect_kwargs)
)
# Enter the async context manager
self._socket_client = await self._websocket_context.__aenter__()
# Set prompt if provided (only for STT-Translate models, after connection)
if self._prompt is not None and "saaras" in self.model_name.lower():
await self._socket_client.set_prompt(self._prompt)
# Register event handler for incoming messages
def _message_handler(message):
"""Wrapper to handle async response handler."""
# Use Pipecat's built-in task management
self.create_task(self._handle_message(message))
self._socket_client.on(EventType.MESSAGE, _message_handler)
# Start receive task using Pipecat's task management
self._receive_task = self.create_task(self._receive_task_handler())
logger.info("Connected to Sarvam successfully")
except ApiError as e:
logger.error(f"Sarvam API error: {e}")
await self.push_error(ErrorFrame(f"Sarvam API error: {e}"))
except Exception as e:
logger.error(f"Failed to connect to Sarvam: {e}")
self._socket_client = None
self._websocket_context = None
await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}"))
async def _disconnect(self):
"""Disconnect from Sarvam WebSocket API using SDK."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
if self._websocket_context and self._socket_client:
try:
# Exit the async context manager
await self._websocket_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Error closing WebSocket connection: {e}")
finally:
logger.debug("Disconnected from Sarvam WebSocket")
self._socket_client = None
self._websocket_context = None
async def _receive_task_handler(self):
"""Handle incoming messages from Sarvam WebSocket.
This task wraps the SDK's start_listening() method which processes
messages via the registered event handler callback.
"""
if not self._socket_client:
return
try:
# Start listening for messages from the Sarvam SDK
# Messages will be handled via the _message_handler callback
await self._socket_client.start_listening()
except Exception as e:
logger.error(f"Error in Sarvam receive task: {e}")
await self.push_error(ErrorFrame(f"Sarvam receive task error: {e}"))
async def _handle_message(self, message):
"""Handle incoming WebSocket message from Sarvam SDK.
Processes transcription data and VAD events from the Sarvam service.
Args:
message: The parsed response object from Sarvam WebSocket.
"""
logger.debug(f"Received response: {message}")
try:
if message.type == "events":
# VAD event
signal = message.data.signal_type
timestamp = message.data.occured_at
logger.debug(f"VAD Signal: {signal}, Occurred at: {timestamp}")
if signal == "START_SPEECH":
await self.start_metrics()
logger.debug("User started speaking")
await self._call_event_handler("on_speech_started")
elif message.type == "data":
await self.stop_ttfb_metrics()
transcript = message.data.transcript
language_code = message.data.language_code
# Prefer language from message (auto-detected for translate models). Fallback to configured.
if language_code:
language = self._map_language_code_to_enum(language_code)
elif self._language_string:
language = self._map_language_code_to_enum(self._language_string)
else:
language = Language.HI_IN
# Emit utterance end event
await self._call_event_handler("on_utterance_end")
if transcript and transcript.strip():
# Record tracing for this transcription event
await self._handle_transcription(transcript, True, language)
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=(message.dict() if hasattr(message, "dict") else str(message)),
)
)
await self.stop_processing_metrics()
except Exception as e:
logger.error(f"Error handling Sarvam message: {e}")
await self.push_error(ErrorFrame(f"Failed to handle message: {e}"))
await self.stop_all_metrics()
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing.
This method is decorated with @traced_stt for observability.
"""
pass
def _map_language_code_to_enum(self, language_code: str) -> Language:
"""Map Sarvam language code to pipecat Language enum."""
mapping = {
"bn-IN": Language.BN_IN,
"gu-IN": Language.GU_IN,
"hi-IN": Language.HI_IN,
"kn-IN": Language.KN_IN,
"ml-IN": Language.ML_IN,
"mr-IN": Language.MR_IN,
"ta-IN": Language.TA_IN,
"te-IN": Language.TE_IN,
"pa-IN": Language.PA_IN,
"od-IN": Language.OR_IN,
"en-US": Language.EN_US,
"en-IN": Language.EN_IN,
"as-IN": Language.AS_IN,
}
return mapping.get(language_code, Language.HI_IN)
async def start_metrics(self):
"""Start TTFB and processing metrics collection."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()

18
uv.lock generated
View File

@@ -4550,6 +4550,7 @@ runner = [
{ name = "uvicorn" },
]
sarvam = [
{ name = "sarvamai" },
{ name = "websockets" },
]
sentry = [
@@ -4704,6 +4705,7 @@ requires-dist = [
{ name = "python-dotenv", marker = "extra == 'runner'", specifier = ">=1.0.0,<2.0.0" },
{ name = "pyvips", extras = ["binary"], marker = "extra == 'moondream'", specifier = "~=3.0.0" },
{ name = "resampy", specifier = "~=0.4.3" },
{ name = "sarvamai", marker = "extra == 'sarvam'", specifier = "==0.1.21" },
{ name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.28.0,<3" },
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=0.1.10" },
{ name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.0" },
@@ -6212,6 +6214,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" },
]
[[package]]
name = "sarvamai"
version = "0.1.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "pydantic" },
{ name = "pydantic-core" },
{ name = "typing-extensions" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e9/08/e5efcb30818ed220b818319255c22fd91e379489ebaa93efd6f444fb4987/sarvamai-0.1.21.tar.gz", hash = "sha256:865065635b2b99d40f5519308832954015627938e06a6333b5f62ae9c36278bb", size = 87386, upload-time = "2025-10-07T07:37:47.085Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2e/4e/b9933f72681b7aed91b86913337dd3981fad97027881fbc66c3c5eb03568/sarvamai-0.1.21-py3-none-any.whl", hash = "sha256:daa4e5d16635fe434f5f270cee416849249285369141d77132a17f0bf670f120", size = 175204, upload-time = "2025-10-07T07:37:46.024Z" },
]
[[package]]
name = "scipy"
version = "1.15.3"