_calculate_wait_times as private, add and use WebsocketServiceException
This commit is contained in:
@@ -187,7 +187,9 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._receive_task = self.get_event_loop().create_task(
|
||||
self._receive_task_handler(self.push_error)
|
||||
)
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
@@ -296,7 +296,9 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._receive_task = self.get_event_loop().create_task(
|
||||
self._receive_task_handler(self.push_error)
|
||||
)
|
||||
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
|
||||
@@ -113,7 +113,9 @@ class LmntTTSService(TTSService, WebsocketService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._receive_task = self.get_event_loop().create_task(
|
||||
self._receive_task_handler(self.push_error)
|
||||
)
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
@@ -165,7 +165,9 @@ class PlayHTTTSService(TTSService, WebsocketService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._receive_task = self.get_event_loop().create_task(
|
||||
self._receive_task_handler(self.push_error)
|
||||
)
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
@@ -50,7 +50,7 @@ class WebsocketService(ABC):
|
||||
await self._connect_websocket()
|
||||
return await self._verify_connection()
|
||||
|
||||
def calculate_wait_time(
|
||||
def _calculate_wait_time(
|
||||
self, attempt: int, min_wait: float = 4, max_wait: float = 10, multiplier: float = 1
|
||||
) -> float:
|
||||
"""Calculate exponential backoff wait time.
|
||||
@@ -71,8 +71,12 @@ class WebsocketService(ABC):
|
||||
except (ValueError, ArithmeticError):
|
||||
return max_wait
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handles WebSocket message receiving with automatic retry logic."""
|
||||
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
|
||||
"""Handles WebSocket message receiving with automatic retry logic.
|
||||
|
||||
Args:
|
||||
report_error: Callback to report errors
|
||||
"""
|
||||
retry_count = 0
|
||||
MAX_RETRIES = 3
|
||||
|
||||
@@ -90,7 +94,7 @@ class WebsocketService(ABC):
|
||||
if retry_count >= MAX_RETRIES:
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
await self.push_error(ErrorFrame(message, fatal=True))
|
||||
await report_error(ErrorFrame(message, fatal=True))
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
@@ -98,7 +102,7 @@ class WebsocketService(ABC):
|
||||
try:
|
||||
if await self._reconnect_websocket(retry_count):
|
||||
retry_count = 0 # Reset counter on successful reconnection
|
||||
wait_time = self.calculate_wait_time(retry_count)
|
||||
wait_time = self._calculate_wait_time(retry_count)
|
||||
await asyncio.sleep(wait_time)
|
||||
except Exception as reconnect_error:
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
@@ -118,8 +122,3 @@ class WebsocketService(ABC):
|
||||
async def _receive_messages(self):
|
||||
"""Implement service-specific message receiving logic."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def push_error(self, error: ErrorFrame):
|
||||
"""Implement service-specific error handling."""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user