From b53bc8a879cd5ae58e7d86ca134e1969a5d5eb2c Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 14 Jan 2025 13:09:41 -0500 Subject: [PATCH] _calculate_wait_times as private, add and use WebsocketServiceException --- src/pipecat/services/cartesia.py | 4 +++- src/pipecat/services/elevenlabs.py | 4 +++- src/pipecat/services/lmnt.py | 4 +++- src/pipecat/services/playht.py | 4 +++- src/pipecat/services/websocket_service.py | 21 ++++++++++----------- 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index a37c7f323..9712ce600 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -187,7 +187,9 @@ class CartesiaTTSService(WordTTSService, WebsocketService): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): await self._disconnect_websocket() diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 5978cd612..c1d326dfc 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -296,7 +296,9 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) async def _disconnect(self): diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index d58dbc34e..633c24265 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -113,7 +113,9 @@ class LmntTTSService(TTSService, WebsocketService): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): await self._disconnect_websocket() diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 912bce94c..a511e2456 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -165,7 +165,9 @@ class PlayHTTTSService(TTSService, WebsocketService): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): await self._disconnect_websocket() diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index 2ceeb2a8f..365f5a7c8 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -6,7 +6,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Optional +from typing import Awaitable, Callable, Optional import websockets from loguru import logger @@ -50,7 +50,7 @@ class WebsocketService(ABC): await self._connect_websocket() return await self._verify_connection() - def calculate_wait_time( + def _calculate_wait_time( self, attempt: int, min_wait: float = 4, max_wait: float = 10, multiplier: float = 1 ) -> float: """Calculate exponential backoff wait time. @@ -71,8 +71,12 @@ class WebsocketService(ABC): except (ValueError, ArithmeticError): return max_wait - async def _receive_task_handler(self): - """Handles WebSocket message receiving with automatic retry logic.""" + async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]): + """Handles WebSocket message receiving with automatic retry logic. + + Args: + report_error: Callback to report errors + """ retry_count = 0 MAX_RETRIES = 3 @@ -90,7 +94,7 @@ class WebsocketService(ABC): if retry_count >= MAX_RETRIES: message = f"{self} error receiving messages: {e}" logger.error(message) - await self.push_error(ErrorFrame(message, fatal=True)) + await report_error(ErrorFrame(message, fatal=True)) break logger.warning(f"{self} connection error, will retry: {e}") @@ -98,7 +102,7 @@ class WebsocketService(ABC): try: if await self._reconnect_websocket(retry_count): retry_count = 0 # Reset counter on successful reconnection - wait_time = self.calculate_wait_time(retry_count) + wait_time = self._calculate_wait_time(retry_count) await asyncio.sleep(wait_time) except Exception as reconnect_error: logger.error(f"{self} reconnection failed: {reconnect_error}") @@ -118,8 +122,3 @@ class WebsocketService(ABC): async def _receive_messages(self): """Implement service-specific message receiving logic.""" pass - - @abstractmethod - async def push_error(self, error: ErrorFrame): - """Implement service-specific error handling.""" - pass