Track websocket disconnecting status to improve error handling

This commit is contained in:
Mark Backman
2026-01-09 16:48:51 -05:00
parent 4fe0836cf9
commit 9c81acb159
21 changed files with 102 additions and 10 deletions

View File

@@ -198,6 +198,8 @@ class AssemblyAISTTService(WebsocketSTTService):
Establishes websocket connection and starts receive task.
"""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -208,6 +210,8 @@ class AssemblyAISTTService(WebsocketSTTService):
Sends termination message, waits for acknowledgment, and cleans up.
"""
await super()._disconnect()
if not self._connected or not self._websocket:
return

View File

@@ -201,6 +201,8 @@ class AsyncAITTSService(InterruptibleTTSService):
await self._disconnect()
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -210,6 +212,8 @@ class AsyncAITTSService(InterruptibleTTSService):
self._keepalive_task = self.create_task(self._keepalive_task_handler())
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -170,6 +170,8 @@ class AWSTranscribeSTTService(WebsocketSTTService):
Establishes websocket connection and starts receive task.
"""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -180,6 +182,8 @@ class AWSTranscribeSTTService(WebsocketSTTService):
Sends end-stream message and cleans up.
"""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -245,12 +245,16 @@ class CartesiaSTTService(WebsocketSTTService):
yield None
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -483,12 +483,16 @@ class CartesiaTTSService(AudioContextWordTTSService):
await self._disconnect()
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -194,6 +194,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
Establishes the WebSocket connection to the Deepgram Flux API and starts
the background task for receiving transcription results.
"""
await super()._connect()
await self._connect_websocket()
async def _disconnect(self):
@@ -202,6 +204,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
Gracefully disconnects from the Deepgram Flux API, cancels background tasks,
and cleans up resources to prevent memory leaks.
"""
await super()._disconnect()
try:
await self._disconnect_websocket()
except Exception as e:

View File

@@ -147,6 +147,8 @@ class DeepgramTTSService(WebsocketTTSService):
async def _connect(self):
"""Connect to Deepgram WebSocket and start receive task."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -154,6 +156,8 @@ class DeepgramTTSService(WebsocketTTSService):
async def _disconnect(self):
"""Disconnect from Deepgram WebSocket and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -605,6 +605,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
async def _connect(self):
"""Establish WebSocket connection to ElevenLabs Realtime STT."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -612,6 +614,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
async def _disconnect(self):
"""Close WebSocket connection and cleanup tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -478,6 +478,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self.add_word_timestamps([("Reset", 0)])
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -487,6 +489,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
self._keepalive_task = self.create_task(self._keepalive_task_handler())
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -199,12 +199,16 @@ class FishAudioTTSService(InterruptibleTTSService):
await self._disconnect()
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -404,6 +404,8 @@ class GladiaSTTService(WebsocketSTTService):
Initializes the session if needed and establishes websocket connection.
"""
await super()._connect()
# Initialize session if needed
if not self._session_url:
settings = self._prepare_settings()
@@ -425,6 +427,8 @@ class GladiaSTTService(WebsocketSTTService):
Cleans up tasks and closes websocket connection.
"""
await super()._disconnect()
self._connection_active = False
if self._keepalive_task:

View File

@@ -141,6 +141,8 @@ class GradiumSTTService(WebsocketSTTService):
pass
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -179,6 +181,8 @@ class GradiumSTTService(WebsocketSTTService):
raise
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -157,6 +157,8 @@ class GradiumTTSService(InterruptibleWordTTSService):
async def _connect(self):
"""Establish websocket connection and start receive task."""
await super()._connect()
logger.debug(f"{self}: connecting")
# If the server disconnected, cancel the receive-task so that it can be reset below.
@@ -173,6 +175,8 @@ class GradiumTTSService(InterruptibleWordTTSService):
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
await super()._disconnect()
logger.debug(f"{self}: disconnecting")
if self._receive_task:
await self.cancel_task(self._receive_task)

View File

