diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d30759a6..35df891ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added a `reconnect_on_error` parameter to websocket-based TTS services as well + as a `on_connection_error` event handler. The `reconnect_on_error` indicates + whether the TTS service should reconnect on error. The `on_connection_error` + will always get called if there's any error no matter the value of + `reconnect_on_error`. This allows, for example, to fallback to a different TTS + provider if something goes wrong with the current one. + - Added new `SkipTagsAggregator` that extends `BaseTextAggregator` to aggregate text and skips end of sentence matching if aggregated text is between start/end tags. diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 12c56ef3f..e9744fd04 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -425,12 +425,14 @@ class PipelineTask(BaseTask): # Tell the task we should stop nicely. await self.queue_frame(StopFrame()) elif isinstance(frame, ErrorFrame): - logger.error(f"Error running app: {frame}") if frame.fatal: + logger.error(f"A fatal error occurred: {frame}") # Cancel all tasks downstream. await self.queue_frame(CancelFrame()) # Tell the task we should stop. await self.queue_frame(StopTaskFrame()) + else: + logger.warning(f"Something went wrong: {frame}") self._up_queue.task_done() async def _process_down_queue(self): diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index ba895c53d..d29e0e714 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -549,11 +549,25 @@ class WordTTSService(TTSService): class WebsocketTTSService(TTSService, WebsocketService): - """This is a base class for websocket-based TTS services.""" + """This is a base class for websocket-based TTS services. - def __init__(self, **kwargs): + If an error occurs with the websocket, an "on_connection_error" event will + be triggered: + + @tts.event_handler("on_connection_error") + async def on_connection_error(tts: TTSService, error: str): + ... + + """ + + def __init__(self, *, reconnect_on_error: bool = True, **kwargs): TTSService.__init__(self, **kwargs) - WebsocketService.__init__(self) + WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) + self._register_event_handler("on_connection_error") + + async def _report_error(self, error: ErrorFrame): + await self._call_event_handler("on_connection_error", error.error) + await self.push_error(error) class InterruptibleTTSService(WebsocketTTSService): @@ -590,11 +604,23 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService): """This is a base class for websocket-based TTS services that support word timestamps. + If an error occurs with the websocket a "on_connection_error" event will be + triggered: + + @tts.event_handler("on_connection_error") + async def on_connection_error(tts: TTSService, error: str): + ... + """ - def __init__(self, **kwargs): + def __init__(self, *, reconnect_on_error: bool = True, **kwargs): WordTTSService.__init__(self, **kwargs) - WebsocketService.__init__(self) + WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) + self._register_event_handler("on_connection_error") + + async def _report_error(self, error: ErrorFrame): + await self._call_event_handler("on_connection_error", error.error) + await self.push_error(error) class InterruptibleWordTTSService(WebsocketWordTTSService): diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 3b491d26b..11796464d 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -187,7 +187,7 @@ class CartesiaTTSService(AudioContextWordTTSService): async def _connect(self): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): if self._receive_task: @@ -207,6 +207,7 @@ class CartesiaTTSService(AudioContextWordTTSService): except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index c41e5de02..568b9eb64 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -309,7 +309,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) if not self._keepalive_task: self._keepalive_task = self.create_task(self._keepalive_task_handler()) @@ -364,6 +364,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py index 96968d6e9..9e6a8b91e 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish.py @@ -107,7 +107,7 @@ class FishAudioTTSService(InterruptibleTTSService): async def _connect(self): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): if self._receive_task: @@ -132,6 +132,7 @@ class FishAudioTTSService(InterruptibleTTSService): except Exception as e: logger.error(f"Fish Audio initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index cfdf8e6cd..d3cc92603 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -112,7 +112,7 @@ class LmntTTSService(InterruptibleTTSService): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): if self._receive_task: @@ -147,6 +147,7 @@ class LmntTTSService(InterruptibleTTSService): except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): """Disconnect from LMNT websocket.""" diff --git a/src/pipecat/services/neuphonic.py b/src/pipecat/services/neuphonic.py index b935885b6..407e54a83 100644 --- a/src/pipecat/services/neuphonic.py +++ b/src/pipecat/services/neuphonic.py @@ -162,7 +162,7 @@ class NeuphonicTTSService(InterruptibleTTSService): async def _connect(self): await self._connect_websocket() - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) self._keepalive_task = self.create_task(self._keepalive_task_handler()) async def _disconnect(self): @@ -197,6 +197,7 @@ class NeuphonicTTSService(InterruptibleTTSService): except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 80d313765..75677876f 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -160,7 +160,7 @@ class PlayHTTTSService(InterruptibleTTSService): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): if self._receive_task: @@ -183,12 +183,14 @@ class PlayHTTTSService(InterruptibleTTSService): raise ValueError("WebSocket URL is not a string") self._websocket = await websockets.connect(self._websocket_url) - except ValueError as ve: - logger.error(f"{self} initialization error: {ve}") + except ValueError as e: + logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index c6fc50001..a1a455eb5 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -171,7 +171,7 @@ class RimeTTSService(AudioContextWordTTSService): await self._connect_websocket() if not self._receive_task: - self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): """Close websocket connection and clean up tasks.""" @@ -194,6 +194,7 @@ class RimeTTSService(AudioContextWordTTSService): except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): """Close websocket connection and reset state.""" diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index 19fb4ae01..2e82ddaba 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -19,9 +19,10 @@ from pipecat.utils.network import exponential_backoff_time class WebsocketService(ABC): """Base class for websocket-based services with reconnection logic.""" - def __init__(self): + def __init__(self, *, reconnect_on_error: bool = True, **kwargs): """Initialize websocket attributes.""" self._websocket: Optional[websockets.WebSocketClientProtocol] = None + self._reconnect_on_error = reconnect_on_error async def _verify_connection(self) -> bool: """Verify websocket connection is working. @@ -72,24 +73,29 @@ class WebsocketService(ABC): self._websocket.close_rcvd_then_sent, ) except Exception as e: - retry_count += 1 - if retry_count >= MAX_RETRIES: - message = f"{self} error receiving messages: {e}" - logger.error(message) - await report_error(ErrorFrame(message, fatal=True)) + message = f"{self} error receiving messages: {e}" + logger.error(message) + + if self._reconnect_on_error: + retry_count += 1 + if retry_count >= MAX_RETRIES: + await report_error(ErrorFrame(message, fatal=True)) + break + + logger.warning(f"{self} connection error, will retry: {e}") + await report_error(ErrorFrame(message)) + + try: + if await self._reconnect_websocket(retry_count): + retry_count = 0 # Reset counter on successful reconnection + wait_time = exponential_backoff_time(retry_count) + await asyncio.sleep(wait_time) + except Exception as reconnect_error: + logger.error(f"{self} reconnection failed: {reconnect_error}") + else: + await report_error(ErrorFrame(message)) break - logger.warning(f"{self} connection error, will retry: {e}") - - try: - if await self._reconnect_websocket(retry_count): - retry_count = 0 # Reset counter on successful reconnection - wait_time = exponential_backoff_time(retry_count) - await asyncio.sleep(wait_time) - except Exception as reconnect_error: - logger.error(f"{self} reconnection failed: {reconnect_error}") - continue - @abstractmethod async def _connect(self): """Implement service-specific connection logic. This function will