From 02cc6f3d5652a5ec93cee3158578aed316ac7007 Mon Sep 17 00:00:00 2001 From: jqueguiner Date: Tue, 3 Jun 2025 03:16:57 -0700 Subject: [PATCH] Enhance GladiaSTTService with reconnection and audio buffer management features - Added parameters for maximum reconnection attempts, reconnection delay, and maximum audio buffer size. - Implemented automatic reconnection logic with exponential backoff. - Introduced audio buffer management to handle audio data efficiently, including trimming excess data. - Updated connection handling to ensure proper cleanup and management of WebSocket connections. - Enhanced audio sending logic to support buffered audio transmission after reconnections. --- src/pipecat/services/gladia/stt.py | 217 +++++++++++++++++++++++++---- 1 file changed, 187 insertions(+), 30 deletions(-) diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index 6ac5edad9..b07fd0345 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -195,6 +195,9 @@ class GladiaSTTService(STTService): sample_rate: Optional[int] = None, model: str = "solaria-1", params: Optional[GladiaInputParams] = None, + max_reconnection_attempts: int = 5, + reconnection_delay: float = 1.0, + max_buffer_size: int = 1024 * 1024 * 5, # 5MB default buffer **kwargs, ): """Initialize the Gladia STT service. @@ -207,6 +210,9 @@ class GladiaSTTService(STTService): model: Model to use ("solaria-1", "solaria-mini-1", "fast", or "accurate") params: Additional configuration parameters + max_reconnection_attempts: Maximum number of reconnection attempts + reconnection_delay: Initial delay between reconnection attempts (exponential backoff) + max_buffer_size: Maximum size of audio buffer in bytes **kwargs: Additional arguments passed to the STTService """ super().__init__(sample_rate=sample_rate, **kwargs) @@ -232,6 +238,23 @@ class GladiaSTTService(STTService): self._keepalive_task = None self._settings = {} + # Reconnection settings + self._max_reconnection_attempts = max_reconnection_attempts + self._reconnection_delay = reconnection_delay + self._reconnection_attempts = 0 + self._session_url = None + self._connection_active = False + + # Audio buffer management + self._audio_buffer = bytearray() + self._bytes_sent = 0 + self._max_buffer_size = max_buffer_size + self._buffer_lock = asyncio.Lock() + + # Connection management + self._connection_task = None + self._should_reconnect = True + def can_generate_metrics(self) -> bool: return True @@ -293,36 +316,149 @@ class GladiaSTTService(STTService): async def start(self, frame: StartFrame): """Start the Gladia STT websocket connection.""" await super().start(frame) - if self._websocket: + if self._connection_task: return - settings = self._prepare_settings() - response = await self._setup_gladia(settings) - self._websocket = await websockets.connect(response["url"]) - if self._websocket and not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler()) - if self._websocket and not self._keepalive_task: - self._keepalive_task = self.create_task(self._keepalive_task_handler()) + + self._should_reconnect = True + self._connection_task = self.create_task(self._connection_handler()) async def stop(self, frame: EndFrame): """Stop the Gladia STT websocket connection.""" await super().stop(frame) + self._should_reconnect = False await self._send_stop_recording() - if self._keepalive_task: - await self.cancel_task(self._keepalive_task) - self._keepalive_task = None + if self._connection_task: + await self.cancel_task(self._connection_task) + self._connection_task = None - if self._websocket: - await self._websocket.close() - self._websocket = None - - if self._receive_task: - await self.wait_for_task(self._receive_task) - self._receive_task = None + await self._cleanup_connection() async def cancel(self, frame: CancelFrame): """Cancel the Gladia STT websocket connection.""" await super().cancel(frame) + self._should_reconnect = False + + if self._connection_task: + await self.cancel_task(self._connection_task) + self._connection_task = None + + await self._cleanup_connection() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Run speech-to-text on audio data.""" + await self.start_ttfb_metrics() + await self.start_processing_metrics() + + # Add audio to buffer + async with self._buffer_lock: + self._audio_buffer.extend(audio) + # Trim buffer if it exceeds max size + if len(self._audio_buffer) > self._max_buffer_size: + trim_size = len(self._audio_buffer) - self._max_buffer_size + self._audio_buffer = self._audio_buffer[trim_size:] + self._bytes_sent = max(0, self._bytes_sent - trim_size) + logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes") + + # Send audio if connected + if self._connection_active and self._websocket and not self._websocket.closed: + await self._send_audio(audio) + + yield None + + async def _connection_handler(self): + """Handle WebSocket connection with automatic reconnection.""" + while self._should_reconnect: + try: + # Initialize session if needed + if not self._session_url: + settings = self._prepare_settings() + response = await self._setup_gladia(settings) + self._session_url = response["url"] + self._reconnection_attempts = 0 + + # Connect with automatic reconnection + async for websocket in websockets.connect(self._session_url): + try: + self._websocket = websocket + self._connection_active = True + logger.info("Connected to Gladia WebSocket") + + # Send buffered audio if any + await self._send_buffered_audio() + + # Start tasks + receive_task = asyncio.create_task(self._receive_task_handler()) + keepalive_task = asyncio.create_task(self._keepalive_task_handler()) + + # Wait for tasks to complete + await asyncio.gather(receive_task, keepalive_task) + + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"WebSocket connection closed: {e}") + self._connection_active = False + + # Clean up tasks + if "receive_task" in locals(): + receive_task.cancel() + if "keepalive_task" in locals(): + keepalive_task.cancel() + + # Check if we should reconnect + if not self._should_reconnect: + break + + # Implement exponential backoff + self._reconnection_attempts += 1 + if self._reconnection_attempts > self._max_reconnection_attempts: + logger.error( + f"Max reconnection attempts ({self._max_reconnection_attempts}) reached" + ) + self._should_reconnect = False + break + + delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1)) + logger.info( + f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})" + ) + await asyncio.sleep(delay) + + except Exception as e: + logger.error(f"Error in WebSocket connection: {e}") + self._connection_active = False + + # Same reconnection logic as above + if not self._should_reconnect: + break + + self._reconnection_attempts += 1 + if self._reconnection_attempts > self._max_reconnection_attempts: + logger.error( + f"Max reconnection attempts ({self._max_reconnection_attempts}) reached" + ) + self._should_reconnect = False + break + + delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1)) + logger.info( + f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})" + ) + await asyncio.sleep(delay) + + except Exception as e: + logger.error(f"Error in connection handler: {e}") + self._connection_active = False + + if not self._should_reconnect: + break + + # Reset session URL to get a new one + self._session_url = None + await asyncio.sleep(self._reconnection_delay) + + async def _cleanup_connection(self): + """Clean up connection resources.""" + self._connection_active = False if self._keepalive_task: await self.cancel_task(self._keepalive_task) @@ -336,13 +472,6 @@ class GladiaSTTService(STTService): await self.cancel_task(self._receive_task) self._receive_task = None - async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - """Run speech-to-text on audio data.""" - await self.start_ttfb_metrics() - await self.start_processing_metrics() - await self._send_audio(audio) - yield None - async def _setup_gladia(self, settings: Dict[str, Any]): async with aiohttp.ClientSession() as session: async with session.post( @@ -369,9 +498,25 @@ class GladiaSTTService(STTService): await self.stop_processing_metrics() async def _send_audio(self, audio: bytes): - data = base64.b64encode(audio).decode("utf-8") - message = {"type": "audio_chunk", "data": {"chunk": data}} - await self._websocket.send(json.dumps(message)) + """Send audio chunk with proper message format.""" + if self._websocket and not self._websocket.closed: + data = base64.b64encode(audio).decode("utf-8") + message = {"type": "audio_chunk", "data": {"chunk": data}} + await self._websocket.send(json.dumps(message)) + + async def _send_buffered_audio(self): + """Send any buffered audio after reconnection.""" + async with self._buffer_lock: + if self._bytes_sent < len(self._audio_buffer): + buffered_data = self._audio_buffer[self._bytes_sent :] + if buffered_data: + logger.info(f"Sending {len(buffered_data)} bytes of buffered audio") + # Send in chunks to avoid overwhelming the connection + chunk_size = 16384 # 16KB chunks + for i in range(0, len(buffered_data), chunk_size): + chunk = buffered_data[i : i + chunk_size] + await self._send_audio(bytes(chunk)) + await asyncio.sleep(0.01) # Small delay between chunks async def _send_stop_recording(self): if self._websocket and not self._websocket.closed: @@ -380,7 +525,7 @@ class GladiaSTTService(STTService): async def _keepalive_task_handler(self): """Send periodic empty audio chunks to keep the connection alive.""" try: - while True: + while self._connection_active: # Send keepalive every 20 seconds (Gladia times out after 30 seconds) await asyncio.sleep(20) if self._websocket and not self._websocket.closed: @@ -399,7 +544,19 @@ class GladiaSTTService(STTService): try: async for message in self._websocket: content = json.loads(message) - if content["type"] == "transcript": + + # Handle audio chunk acknowledgments + if content["type"] == "audio_chunk" and content.get("acknowledged"): + byte_range = content["data"]["byte_range"] + async with self._buffer_lock: + # Update bytes sent and trim acknowledged data from buffer + end_byte = byte_range[1] + if end_byte > self._bytes_sent: + trim_size = end_byte - self._bytes_sent + self._audio_buffer = self._audio_buffer[trim_size:] + self._bytes_sent = end_byte + + elif content["type"] == "transcript": utterance = content["data"]["utterance"] confidence = utterance.get("confidence", 0) language = utterance["language"]