@@ -605,6 +605,8 @@ class InworldTTSService(AudioContextWordTTSService):
Returns:
The websocket.
"""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
@@ -615,6 +617,8 @@ class InworldTTSService(AudioContextWordTTSService):
Returns:
The websocket.
"""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -175,6 +175,8 @@ class LmntTTSService(InterruptibleTTSService):
async def _connect(self):
"""Connect to LMNT WebSocket and start receive task."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -182,6 +184,8 @@ class LmntTTSService(InterruptibleTTSService):
async def _disconnect(self):
"""Disconnect from LMNT WebSocket and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -237,6 +237,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
async def _connect(self):
"""Connect to Neuphonic WebSocket and start background tasks."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -247,6 +249,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
async def _disconnect(self):
"""Disconnect from Neuphonic WebSocket and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -231,6 +231,8 @@ class PlayHTTTSService(InterruptibleTTSService):
async def _connect(self):
"""Connect to PlayHT WebSocket and start receive task."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -238,6 +240,8 @@ class PlayHTTTSService(InterruptibleTTSService):
async def _disconnect(self):
"""Disconnect from PlayHT WebSocket and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -278,6 +278,8 @@ class RimeTTSService(AudioContextWordTTSService):
async def _connect(self):
"""Establish websocket connection and start receive task."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -285,6 +287,8 @@ class RimeTTSService(AudioContextWordTTSService):
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
@@ -767,12 +771,16 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
async def _connect(self):
"""Establish WebSocket connection and start receive task."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close WebSocket connection and clean up tasks."""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None

View File

@@ -532,6 +532,8 @@ class SarvamTTSService(InterruptibleTTSService):
async def _connect(self):
"""Connect to Sarvam WebSocket and start background tasks."""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -544,6 +546,8 @@ class SarvamTTSService(InterruptibleTTSService):
async def _disconnect(self):
"""Disconnect from Sarvam WebSocket and clean up tasks."""
await super()._disconnect()
try:
# First, set a flag to prevent new operations
self._disconnecting = True

View File

@@ -264,6 +264,8 @@ class SonioxSTTService(WebsocketSTTService):
Establishes websocket connection and starts receive and keepalive tasks.
"""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
@@ -277,6 +279,8 @@ class SonioxSTTService(WebsocketSTTService):
Cleans up tasks and closes websocket connection.
"""
await super()._disconnect()
if self._keepalive_task:
await self.cancel_task(self._keepalive_task)
self._keepalive_task = None

View File

@@ -36,7 +36,8 @@ class WebsocketService(ABC):
"""
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
self._reconnect_on_error = reconnect_on_error
self._reconnect_in_progress: bool = False # Add this flag
self._reconnect_in_progress: bool = False
self._disconnecting: bool = False
async def _verify_connection(self) -> bool:
"""Verify the websocket connection is active and responsive.
@@ -138,6 +139,11 @@ class WebsocketService(ABC):
logger.debug(f"{self} connection closed normally: {e}")
break
except ConnectionClosedError as e:
# Don't reconnect if we're intentionally disconnecting
if self._disconnecting:
logger.warning(f"{self} connection closed with an error during disconnect: {e}")
break
# Connection closed with error (e.g., no close frame received/sent)
# This often indicates network issues, server problems, or abrupt disconnection
message = f"{self} connection closed, but with an error: {e}"
@@ -162,23 +168,25 @@ class WebsocketService(ABC):
await report_error(ErrorFrame(message))
break
@abstractmethod
async def _connect(self):
"""Connect to the service.
"""Connect to the service and reset disconnecting flag.
Implement service-specific connection logic including websocket connection
via _connect_websocket() and any additional setup required.
Manages the disconnecting flag to enable reconnection. Subclasses should
call super()._connect() first, then implement their specific connection
logic including websocket connection via _connect_websocket() and any
additional setup required.
"""
pass
self._disconnecting = False
@abstractmethod
async def _disconnect(self):
"""Disconnect from the service.
"""Disconnect from the service and set disconnecting flag.
Implement service-specific disconnection logic including websocket
Manages the disconnecting flag to prevent reconnection during intentional
disconnect. Subclasses should call super()._disconnect() first, then
implement their specific disconnection logic including websocket
disconnection via _disconnect_websocket() and any cleanup required.
"""
pass
self._disconnecting = True
@abstractmethod
async def _connect_websocket(self):