Track websocket disconnecting status to improve error handling
This commit is contained in:
@@ -198,6 +198,8 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
|
||||
Establishes websocket connection and starts receive task.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -208,6 +210,8 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
|
||||
Sends termination message, waits for acknowledgment, and cleans up.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
if not self._connected or not self._websocket:
|
||||
return
|
||||
|
||||
|
||||
@@ -201,6 +201,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -210,6 +212,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -170,6 +170,8 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
|
||||
Establishes websocket connection and starts receive task.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -180,6 +182,8 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
|
||||
Sends end-stream message and cleans up.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -245,12 +245,16 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -483,12 +483,16 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -194,6 +194,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
Establishes the WebSocket connection to the Deepgram Flux API and starts
|
||||
the background task for receiving transcription results.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -202,6 +204,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
Gracefully disconnects from the Deepgram Flux API, cancels background tasks,
|
||||
and cleans up resources to prevent memory leaks.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
try:
|
||||
await self._disconnect_websocket()
|
||||
except Exception as e:
|
||||
|
||||
@@ -147,6 +147,8 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Deepgram WebSocket and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -154,6 +156,8 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Deepgram WebSocket and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -605,6 +605,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection to ElevenLabs Realtime STT."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -612,6 +614,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close WebSocket connection and cleanup tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -478,6 +478,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -487,6 +489,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -199,12 +199,16 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -404,6 +404,8 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
Initializes the session if needed and establishes websocket connection.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
# Initialize session if needed
|
||||
if not self._session_url:
|
||||
settings = self._prepare_settings()
|
||||
@@ -425,6 +427,8 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
Cleans up tasks and closes websocket connection.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
self._connection_active = False
|
||||
|
||||
if self._keepalive_task:
|
||||
|
||||
@@ -141,6 +141,8 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
pass
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -179,6 +181,8 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -157,6 +157,8 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
logger.debug(f"{self}: connecting")
|
||||
|
||||
# If the server disconnected, cancel the receive-task so that it can be reset below.
|
||||
@@ -173,6 +175,8 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
@@ -605,6 +605,8 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
Returns:
|
||||
The websocket.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
@@ -615,6 +617,8 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
Returns:
|
||||
The websocket.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -175,6 +175,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to LMNT WebSocket and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -182,6 +184,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from LMNT WebSocket and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -237,6 +237,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Neuphonic WebSocket and start background tasks."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -247,6 +249,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Neuphonic WebSocket and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -231,6 +231,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to PlayHT WebSocket and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -238,6 +240,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from PlayHT WebSocket and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -278,6 +278,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -285,6 +287,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
@@ -767,12 +771,16 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection and start receive task."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close WebSocket connection and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
@@ -532,6 +532,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Sarvam WebSocket and start background tasks."""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -544,6 +546,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Sarvam WebSocket and clean up tasks."""
|
||||
await super()._disconnect()
|
||||
|
||||
try:
|
||||
# First, set a flag to prevent new operations
|
||||
self._disconnecting = True
|
||||
|
||||
@@ -264,6 +264,8 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
|
||||
Establishes websocket connection and starts receive and keepalive tasks.
|
||||
"""
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
@@ -277,6 +279,8 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
|
||||
Cleans up tasks and closes websocket connection.
|
||||
"""
|
||||
await super()._disconnect()
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
@@ -36,7 +36,8 @@ class WebsocketService(ABC):
|
||||
"""
|
||||
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._reconnect_on_error = reconnect_on_error
|
||||
self._reconnect_in_progress: bool = False # Add this flag
|
||||
self._reconnect_in_progress: bool = False
|
||||
self._disconnecting: bool = False
|
||||
|
||||
async def _verify_connection(self) -> bool:
|
||||
"""Verify the websocket connection is active and responsive.
|
||||
@@ -138,6 +139,11 @@ class WebsocketService(ABC):
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
break
|
||||
except ConnectionClosedError as e:
|
||||
# Don't reconnect if we're intentionally disconnecting
|
||||
if self._disconnecting:
|
||||
logger.warning(f"{self} connection closed with an error during disconnect: {e}")
|
||||
break
|
||||
|
||||
# Connection closed with error (e.g., no close frame received/sent)
|
||||
# This often indicates network issues, server problems, or abrupt disconnection
|
||||
message = f"{self} connection closed, but with an error: {e}"
|
||||
@@ -162,23 +168,25 @@ class WebsocketService(ABC):
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
@abstractmethod
|
||||
async def _connect(self):
|
||||
"""Connect to the service.
|
||||
"""Connect to the service and reset disconnecting flag.
|
||||
|
||||
Implement service-specific connection logic including websocket connection
|
||||
via _connect_websocket() and any additional setup required.
|
||||
Manages the disconnecting flag to enable reconnection. Subclasses should
|
||||
call super()._connect() first, then implement their specific connection
|
||||
logic including websocket connection via _connect_websocket() and any
|
||||
additional setup required.
|
||||
"""
|
||||
pass
|
||||
self._disconnecting = False
|
||||
|
||||
@abstractmethod
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the service.
|
||||
"""Disconnect from the service and set disconnecting flag.
|
||||
|
||||
Implement service-specific disconnection logic including websocket
|
||||
Manages the disconnecting flag to prevent reconnection during intentional
|
||||
disconnect. Subclasses should call super()._disconnect() first, then
|
||||
implement their specific disconnection logic including websocket
|
||||
disconnection via _disconnect_websocket() and any cleanup required.
|
||||
"""
|
||||
pass
|
||||
self._disconnecting = True
|
||||
|
||||
@abstractmethod
|
||||
async def _connect_websocket(self):
|
||||
|
||||
Reference in New Issue
Block a user