diff --git a/CHANGELOG.md b/CHANGELOG.md index f8f8bed9b..de9cf3328 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT + service from ElevenLabs. + - Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS services that subclass `TTSService` can indicate whether the text in the `TTSTextFrame`s they push already contain any necessary inter-frame spaces. diff --git a/examples/foundational/07d-interruptible-elevenlabs.py b/examples/foundational/07d-interruptible-elevenlabs.py index da2a8eb00..2d14a4d77 100644 --- a/examples/foundational/07d-interruptible-elevenlabs.py +++ b/examples/foundational/07d-interruptible-elevenlabs.py @@ -22,7 +22,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport -from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.elevenlabs.stt import ElevenLabsRealtimeSTTService from pipecat.services.elevenlabs.tts import ElevenLabsTTSService from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -60,7 +60,7 @@ transport_params = { async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): logger.info(f"Starting bot") - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = ElevenLabsRealtimeSTTService(api_key=os.getenv("ELEVENLABS_API_KEY")) tts = ElevenLabsTTSService( api_key=os.getenv("ELEVENLABS_API_KEY", ""), diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index bbc86d97e..3b3349ba2 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -11,19 +11,43 @@ using segmented audio processing. The service uploads audio files and receives transcription results directly. """ +import base64 import io +import json +from enum import Enum from typing import AsyncGenerator, Optional import aiohttp from loguru import logger from pydantic import BaseModel -from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame -from pipecat.services.stt_service import SegmentedSTTService +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.time import time_now_iso8601 from pipecat.utils.tracing.service_decorators import traced_stt +try: + from websockets.asyncio.client import connect as websocket_connect + from websockets.protocol import State +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use ElevenLabs Realtime STT, you need to `pip install pipecat-ai[elevenlabs]`." + ) + raise Exception(f"Missing module: {e}") + def language_to_elevenlabs_language(language: Language) -> Optional[str]: """Convert a Language enum to ElevenLabs language code. @@ -329,3 +353,455 @@ class ElevenLabsSTTService(SegmentedSTTService): except Exception as e: logger.error(f"ElevenLabs STT error: {e}") yield ErrorFrame(f"ElevenLabs STT error: {str(e)}") + + +def audio_format_from_sample_rate(sample_rate: int) -> str: + """Get the appropriate audio format string for a given sample rate. + + Args: + sample_rate: The audio sample rate in Hz. + + Returns: + The ElevenLabs audio format string. + """ + match sample_rate: + case 8000: + return "pcm_8000" + case 16000: + return "pcm_16000" + case 22050: + return "pcm_22050" + case 24000: + return "pcm_24000" + case 44100: + return "pcm_44100" + case 48000: + return "pcm_48000" + logger.warning( + f"ElevenLabsRealtimeSTTService: No audio format available for {sample_rate} sample rate, using pcm_16000" + ) + return "pcm_16000" + + +class CommitStrategy(str, Enum): + """Commit strategies for transcript segmentation.""" + + MANUAL = "manual" + VAD = "vad" + + +class ElevenLabsRealtimeSTTService(WebsocketSTTService): + """Speech-to-text service using ElevenLabs' Realtime WebSocket API. + + This service uses ElevenLabs' Realtime Speech-to-Text API to perform transcription + with ultra-low latency. It supports both partial (interim) and committed (final) + transcripts, and can use either manual commit control or automatic Voice Activity + Detection (VAD) for segment boundaries. + + By default, uses manual commit strategy where Pipecat's VAD controls when to + commit transcript segments, providing consistency with other STT services. + """ + + class InputParams(BaseModel): + """Configuration parameters for ElevenLabs Realtime STT API. + + Parameters: + language_code: ISO-639-1 or ISO-639-3 language code. Leave None for auto-detection. + commit_strategy: How to segment speech - manual (Pipecat VAD) or vad (ElevenLabs VAD). + vad_silence_threshold_secs: Seconds of silence before VAD commits (0.3-3.0). + Only used when commit_strategy is VAD. None uses ElevenLabs default. + vad_threshold: VAD sensitivity (0.1-0.9, lower is more sensitive). + Only used when commit_strategy is VAD. None uses ElevenLabs default. + min_speech_duration_ms: Minimum speech duration for VAD (50-2000ms). + Only used when commit_strategy is VAD. None uses ElevenLabs default. + min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms). + Only used when commit_strategy is VAD. None uses ElevenLabs default. + """ + + language_code: Optional[str] = None + commit_strategy: CommitStrategy = CommitStrategy.MANUAL + vad_silence_threshold_secs: Optional[float] = None + vad_threshold: Optional[float] = None + min_speech_duration_ms: Optional[int] = None + min_silence_duration_ms: Optional[int] = None + + def __init__( + self, + *, + api_key: str, + base_url: str = "api.elevenlabs.io", + model: str = "scribe_v2_realtime", + sample_rate: Optional[int] = None, + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize the ElevenLabs Realtime STT service. + + Args: + api_key: ElevenLabs API key for authentication. + base_url: Base URL for ElevenLabs WebSocket API. + model: Model ID for transcription. Defaults to "scribe_v2_realtime". + sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate. + params: Configuration parameters for the STT service. + **kwargs: Additional arguments passed to WebsocketSTTService. + """ + super().__init__( + sample_rate=sample_rate, + **kwargs, + ) + + params = params or ElevenLabsRealtimeSTTService.InputParams() + + self._api_key = api_key + self._base_url = base_url + self._model_id = model + self._params = params + self._audio_format = "" # initialized in start() + self._receive_task = None + + def can_generate_metrics(self) -> bool: + """Check if the service can generate processing metrics. + + Returns: + True, as ElevenLabs Realtime STT service supports metrics generation. + """ + return True + + async def set_language(self, language: Language): + """Set the transcription language. + + Args: + language: The language to use for speech-to-text transcription. + + Note: + Changing language requires reconnecting to the WebSocket. + """ + logger.info(f"Switching STT language to: [{language}]") + self._params.language_code = language.value if isinstance(language, Language) else language + # Reconnect with new settings + await self._disconnect() + await self._connect() + + async def set_model(self, model: str): + """Set the STT model. + + Args: + model: The model name to use for transcription. + + Note: + Changing model requires reconnecting to the WebSocket. + """ + await super().set_model(model) + logger.info(f"Switching STT model to: [{model}]") + self._model_id = model + # Reconnect with new settings + await self._disconnect() + await self._connect() + + async def start(self, frame: StartFrame): + """Start the STT service and establish WebSocket connection. + + Args: + frame: Frame indicating service should start. + """ + await super().start(frame) + self._audio_format = audio_format_from_sample_rate(self.sample_rate) + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the STT service and close WebSocket connection. + + Args: + frame: Frame indicating service should stop. + """ + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the STT service and close WebSocket connection. + + Args: + frame: Frame indicating service should be cancelled. + """ + await super().cancel(frame) + await self._disconnect() + + async def start_metrics(self): + """Start performance metrics collection for transcription processing.""" + await self.start_ttfb_metrics() + await self.start_processing_metrics() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle speech events. + + Args: + frame: The frame to process. + direction: Direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, UserStartedSpeakingFrame): + # Start metrics when user starts speaking + await self.start_metrics() + elif isinstance(frame, UserStoppedSpeakingFrame): + # Send commit when user stops speaking (manual commit mode) + if self._params.commit_strategy == CommitStrategy.MANUAL: + if self._websocket and self._websocket.state is State.OPEN: + try: + commit_message = { + "message_type": "input_audio_chunk", + "audio_base_64": "", + "commit": True, + "sample_rate": self.sample_rate, + } + await self._websocket.send(json.dumps(commit_message)) + logger.trace("Sent manual commit to ElevenLabs") + except Exception as e: + logger.warning(f"Failed to send commit: {e}") + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Process audio data for speech-to-text transcription. + + Args: + audio: Raw audio bytes to transcribe. + + Yields: + None - transcription results are handled via WebSocket responses. + """ + # Reconnect if connection is closed + if not self._websocket or self._websocket.state is State.CLOSED: + await self._connect() + + if self._websocket and self._websocket.state is State.OPEN: + try: + # Encode audio as base64 + audio_base64 = base64.b64encode(audio).decode("utf-8") + + # Send audio chunk + message = { + "message_type": "input_audio_chunk", + "audio_base_64": audio_base64, + "commit": False, + "sample_rate": self.sample_rate, + } + await self._websocket.send(json.dumps(message)) + except Exception as e: + logger.error(f"Error sending audio: {e}") + yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}") + + yield None + + async def _connect(self): + """Establish WebSocket connection to ElevenLabs Realtime STT.""" + await self._connect_websocket() + + if self._websocket and not self._receive_task: + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) + + async def _disconnect(self): + """Close WebSocket connection and cleanup tasks.""" + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + await self._disconnect_websocket() + + async def _connect_websocket(self): + """Connect to ElevenLabs Realtime STT WebSocket endpoint.""" + try: + if self._websocket and self._websocket.state is State.OPEN: + return + + logger.debug("Connecting to ElevenLabs Realtime STT") + + # Build query parameters + params = [f"model_id={self._model_id}"] + + if self._params.language_code: + params.append(f"language_code={self._params.language_code}") + + params.append(f"encoding={self._audio_format}") + params.append(f"sample_rate={self.sample_rate}") + params.append(f"commit_strategy={self._params.commit_strategy.value}") + + # Add VAD parameters if using VAD commit strategy and values are specified + if self._params.commit_strategy == CommitStrategy.VAD: + if self._params.vad_silence_threshold_secs is not None: + params.append( + f"vad_silence_threshold_secs={self._params.vad_silence_threshold_secs}" + ) + if self._params.vad_threshold is not None: + params.append(f"vad_threshold={self._params.vad_threshold}") + if self._params.min_speech_duration_ms is not None: + params.append(f"min_speech_duration_ms={self._params.min_speech_duration_ms}") + if self._params.min_silence_duration_ms is not None: + params.append(f"min_silence_duration_ms={self._params.min_silence_duration_ms}") + + ws_url = f"wss://{self._base_url}/v1/speech-to-text/realtime?{'&'.join(params)}" + + headers = {"xi-api-key": self._api_key} + + self._websocket = await websocket_connect(ws_url, additional_headers=headers) + await self._call_event_handler("on_connected") + logger.debug("Connected to ElevenLabs Realtime STT") + except Exception as e: + logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}") + await self.push_error(ErrorFrame(f"Connection error: {str(e)}")) + + async def _disconnect_websocket(self): + """Disconnect from ElevenLabs Realtime STT WebSocket.""" + try: + if self._websocket and self._websocket.state is State.OPEN: + logger.debug("Disconnecting from ElevenLabs Realtime STT") + await self._websocket.close() + except Exception as e: + logger.error(f"{self} error closing websocket: {e}") + finally: + self._websocket = None + await self._call_event_handler("on_disconnected") + + def _get_websocket(self): + """Get the current WebSocket connection. + + Returns: + The WebSocket connection. + + Raises: + Exception: If WebSocket is not connected. + """ + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _process_messages(self): + """Process incoming WebSocket messages.""" + async for message in self._get_websocket(): + try: + data = json.loads(message) + await self._process_response(data) + except json.JSONDecodeError: + logger.warning(f"Received non-JSON message: {message}") + except Exception as e: + logger.error(f"Error processing message: {e}") + + async def _receive_messages(self): + """Continuously receive and process WebSocket messages.""" + try: + await self._process_messages() + except Exception as e: + logger.warning(f"{self} WebSocket connection closed: {e}") + # Connection closed, will reconnect on next audio chunk + + async def _process_response(self, data: dict): + """Process a response message from ElevenLabs. + + Args: + data: Parsed JSON response data. + """ + message_type = data.get("message_type") + + if message_type == "session_started": + logger.debug(f"ElevenLabs session started: {data}") + + elif message_type == "partial_transcript": + await self._on_partial_transcript(data) + + elif message_type == "committed_transcript": + await self._on_committed_transcript(data) + + elif message_type == "committed_transcript_with_timestamps": + await self._on_committed_transcript_with_timestamps(data) + + elif message_type == "input_error": + error_msg = data.get("error", "Unknown input error") + logger.error(f"ElevenLabs input error: {error_msg}") + await self.push_error(ErrorFrame(f"Input error: {error_msg}")) + + elif message_type in ["auth_error", "quota_exceeded", "transcriber_error", "error"]: + error_msg = data.get("error", data.get("message", "Unknown error")) + logger.error(f"ElevenLabs error ({message_type}): {error_msg}") + await self.push_error(ErrorFrame(f"{message_type}: {error_msg}")) + + else: + logger.debug(f"Unknown message type: {message_type}") + + async def _on_partial_transcript(self, data: dict): + """Handle partial transcript (interim results). + + Args: + data: Partial transcript data. + """ + text = data.get("text", "").strip() + if not text: + return + + await self.stop_ttfb_metrics() + + # Get language if provided + language = data.get("language_code") + + logger.trace(f"Partial transcript: [{text}]") + + await self.push_frame( + InterimTranscriptionFrame( + text, + self._user_id, + time_now_iso8601(), + language, + result=data, + ) + ) + + @traced_stt + async def _handle_transcription( + self, transcript: str, is_final: bool, language: Optional[str] = None + ): + """Handle a transcription result with tracing.""" + pass + + async def _on_committed_transcript(self, data: dict): + """Handle committed transcript (final results). + + Args: + data: Committed transcript data. + """ + text = data.get("text", "").strip() + if not text: + return + + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + + # Get language if provided + language = data.get("language_code") + + logger.debug(f"Committed transcript: [{text}]") + + await self._handle_transcription(text, True, language) + + await self.push_frame( + TranscriptionFrame( + text, + self._user_id, + time_now_iso8601(), + language, + result=data, + ) + ) + + async def _on_committed_transcript_with_timestamps(self, data: dict): + """Handle committed transcript with word-level timestamps. + + Args: + data: Committed transcript data with timestamps. + """ + text = data.get("text", "").strip() + if not text: + return + + logger.debug(f"Committed transcript with timestamps: [{text}]") + logger.trace(f"Timestamps: {data.get('words', [])}") + + # This is sent after the committed_transcript, so we don't need to + # push another TranscriptionFrame, but we could use the timestamps + # for additional processing if needed in the future