_calculate_wait_times as private, add and use WebsocketServiceException

This commit is contained in:
Mark Backman
2025-01-14 13:09:41 -05:00
parent e3d8910814
commit b53bc8a879
5 changed files with 22 additions and 15 deletions

View File

@@ -187,7 +187,9 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
async def _disconnect(self):
await self._disconnect_websocket()

View File

@@ -296,7 +296,9 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
async def _disconnect(self):

View File

@@ -113,7 +113,9 @@ class LmntTTSService(TTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
async def _disconnect(self):
await self._disconnect_websocket()

View File

@@ -165,7 +165,9 @@ class PlayHTTTSService(TTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
async def _disconnect(self):
await self._disconnect_websocket()

View File

@@ -6,7 +6,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Optional
from typing import Awaitable, Callable, Optional
import websockets
from loguru import logger
@@ -50,7 +50,7 @@ class WebsocketService(ABC):
await self._connect_websocket()
return await self._verify_connection()
def calculate_wait_time(
def _calculate_wait_time(
self, attempt: int, min_wait: float = 4, max_wait: float = 10, multiplier: float = 1
) -> float:
"""Calculate exponential backoff wait time.
@@ -71,8 +71,12 @@ class WebsocketService(ABC):
except (ValueError, ArithmeticError):
return max_wait
async def _receive_task_handler(self):
"""Handles WebSocket message receiving with automatic retry logic."""
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
"""Handles WebSocket message receiving with automatic retry logic.
Args:
report_error: Callback to report errors
"""
retry_count = 0
MAX_RETRIES = 3
@@ -90,7 +94,7 @@ class WebsocketService(ABC):
if retry_count >= MAX_RETRIES:
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
await report_error(ErrorFrame(message, fatal=True))
break
logger.warning(f"{self} connection error, will retry: {e}")
@@ -98,7 +102,7 @@ class WebsocketService(ABC):
try:
if await self._reconnect_websocket(retry_count):
retry_count = 0 # Reset counter on successful reconnection
wait_time = self.calculate_wait_time(retry_count)
wait_time = self._calculate_wait_time(retry_count)
await asyncio.sleep(wait_time)
except Exception as reconnect_error:
logger.error(f"{self} reconnection failed: {reconnect_error}")
@@ -118,8 +122,3 @@ class WebsocketService(ABC):
async def _receive_messages(self):
"""Implement service-specific message receiving logic."""
pass
@abstractmethod
async def push_error(self, error: ErrorFrame):
"""Implement service-specific error handling."""
pass