Merge pull request #3040 from pipecat-ai/mb/11labs-realtime-stt

Add ElevenLabsRealtimeSTTService
This commit is contained in:
Mark Backman
2025-11-13 09:53:34 -05:00
committed by GitHub
3 changed files with 483 additions and 4 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT
service from ElevenLabs.
- Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS
services that subclass `TTSService` can indicate whether the text in the
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.

View File

@@ -22,7 +22,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.elevenlabs.stt import ElevenLabsRealtimeSTTService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -60,7 +60,7 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = ElevenLabsRealtimeSTTService(api_key=os.getenv("ELEVENLABS_API_KEY"))
tts = ElevenLabsTTSService(
api_key=os.getenv("ELEVENLABS_API_KEY", ""),

View File

@@ -11,19 +11,43 @@ using segmented audio processing. The service uploads audio files and receives
transcription results directly.
"""
import base64
import io
import json
from enum import Enum
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use ElevenLabs Realtime STT, you need to `pip install pipecat-ai[elevenlabs]`."
)
raise Exception(f"Missing module: {e}")
def language_to_elevenlabs_language(language: Language) -> Optional[str]:
"""Convert a Language enum to ElevenLabs language code.
@@ -329,3 +353,455 @@ class ElevenLabsSTTService(SegmentedSTTService):
except Exception as e:
logger.error(f"ElevenLabs STT error: {e}")
yield ErrorFrame(f"ElevenLabs STT error: {str(e)}")
def audio_format_from_sample_rate(sample_rate: int) -> str:
"""Get the appropriate audio format string for a given sample rate.
Args:
sample_rate: The audio sample rate in Hz.
Returns:
The ElevenLabs audio format string.
"""
match sample_rate:
case 8000:
return "pcm_8000"
case 16000:
return "pcm_16000"
case 22050:
return "pcm_22050"
case 24000:
return "pcm_24000"
case 44100:
return "pcm_44100"
case 48000:
return "pcm_48000"
logger.warning(
f"ElevenLabsRealtimeSTTService: No audio format available for {sample_rate} sample rate, using pcm_16000"
)
return "pcm_16000"
class CommitStrategy(str, Enum):
"""Commit strategies for transcript segmentation."""
MANUAL = "manual"
VAD = "vad"
class ElevenLabsRealtimeSTTService(WebsocketSTTService):
"""Speech-to-text service using ElevenLabs' Realtime WebSocket API.
This service uses ElevenLabs' Realtime Speech-to-Text API to perform transcription
with ultra-low latency. It supports both partial (interim) and committed (final)
transcripts, and can use either manual commit control or automatic Voice Activity
Detection (VAD) for segment boundaries.
By default, uses manual commit strategy where Pipecat's VAD controls when to
commit transcript segments, providing consistency with other STT services.
"""
class InputParams(BaseModel):
"""Configuration parameters for ElevenLabs Realtime STT API.
Parameters:
language_code: ISO-639-1 or ISO-639-3 language code. Leave None for auto-detection.
commit_strategy: How to segment speech - manual (Pipecat VAD) or vad (ElevenLabs VAD).
vad_silence_threshold_secs: Seconds of silence before VAD commits (0.3-3.0).
Only used when commit_strategy is VAD. None uses ElevenLabs default.
vad_threshold: VAD sensitivity (0.1-0.9, lower is more sensitive).
Only used when commit_strategy is VAD. None uses ElevenLabs default.
min_speech_duration_ms: Minimum speech duration for VAD (50-2000ms).
Only used when commit_strategy is VAD. None uses ElevenLabs default.
min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms).
Only used when commit_strategy is VAD. None uses ElevenLabs default.
"""
language_code: Optional[str] = None
commit_strategy: CommitStrategy = CommitStrategy.MANUAL
vad_silence_threshold_secs: Optional[float] = None
vad_threshold: Optional[float] = None
min_speech_duration_ms: Optional[int] = None
min_silence_duration_ms: Optional[int] = None
def __init__(
self,
*,
api_key: str,
base_url: str = "api.elevenlabs.io",
model: str = "scribe_v2_realtime",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the ElevenLabs Realtime STT service.
Args:
api_key: ElevenLabs API key for authentication.
base_url: Base URL for ElevenLabs WebSocket API.
model: Model ID for transcription. Defaults to "scribe_v2_realtime".
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
params: Configuration parameters for the STT service.
**kwargs: Additional arguments passed to WebsocketSTTService.
"""
super().__init__(
sample_rate=sample_rate,
**kwargs,
)
params = params or ElevenLabsRealtimeSTTService.InputParams()
self._api_key = api_key
self._base_url = base_url
self._model_id = model
self._params = params
self._audio_format = "" # initialized in start()
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if the service can generate processing metrics.
Returns:
True, as ElevenLabs Realtime STT service supports metrics generation.
"""
return True
async def set_language(self, language: Language):
"""Set the transcription language.
Args:
language: The language to use for speech-to-text transcription.
Note:
Changing language requires reconnecting to the WebSocket.
"""
logger.info(f"Switching STT language to: [{language}]")
self._params.language_code = language.value if isinstance(language, Language) else language
# Reconnect with new settings
await self._disconnect()
await self._connect()
async def set_model(self, model: str):
"""Set the STT model.
Args:
model: The model name to use for transcription.
Note:
Changing model requires reconnecting to the WebSocket.
"""
await super().set_model(model)
logger.info(f"Switching STT model to: [{model}]")
self._model_id = model
# Reconnect with new settings
await self._disconnect()
await self._connect()
async def start(self, frame: StartFrame):
"""Start the STT service and establish WebSocket connection.
Args:
frame: Frame indicating service should start.
"""
await super().start(frame)
self._audio_format = audio_format_from_sample_rate(self.sample_rate)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the STT service and close WebSocket connection.
Args:
frame: Frame indicating service should stop.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the STT service and close WebSocket connection.
Args:
frame: Frame indicating service should be cancelled.
"""
await super().cancel(frame)
await self._disconnect()
async def start_metrics(self):
"""Start performance metrics collection for transcription processing."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle speech events.
Args:
frame: The frame to process.
direction: Direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
# Start metrics when user starts speaking
await self.start_metrics()
elif isinstance(frame, UserStoppedSpeakingFrame):
# Send commit when user stops speaking (manual commit mode)
if self._params.commit_strategy == CommitStrategy.MANUAL:
if self._websocket and self._websocket.state is State.OPEN:
try:
commit_message = {
"message_type": "input_audio_chunk",
"audio_base_64": "",
"commit": True,
"sample_rate": self.sample_rate,
}
await self._websocket.send(json.dumps(commit_message))
logger.trace("Sent manual commit to ElevenLabs")
except Exception as e:
logger.warning(f"Failed to send commit: {e}")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results are handled via WebSocket responses.
"""
# Reconnect if connection is closed
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if self._websocket and self._websocket.state is State.OPEN:
try:
# Encode audio as base64
audio_base64 = base64.b64encode(audio).decode("utf-8")
# Send audio chunk
message = {
"message_type": "input_audio_chunk",
"audio_base_64": audio_base64,
"commit": False,
"sample_rate": self.sample_rate,
}
await self._websocket.send(json.dumps(message))
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}")
yield None
async def _connect(self):
"""Establish WebSocket connection to ElevenLabs Realtime STT."""
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close WebSocket connection and cleanup tasks."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Connect to ElevenLabs Realtime STT WebSocket endpoint."""
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to ElevenLabs Realtime STT")
# Build query parameters
params = [f"model_id={self._model_id}"]
if self._params.language_code:
params.append(f"language_code={self._params.language_code}")
params.append(f"encoding={self._audio_format}")
params.append(f"sample_rate={self.sample_rate}")
params.append(f"commit_strategy={self._params.commit_strategy.value}")
# Add VAD parameters if using VAD commit strategy and values are specified
if self._params.commit_strategy == CommitStrategy.VAD:
if self._params.vad_silence_threshold_secs is not None:
params.append(
f"vad_silence_threshold_secs={self._params.vad_silence_threshold_secs}"
)
if self._params.vad_threshold is not None:
params.append(f"vad_threshold={self._params.vad_threshold}")
if self._params.min_speech_duration_ms is not None:
params.append(f"min_speech_duration_ms={self._params.min_speech_duration_ms}")
if self._params.min_silence_duration_ms is not None:
params.append(f"min_silence_duration_ms={self._params.min_silence_duration_ms}")
ws_url = f"wss://{self._base_url}/v1/speech-to-text/realtime?{'&'.join(params)}"
headers = {"xi-api-key": self._api_key}
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
await self._call_event_handler("on_connected")
logger.debug("Connected to ElevenLabs Realtime STT")
except Exception as e:
logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}")
await self.push_error(ErrorFrame(f"Connection error: {str(e)}"))
async def _disconnect_websocket(self):
"""Disconnect from ElevenLabs Realtime STT WebSocket."""
try:
if self._websocket and self._websocket.state is State.OPEN:
logger.debug("Disconnecting from ElevenLabs Realtime STT")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
"""Get the current WebSocket connection.
Returns:
The WebSocket connection.
Raises:
Exception: If WebSocket is not connected.
"""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _process_messages(self):
"""Process incoming WebSocket messages."""
async for message in self._get_websocket():
try:
data = json.loads(message)
await self._process_response(data)
except json.JSONDecodeError:
logger.warning(f"Received non-JSON message: {message}")
except Exception as e:
logger.error(f"Error processing message: {e}")
async def _receive_messages(self):
"""Continuously receive and process WebSocket messages."""
try:
await self._process_messages()
except Exception as e:
logger.warning(f"{self} WebSocket connection closed: {e}")
# Connection closed, will reconnect on next audio chunk
async def _process_response(self, data: dict):
"""Process a response message from ElevenLabs.
Args:
data: Parsed JSON response data.
"""
message_type = data.get("message_type")
if message_type == "session_started":
logger.debug(f"ElevenLabs session started: {data}")
elif message_type == "partial_transcript":
await self._on_partial_transcript(data)
elif message_type == "committed_transcript":
await self._on_committed_transcript(data)
elif message_type == "committed_transcript_with_timestamps":
await self._on_committed_transcript_with_timestamps(data)
elif message_type == "input_error":
error_msg = data.get("error", "Unknown input error")
logger.error(f"ElevenLabs input error: {error_msg}")
await self.push_error(ErrorFrame(f"Input error: {error_msg}"))
elif message_type in ["auth_error", "quota_exceeded", "transcriber_error", "error"]:
error_msg = data.get("error", data.get("message", "Unknown error"))
logger.error(f"ElevenLabs error ({message_type}): {error_msg}")
await self.push_error(ErrorFrame(f"{message_type}: {error_msg}"))
else:
logger.debug(f"Unknown message type: {message_type}")
async def _on_partial_transcript(self, data: dict):
"""Handle partial transcript (interim results).
Args:
data: Partial transcript data.
"""
text = data.get("text", "").strip()
if not text:
return
await self.stop_ttfb_metrics()
# Get language if provided
language = data.get("language_code")
logger.trace(f"Partial transcript: [{text}]")
await self.push_frame(
InterimTranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
language,
result=data,
)
)
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[str] = None
):
"""Handle a transcription result with tracing."""
pass
async def _on_committed_transcript(self, data: dict):
"""Handle committed transcript (final results).
Args:
data: Committed transcript data.
"""
text = data.get("text", "").strip()
if not text:
return
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Get language if provided
language = data.get("language_code")
logger.debug(f"Committed transcript: [{text}]")
await self._handle_transcription(text, True, language)
await self.push_frame(
TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
language,
result=data,
)
)
async def _on_committed_transcript_with_timestamps(self, data: dict):
"""Handle committed transcript with word-level timestamps.
Args:
data: Committed transcript data with timestamps.
"""
text = data.get("text", "").strip()
if not text:
return
logger.debug(f"Committed transcript with timestamps: [{text}]")
logger.trace(f"Timestamps: {data.get('words', [])}")
# This is sent after the committed_transcript, so we don't need to
# push another TranscriptionFrame, but we could use the timestamps
# for additional processing if needed in the future