diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index f54b4ff80..3fde9491c 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -198,6 +198,8 @@ class AssemblyAISTTService(WebsocketSTTService): Establishes websocket connection and starts receive task. """ + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -208,6 +210,8 @@ class AssemblyAISTTService(WebsocketSTTService): Sends termination message, waits for acknowledgment, and cleans up. """ + await super()._disconnect() + if not self._connected or not self._websocket: return diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index a838e2465..303369205 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -201,6 +201,8 @@ class AsyncAITTSService(InterruptibleTTSService): await self._disconnect() async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -210,6 +212,8 @@ class AsyncAITTSService(InterruptibleTTSService): self._keepalive_task = self.create_task(self._keepalive_task_handler()) async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index 915213e51..2ad350a96 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -170,6 +170,8 @@ class AWSTranscribeSTTService(WebsocketSTTService): Establishes websocket connection and starts receive task. """ + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -180,6 +182,8 @@ class AWSTranscribeSTTService(WebsocketSTTService): Sends end-stream message and cleans up. """ + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/cartesia/stt.py b/src/pipecat/services/cartesia/stt.py index 386d8cbbc..625df6366 100644 --- a/src/pipecat/services/cartesia/stt.py +++ b/src/pipecat/services/cartesia/stt.py @@ -245,12 +245,16 @@ class CartesiaSTTService(WebsocketSTTService): yield None async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index 3ed3ca556..6bfb8703d 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -483,12 +483,16 @@ class CartesiaTTSService(AudioContextWordTTSService): await self._disconnect() async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/deepgram/flux/stt.py b/src/pipecat/services/deepgram/flux/stt.py index cc2df29e2..13b72bcf7 100644 --- a/src/pipecat/services/deepgram/flux/stt.py +++ b/src/pipecat/services/deepgram/flux/stt.py @@ -194,6 +194,8 @@ class DeepgramFluxSTTService(WebsocketSTTService): Establishes the WebSocket connection to the Deepgram Flux API and starts the background task for receiving transcription results. """ + await super()._connect() + await self._connect_websocket() async def _disconnect(self): @@ -202,6 +204,8 @@ class DeepgramFluxSTTService(WebsocketSTTService): Gracefully disconnects from the Deepgram Flux API, cancels background tasks, and cleans up resources to prevent memory leaks. """ + await super()._disconnect() + try: await self._disconnect_websocket() except Exception as e: diff --git a/src/pipecat/services/deepgram/tts.py b/src/pipecat/services/deepgram/tts.py index e1688a90c..ec41baf26 100644 --- a/src/pipecat/services/deepgram/tts.py +++ b/src/pipecat/services/deepgram/tts.py @@ -147,6 +147,8 @@ class DeepgramTTSService(WebsocketTTSService): async def _connect(self): """Connect to Deepgram WebSocket and start receive task.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -154,6 +156,8 @@ class DeepgramTTSService(WebsocketTTSService): async def _disconnect(self): """Disconnect from Deepgram WebSocket and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index 4d26e2f81..8f9020aa7 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -605,6 +605,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): async def _connect(self): """Establish WebSocket connection to ElevenLabs Realtime STT.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -612,6 +614,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): async def _disconnect(self): """Close WebSocket connection and cleanup tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index dca462ce4..02ccd0ab3 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -478,6 +478,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService): await self.add_word_timestamps([("Reset", 0)]) async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -487,6 +489,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService): self._keepalive_task = self.create_task(self._keepalive_task_handler()) async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/fish/tts.py b/src/pipecat/services/fish/tts.py index dfa161066..357b82346 100644 --- a/src/pipecat/services/fish/tts.py +++ b/src/pipecat/services/fish/tts.py @@ -199,12 +199,16 @@ class FishAudioTTSService(InterruptibleTTSService): await self._disconnect() async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index 48334ef8c..4ba0a2ffd 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -404,6 +404,8 @@ class GladiaSTTService(WebsocketSTTService): Initializes the session if needed and establishes websocket connection. """ + await super()._connect() + # Initialize session if needed if not self._session_url: settings = self._prepare_settings() @@ -425,6 +427,8 @@ class GladiaSTTService(WebsocketSTTService): Cleans up tasks and closes websocket connection. """ + await super()._disconnect() + self._connection_active = False if self._keepalive_task: diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index f869983d3..b66b18070 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -141,6 +141,8 @@ class GradiumSTTService(WebsocketSTTService): pass async def _connect(self): + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -179,6 +181,8 @@ class GradiumSTTService(WebsocketSTTService): raise async def _disconnect(self): + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/gradium/tts.py b/src/pipecat/services/gradium/tts.py index 3baaa887c..14e093541 100644 --- a/src/pipecat/services/gradium/tts.py +++ b/src/pipecat/services/gradium/tts.py @@ -157,6 +157,8 @@ class GradiumTTSService(InterruptibleWordTTSService): async def _connect(self): """Establish websocket connection and start receive task.""" + await super()._connect() + logger.debug(f"{self}: connecting") # If the server disconnected, cancel the receive-task so that it can be reset below. @@ -173,6 +175,8 @@ class GradiumTTSService(InterruptibleWordTTSService): async def _disconnect(self): """Close websocket connection and clean up tasks.""" + await super()._disconnect() + logger.debug(f"{self}: disconnecting") if self._receive_task: await self.cancel_task(self._receive_task) diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index fddb96602..ffac22464 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -605,6 +605,8 @@ class InworldTTSService(AudioContextWordTTSService): Returns: The websocket. """ + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) @@ -615,6 +617,8 @@ class InworldTTSService(AudioContextWordTTSService): Returns: The websocket. """ + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/lmnt/tts.py b/src/pipecat/services/lmnt/tts.py index b6a50aa9a..911d13923 100644 --- a/src/pipecat/services/lmnt/tts.py +++ b/src/pipecat/services/lmnt/tts.py @@ -175,6 +175,8 @@ class LmntTTSService(InterruptibleTTSService): async def _connect(self): """Connect to LMNT WebSocket and start receive task.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -182,6 +184,8 @@ class LmntTTSService(InterruptibleTTSService): async def _disconnect(self): """Disconnect from LMNT WebSocket and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/neuphonic/tts.py b/src/pipecat/services/neuphonic/tts.py index 44e00dd09..2666c0cfc 100644 --- a/src/pipecat/services/neuphonic/tts.py +++ b/src/pipecat/services/neuphonic/tts.py @@ -237,6 +237,8 @@ class NeuphonicTTSService(InterruptibleTTSService): async def _connect(self): """Connect to Neuphonic WebSocket and start background tasks.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -247,6 +249,8 @@ class NeuphonicTTSService(InterruptibleTTSService): async def _disconnect(self): """Disconnect from Neuphonic WebSocket and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/playht/tts.py b/src/pipecat/services/playht/tts.py index 1e9f83500..bc9dd4859 100644 --- a/src/pipecat/services/playht/tts.py +++ b/src/pipecat/services/playht/tts.py @@ -231,6 +231,8 @@ class PlayHTTTSService(InterruptibleTTSService): async def _connect(self): """Connect to PlayHT WebSocket and start receive task.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -238,6 +240,8 @@ class PlayHTTTSService(InterruptibleTTSService): async def _disconnect(self): """Disconnect from PlayHT WebSocket and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index 6018730b6..b6fe25e0e 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -278,6 +278,8 @@ class RimeTTSService(AudioContextWordTTSService): async def _connect(self): """Establish websocket connection and start receive task.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -285,6 +287,8 @@ class RimeTTSService(AudioContextWordTTSService): async def _disconnect(self): """Close websocket connection and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None @@ -767,12 +771,16 @@ class RimeNonJsonTTSService(InterruptibleTTSService): async def _connect(self): """Establish WebSocket connection and start receive task.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): """Close WebSocket connection and clean up tasks.""" + await super()._disconnect() + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index 2837b3e20..cef228b84 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -532,6 +532,8 @@ class SarvamTTSService(InterruptibleTTSService): async def _connect(self): """Connect to Sarvam WebSocket and start background tasks.""" + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -544,6 +546,8 @@ class SarvamTTSService(InterruptibleTTSService): async def _disconnect(self): """Disconnect from Sarvam WebSocket and clean up tasks.""" + await super()._disconnect() + try: # First, set a flag to prevent new operations self._disconnecting = True diff --git a/src/pipecat/services/soniox/stt.py b/src/pipecat/services/soniox/stt.py index 34b4bc396..476ae0762 100644 --- a/src/pipecat/services/soniox/stt.py +++ b/src/pipecat/services/soniox/stt.py @@ -264,6 +264,8 @@ class SonioxSTTService(WebsocketSTTService): Establishes websocket connection and starts receive and keepalive tasks. """ + await super()._connect() + await self._connect_websocket() if self._websocket and not self._receive_task: @@ -277,6 +279,8 @@ class SonioxSTTService(WebsocketSTTService): Cleans up tasks and closes websocket connection. """ + await super()._disconnect() + if self._keepalive_task: await self.cancel_task(self._keepalive_task) self._keepalive_task = None diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index e9b93af65..e86dee73f 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -36,7 +36,8 @@ class WebsocketService(ABC): """ self._websocket: Optional[websockets.WebSocketClientProtocol] = None self._reconnect_on_error = reconnect_on_error - self._reconnect_in_progress: bool = False # Add this flag + self._reconnect_in_progress: bool = False + self._disconnecting: bool = False async def _verify_connection(self) -> bool: """Verify the websocket connection is active and responsive. @@ -120,6 +121,39 @@ class WebsocketService(ABC): else: logger.error(f"{self} send failed; unable to reconnect") + async def _maybe_try_reconnect( + self, + error: Exception, + error_message: str, + report_error: Callable[[ErrorFrame], Awaitable[None]], + ) -> bool: + """Check if reconnection should be attempted and try if appropriate. + + Args: + error: The exception that occurred. + error_message: Human-readable error message for logging. + report_error: Callback function to report connection errors. + + Returns: + True if should continue the receive loop, False if should break. + """ + # Don't reconnect if we're intentionally disconnecting + if self._disconnecting: + logger.warning(f"{self} error during disconnect: {error}") + return False + + # Log the error + logger.warning(error_message) + + # Try to reconnect if enabled + if self._reconnect_on_error: + success = await self._try_reconnect(report_error=report_error) + return success + else: + # Reconnection disabled + await report_error(ErrorFrame(error_message)) + return False + async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]): """Handle websocket message receiving with automatic retry logic. @@ -138,38 +172,38 @@ class WebsocketService(ABC): logger.debug(f"{self} connection closed normally: {e}") break except ConnectionClosedError as e: - # Error closure, don't retry - logger.warning(f"{self} connection closed, but with an error: {e}") - break + # Connection closed with error (e.g., no close frame received/sent) + # This often indicates network issues, server problems, or abrupt disconnection + message = f"{self} connection closed, but with an error: {e}" + should_continue = await self._maybe_try_reconnect(e, message, report_error) + if not should_continue: + break except Exception as e: + # General error during message receiving message = f"{self} error receiving messages: {e}" - logger.error(message) - - if self._reconnect_on_error: - success = await self._try_reconnect(report_error=report_error) - if not success: - break - else: - await report_error(ErrorFrame(message)) + should_continue = await self._maybe_try_reconnect(e, message, report_error) + if not should_continue: break - @abstractmethod async def _connect(self): - """Connect to the service. + """Connect to the service and reset disconnecting flag. - Implement service-specific connection logic including websocket connection - via _connect_websocket() and any additional setup required. + Manages the disconnecting flag to enable reconnection. Subclasses should + call super()._connect() first, then implement their specific connection + logic including websocket connection via _connect_websocket() and any + additional setup required. """ - pass + self._disconnecting = False - @abstractmethod async def _disconnect(self): - """Disconnect from the service. + """Disconnect from the service and set disconnecting flag. - Implement service-specific disconnection logic including websocket + Manages the disconnecting flag to prevent reconnection during intentional + disconnect. Subclasses should call super()._disconnect() first, then + implement their specific disconnection logic including websocket disconnection via _disconnect_websocket() and any cleanup required. """ - pass + self._disconnecting = True @abstractmethod async def _connect_websocket(self):