WebsocketTTSService: add on_connection_error and reconnect_on_error

This commit is contained in:
Aleix Conchillo Flaqué
2025-03-19 16:00:57 -07:00
parent afb26be0ad
commit a3b5e4413a
11 changed files with 81 additions and 32 deletions

View File

@@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added a `reconnect_on_error` parameter to websocket-based TTS services as well
as a `on_connection_error` event handler. The `reconnect_on_error` indicates
whether the TTS service should reconnect on error. The `on_connection_error`
will always get called if there's any error no matter the value of
`reconnect_on_error`. This allows, for example, to fallback to a different TTS
provider if something goes wrong with the current one.
- Added new `SkipTagsAggregator` that extends `BaseTextAggregator` to aggregate
text and skips end of sentence matching if aggregated text is between
start/end tags.

View File

@@ -425,12 +425,14 @@ class PipelineTask(BaseTask):
# Tell the task we should stop nicely.
await self.queue_frame(StopFrame())
elif isinstance(frame, ErrorFrame):
logger.error(f"Error running app: {frame}")
if frame.fatal:
logger.error(f"A fatal error occurred: {frame}")
# Cancel all tasks downstream.
await self.queue_frame(CancelFrame())
# Tell the task we should stop.
await self.queue_frame(StopTaskFrame())
else:
logger.warning(f"Something went wrong: {frame}")
self._up_queue.task_done()
async def _process_down_queue(self):

View File

@@ -549,11 +549,25 @@ class WordTTSService(TTSService):
class WebsocketTTSService(TTSService, WebsocketService):
"""This is a base class for websocket-based TTS services."""
"""This is a base class for websocket-based TTS services.
def __init__(self, **kwargs):
If an error occurs with the websocket, an "on_connection_error" event will
be triggered:
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
...
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
TTSService.__init__(self, **kwargs)
WebsocketService.__init__(self)
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
self._register_event_handler("on_connection_error")
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
class InterruptibleTTSService(WebsocketTTSService):
@@ -590,11 +604,23 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
"""This is a base class for websocket-based TTS services that support word
timestamps.
If an error occurs with the websocket a "on_connection_error" event will be
triggered:
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
...
"""
def __init__(self, **kwargs):
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
WordTTSService.__init__(self, **kwargs)
WebsocketService.__init__(self)
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
self._register_event_handler("on_connection_error")
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
class InterruptibleWordTTSService(WebsocketWordTTSService):

View File

@@ -187,7 +187,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
async def _connect(self):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
if self._receive_task:
@@ -207,6 +207,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:

View File

@@ -309,7 +309,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
if not self._keepalive_task:
self._keepalive_task = self.create_task(self._keepalive_task_handler())
@@ -364,6 +364,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:

View File

@@ -107,7 +107,7 @@ class FishAudioTTSService(InterruptibleTTSService):
async def _connect(self):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
if self._receive_task:
@@ -132,6 +132,7 @@ class FishAudioTTSService(InterruptibleTTSService):
except Exception as e:
logger.error(f"Fish Audio initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:

View File

@@ -112,7 +112,7 @@ class LmntTTSService(InterruptibleTTSService):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
if self._receive_task:
@@ -147,6 +147,7 @@ class LmntTTSService(InterruptibleTTSService):
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Disconnect from LMNT websocket."""

View File

@@ -162,7 +162,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
self._keepalive_task = self.create_task(self._keepalive_task_handler())
async def _disconnect(self):
@@ -197,6 +197,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:

View File

@@ -160,7 +160,7 @@ class PlayHTTTSService(InterruptibleTTSService):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
if self._receive_task:
@@ -183,12 +183,14 @@ class PlayHTTTSService(InterruptibleTTSService):
raise ValueError("WebSocket URL is not a string")
self._websocket = await websockets.connect(self._websocket_url)
except ValueError as ve:
logger.error(f"{self} initialization error: {ve}")
except ValueError as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:

View File

@@ -171,7 +171,7 @@ class RimeTTSService(AudioContextWordTTSService):
await self._connect_websocket()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
@@ -194,6 +194,7 @@ class RimeTTSService(AudioContextWordTTSService):
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Close websocket connection and reset state."""

View File

@@ -19,9 +19,10 @@ from pipecat.utils.network import exponential_backoff_time
class WebsocketService(ABC):
"""Base class for websocket-based services with reconnection logic."""
def __init__(self):
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
"""Initialize websocket attributes."""
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
self._reconnect_on_error = reconnect_on_error
async def _verify_connection(self) -> bool:
"""Verify websocket connection is working.
@@ -72,24 +73,29 @@ class WebsocketService(ABC):
self._websocket.close_rcvd_then_sent,
)
except Exception as e:
retry_count += 1
if retry_count >= MAX_RETRIES:
message = f"{self} error receiving messages: {e}"
logger.error(message)
await report_error(ErrorFrame(message, fatal=True))
message = f"{self} error receiving messages: {e}"
logger.error(message)
if self._reconnect_on_error:
retry_count += 1
if retry_count >= MAX_RETRIES:
await report_error(ErrorFrame(message, fatal=True))
break
logger.warning(f"{self} connection error, will retry: {e}")
await report_error(ErrorFrame(message))
try:
if await self._reconnect_websocket(retry_count):
retry_count = 0 # Reset counter on successful reconnection
wait_time = exponential_backoff_time(retry_count)
await asyncio.sleep(wait_time)
except Exception as reconnect_error:
logger.error(f"{self} reconnection failed: {reconnect_error}")
else:
await report_error(ErrorFrame(message))
break
logger.warning(f"{self} connection error, will retry: {e}")
try:
if await self._reconnect_websocket(retry_count):
retry_count = 0 # Reset counter on successful reconnection
wait_time = exponential_backoff_time(retry_count)
await asyncio.sleep(wait_time)
except Exception as reconnect_error:
logger.error(f"{self} reconnection failed: {reconnect_error}")
continue
@abstractmethod
async def _connect(self):
"""Implement service-specific connection logic. This function will