Merge pull request #4311 from pipecat-ai/filipi/reconnect_websocket
New approach to reconnect STT services after updating settings.
This commit is contained in:
1
changelog/4311.changed.md
Normal file
1
changelog/4311.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- STT services now reconnect safely when settings change: reconnection is deferred until the current user turn ends (i.e., until `UserStoppedSpeakingFrame` is received) rather than interrupting an active speech session. Audio frames received while the reconnect is in progress are buffered and replayed once the new connection is ready. `CartesiaSTTService` and `DeepgramSTTService` both use this new behavior.
|
||||
1
changelog/4311.fixed.md
Normal file
1
changelog/4311.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed audio loss and potential errors when STT settings were updated mid-speech. Previously, `CartesiaSTTService` and `DeepgramSTTService` would immediately disconnect and reconnect when settings changed, dropping any in-flight audio. Reconnection is now deferred until the user stops speaking, and audio arriving during the reconnect window is buffered and replayed.
|
||||
@@ -323,9 +323,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
if self._websocket:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await self._request_reconnect()
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
@@ -443,6 +443,7 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
self._connection = None
|
||||
self._connection_task = None
|
||||
self._connection_ready = asyncio.Event()
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -452,6 +453,24 @@ class DeepgramSTTService(STTService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _do_reconnect(self):
|
||||
"""Disconnect and reconnect to Deepgram, waiting until ready.
|
||||
|
||||
Called by ``STTService._reconnect()`` inside the reconnecting guard.
|
||||
Unlike ``WebsocketSTTService``, Deepgram's ``_connect()`` only
|
||||
launches a background task — the actual WebSocket handshake happens
|
||||
asynchronously. This method waits for ``_connection_ready`` to be set
|
||||
before returning so that buffered audio frames are replayed only after
|
||||
the new connection can accept them.
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If the connection is not established within
|
||||
05 seconds.
|
||||
"""
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await asyncio.wait_for(self._connection_ready.wait(), timeout=5.0)
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and reconnect if anything changed."""
|
||||
changed = await super()._update_settings(delta)
|
||||
@@ -463,9 +482,7 @@ class DeepgramSTTService(STTService):
|
||||
if isinstance(self._settings, self.Settings):
|
||||
self._settings._sync_extra_to_fields()
|
||||
|
||||
if self._connection:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await self._request_reconnect()
|
||||
|
||||
return changed
|
||||
|
||||
@@ -581,8 +598,11 @@ class DeepgramSTTService(STTService):
|
||||
return
|
||||
|
||||
logger.debug("Disconnecting from Deepgram")
|
||||
# Clear self._connection first to prevent run_stt from sending audio
|
||||
# during the close handshake, then close gracefully on the saved ref.
|
||||
# Clear _connection and _connection_ready first to prevent run_stt
|
||||
# from sending audio during the close handshake, and to ensure any
|
||||
# concurrent _do_reconnect() waiter sees a clean state before the
|
||||
# new connection is established.
|
||||
self._connection_ready.clear()
|
||||
connection = self._connection
|
||||
self._connection = None
|
||||
|
||||
@@ -603,6 +623,7 @@ class DeepgramSTTService(STTService):
|
||||
try:
|
||||
async with self._client.listen.v1.connect(**connect_kwargs) as connection:
|
||||
self._connection = connection
|
||||
self._connection_ready.set()
|
||||
connection.on(EventType.MESSAGE, self._on_message)
|
||||
connection.on(EventType.ERROR, self._on_error)
|
||||
|
||||
@@ -611,16 +632,13 @@ class DeepgramSTTService(STTService):
|
||||
keepalive_task = self.create_task(
|
||||
self._keepalive_handler(), f"{self}::keepalive"
|
||||
)
|
||||
try:
|
||||
await connection.start_listening()
|
||||
finally:
|
||||
await self.cancel_task(keepalive_task)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
await connection.start_listening()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self}: Connection lost, will retry: {e}")
|
||||
finally:
|
||||
self._connection_ready.clear()
|
||||
self._connection = None
|
||||
await self.cancel_task(keepalive_task)
|
||||
|
||||
async def _keepalive_handler(self):
|
||||
"""Periodically send KeepAlive frames to prevent server-side timeout.
|
||||
|
||||
@@ -28,6 +28,7 @@ from pipecat.frames.frames import (
|
||||
STTMuteFrame,
|
||||
STTUpdateSettingsFrame,
|
||||
TranscriptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -163,6 +164,16 @@ class STTService(AIService):
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
self._last_audio_time: float = 0
|
||||
|
||||
# VAD-aware reconnect state
|
||||
# Whether it is safe to reconnect right now (False while the user is speaking).
|
||||
self._can_reconnect: bool = True
|
||||
# Whether a reconnect has been requested but deferred until speaking ends.
|
||||
self._need_reconnect: bool = False
|
||||
# Whether a reconnect cycle is currently in progress.
|
||||
self._reconnecting: bool = False
|
||||
# Audio frames received while _reconnecting is True, replayed after reconnect.
|
||||
self._reconnect_audio_buffer: list[tuple[AudioRawFrame, FrameDirection]] = []
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
@@ -290,6 +301,7 @@ class STTService(AIService):
|
||||
await super().cleanup()
|
||||
await self._cancel_ttfb_timeout()
|
||||
await self._cancel_keepalive_task()
|
||||
self._reconnect_audio_buffer.clear()
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply an STT settings delta.
|
||||
@@ -331,15 +343,19 @@ class STTService(AIService):
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
"""Process an audio frame for speech recognition.
|
||||
|
||||
If the service is muted, this method does nothing. Otherwise, it
|
||||
processes the audio frame and runs speech-to-text on it, yielding
|
||||
transcription results. If the frame has a user_id, it is stored
|
||||
for later use in transcription.
|
||||
If a reconnect is in progress, the frame is buffered and replayed
|
||||
once the connection is restored. If the service is muted, the frame
|
||||
is dropped. Otherwise the frame is sent to the STT service and, if
|
||||
a user_id is present, it is stored for use in transcription results.
|
||||
|
||||
Args:
|
||||
frame: The audio frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
if self._reconnecting:
|
||||
self._reconnect_audio_buffer.append((frame, direction))
|
||||
return
|
||||
|
||||
if self._muted:
|
||||
return
|
||||
|
||||
@@ -390,6 +406,9 @@ class STTService(AIService):
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
if frame.service is not None and frame.service is not self:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -483,10 +502,25 @@ class STTService(AIService):
|
||||
"""
|
||||
await self._reset_stt_ttfb_state()
|
||||
self._user_speaking = True
|
||||
self._can_reconnect = False
|
||||
self._finalize_requested = False
|
||||
self._finalize_pending = False
|
||||
self._last_transcript_time = 0
|
||||
|
||||
async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame):
|
||||
"""Handle user stopped speaking frame.
|
||||
|
||||
Called when the user's full turn has ended and the transcription has been
|
||||
received. Re-enables reconnection and triggers any deferred reconnect that
|
||||
was requested while the user was speaking.
|
||||
|
||||
Args:
|
||||
frame: The user stopped speaking frame.
|
||||
"""
|
||||
self._can_reconnect = True
|
||||
if self._need_reconnect:
|
||||
await self._reconnect()
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
"""Handle VAD user stopped speaking frame.
|
||||
|
||||
@@ -546,6 +580,58 @@ class STTService(AIService):
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
async def _reconnect(self):
|
||||
"""Perform a full reconnect cycle with audio buffering.
|
||||
|
||||
Sets ``_reconnecting`` so incoming audio frames are buffered rather than
|
||||
sent to a dead connection. Delegates the actual connection reset to
|
||||
``_do_reconnect()``. After the new connection is established all buffered
|
||||
frames are replayed. On failure the error is reported via ``push_error``
|
||||
and the ``on_connection_error`` event handler.
|
||||
"""
|
||||
logger.info(f"{self} reconnecting...")
|
||||
self._reconnect_audio_buffer.clear()
|
||||
self._reconnecting = True
|
||||
self._need_reconnect = False
|
||||
try:
|
||||
await self._do_reconnect()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} reconnect failed: {e}")
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
await self.push_error(f"{self} reconnect failed: {e}", exception=e)
|
||||
return
|
||||
finally:
|
||||
self._reconnecting = False
|
||||
|
||||
# Replay audio frames that arrived while the connection was down.
|
||||
for buffered_frame, buffered_direction in self._reconnect_audio_buffer:
|
||||
await self.process_audio_frame(buffered_frame, buffered_direction)
|
||||
self._reconnect_audio_buffer.clear()
|
||||
|
||||
async def _do_reconnect(self):
|
||||
"""Perform the service-specific connection reset.
|
||||
|
||||
Called by ``_reconnect()`` inside the reconnecting guard. The default
|
||||
implementation is a no-op. Subclasses that support explicit reconnection
|
||||
(e.g. ``WebsocketSTTService``) should override this to tear down and
|
||||
re-establish their connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _request_reconnect(self):
|
||||
"""Reconnect immediately if safe, or defer until after the current user turn.
|
||||
|
||||
Reconnection is unsafe while the user is speaking because the service is
|
||||
actively receiving audio. Calling this method while the user is speaking
|
||||
schedules a reconnect that fires as soon as ``UserStoppedSpeakingFrame``
|
||||
is received.
|
||||
"""
|
||||
logger.debug(f"{self} requesting to reconnect!")
|
||||
if self._can_reconnect:
|
||||
await self._reconnect()
|
||||
else:
|
||||
self._need_reconnect = True
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic silent audio to prevent the server from closing the connection.
|
||||
|
||||
@@ -737,6 +823,16 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
await super()._disconnect()
|
||||
await self._cancel_keepalive_task()
|
||||
|
||||
async def _do_reconnect(self):
|
||||
"""Disconnect and reconnect the websocket.
|
||||
|
||||
Called by ``STTService._reconnect()`` inside the reconnecting guard.
|
||||
Tears down the current websocket connection and re-establishes it.
|
||||
Keepalive management is handled by ``_connect`` / ``_disconnect``.
|
||||
"""
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def _reconnect_websocket(self, attempt_number: int) -> bool:
|
||||
"""Reconnect and restart keepalive task.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user