WebsocketTTSService: add on_connection_error and reconnect_on_error
This commit is contained in:
@@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added a `reconnect_on_error` parameter to websocket-based TTS services as well
|
||||
as a `on_connection_error` event handler. The `reconnect_on_error` indicates
|
||||
whether the TTS service should reconnect on error. The `on_connection_error`
|
||||
will always get called if there's any error no matter the value of
|
||||
`reconnect_on_error`. This allows, for example, to fallback to a different TTS
|
||||
provider if something goes wrong with the current one.
|
||||
|
||||
- Added new `SkipTagsAggregator` that extends `BaseTextAggregator` to aggregate
|
||||
text and skips end of sentence matching if aggregated text is between
|
||||
start/end tags.
|
||||
|
||||
@@ -425,12 +425,14 @@ class PipelineTask(BaseTask):
|
||||
# Tell the task we should stop nicely.
|
||||
await self.queue_frame(StopFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
logger.error(f"Error running app: {frame}")
|
||||
if frame.fatal:
|
||||
logger.error(f"A fatal error occurred: {frame}")
|
||||
# Cancel all tasks downstream.
|
||||
await self.queue_frame(CancelFrame())
|
||||
# Tell the task we should stop.
|
||||
await self.queue_frame(StopTaskFrame())
|
||||
else:
|
||||
logger.warning(f"Something went wrong: {frame}")
|
||||
self._up_queue.task_done()
|
||||
|
||||
async def _process_down_queue(self):
|
||||
|
||||
@@ -549,11 +549,25 @@ class WordTTSService(TTSService):
|
||||
|
||||
|
||||
class WebsocketTTSService(TTSService, WebsocketService):
|
||||
"""This is a base class for websocket-based TTS services."""
|
||||
"""This is a base class for websocket-based TTS services.
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
If an error occurs with the websocket, an "on_connection_error" event will
|
||||
be triggered:
|
||||
|
||||
@tts.event_handler("on_connection_error")
|
||||
async def on_connection_error(tts: TTSService, error: str):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
TTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
await self.push_error(error)
|
||||
|
||||
|
||||
class InterruptibleTTSService(WebsocketTTSService):
|
||||
@@ -590,11 +604,23 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
|
||||
"""This is a base class for websocket-based TTS services that support word
|
||||
timestamps.
|
||||
|
||||
If an error occurs with the websocket a "on_connection_error" event will be
|
||||
triggered:
|
||||
|
||||
@tts.event_handler("on_connection_error")
|
||||
async def on_connection_error(tts: TTSService, error: str):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
WordTTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
await self.push_error(error)
|
||||
|
||||
|
||||
class InterruptibleWordTTSService(WebsocketWordTTSService):
|
||||
|
||||
@@ -187,7 +187,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
@@ -207,6 +207,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
|
||||
@@ -309,7 +309,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
if not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
@@ -364,6 +364,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
|
||||
@@ -107,7 +107,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
@@ -132,6 +132,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
|
||||
@@ -112,7 +112,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
@@ -147,6 +147,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Disconnect from LMNT websocket."""
|
||||
|
||||
@@ -162,7 +162,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -197,6 +197,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
|
||||
@@ -160,7 +160,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
@@ -183,12 +183,14 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
raise ValueError("WebSocket URL is not a string")
|
||||
|
||||
self._websocket = await websockets.connect(self._websocket_url)
|
||||
except ValueError as ve:
|
||||
logger.error(f"{self} initialization error: {ve}")
|
||||
except ValueError as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
|
||||
@@ -171,7 +171,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
@@ -194,6 +194,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close websocket connection and reset state."""
|
||||
|
||||
@@ -19,9 +19,10 @@ from pipecat.utils.network import exponential_backoff_time
|
||||
class WebsocketService(ABC):
|
||||
"""Base class for websocket-based services with reconnection logic."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
"""Initialize websocket attributes."""
|
||||
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._reconnect_on_error = reconnect_on_error
|
||||
|
||||
async def _verify_connection(self) -> bool:
|
||||
"""Verify websocket connection is working.
|
||||
@@ -72,24 +73,29 @@ class WebsocketService(ABC):
|
||||
self._websocket.close_rcvd_then_sent,
|
||||
)
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= MAX_RETRIES:
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
await report_error(ErrorFrame(message, fatal=True))
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
|
||||
if self._reconnect_on_error:
|
||||
retry_count += 1
|
||||
if retry_count >= MAX_RETRIES:
|
||||
await report_error(ErrorFrame(message, fatal=True))
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
await report_error(ErrorFrame(message))
|
||||
|
||||
try:
|
||||
if await self._reconnect_websocket(retry_count):
|
||||
retry_count = 0 # Reset counter on successful reconnection
|
||||
wait_time = exponential_backoff_time(retry_count)
|
||||
await asyncio.sleep(wait_time)
|
||||
except Exception as reconnect_error:
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
else:
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
|
||||
try:
|
||||
if await self._reconnect_websocket(retry_count):
|
||||
retry_count = 0 # Reset counter on successful reconnection
|
||||
wait_time = exponential_backoff_time(retry_count)
|
||||
await asyncio.sleep(wait_time)
|
||||
except Exception as reconnect_error:
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
continue
|
||||
|
||||
@abstractmethod
|
||||
async def _connect(self):
|
||||
"""Implement service-specific connection logic. This function will
|
||||
|
||||
Reference in New Issue
Block a user