Merge pull request #4311 from pipecat-ai/filipi/reconnect_websocket

New approach to reconnect STT services after updating settings.
This commit is contained in:
Filipi da Silva Fuchter
2026-04-15 14:39:24 -03:00
committed by GitHub
5 changed files with 132 additions and 18 deletions

View File

@@ -0,0 +1 @@
- STT services now reconnect safely when settings change: reconnection is deferred until the current user turn ends (i.e., until `UserStoppedSpeakingFrame` is received) rather than interrupting an active speech session. Audio frames received while the reconnect is in progress are buffered and replayed once the new connection is ready. `CartesiaSTTService` and `DeepgramSTTService` both use this new behavior.

1
changelog/4311.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed audio loss and potential errors when STT settings were updated mid-speech. Previously, `CartesiaSTTService` and `DeepgramSTTService` would immediately disconnect and reconnect when settings changed, dropping any in-flight audio. Reconnection is now deferred until the user stops speaking, and audio arriving during the reconnect window is buffered and replayed.

View File

@@ -323,9 +323,7 @@ class CartesiaSTTService(WebsocketSTTService):
if not changed:
return changed
if self._websocket:
await self._disconnect()
await self._connect()
await self._request_reconnect()
return changed

View File

@@ -443,6 +443,7 @@ class DeepgramSTTService(STTService):
self._connection = None
self._connection_task = None
self._connection_ready = asyncio.Event()
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -452,6 +453,24 @@ class DeepgramSTTService(STTService):
"""
return True
async def _do_reconnect(self):
"""Disconnect and reconnect to Deepgram, waiting until ready.
Called by ``STTService._reconnect()`` inside the reconnecting guard.
Unlike ``WebsocketSTTService``, Deepgram's ``_connect()`` only
launches a background task — the actual WebSocket handshake happens
asynchronously. This method waits for ``_connection_ready`` to be set
before returning so that buffered audio frames are replayed only after
the new connection can accept them.
Raises:
asyncio.TimeoutError: If the connection is not established within
05 seconds.
"""
await self._disconnect()
await self._connect()
await asyncio.wait_for(self._connection_ready.wait(), timeout=5.0)
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta and reconnect if anything changed."""
changed = await super()._update_settings(delta)
@@ -463,9 +482,7 @@ class DeepgramSTTService(STTService):
if isinstance(self._settings, self.Settings):
self._settings._sync_extra_to_fields()
if self._connection:
await self._disconnect()
await self._connect()
await self._request_reconnect()
return changed
@@ -581,8 +598,11 @@ class DeepgramSTTService(STTService):
return
logger.debug("Disconnecting from Deepgram")
# Clear self._connection first to prevent run_stt from sending audio
# during the close handshake, then close gracefully on the saved ref.
# Clear _connection and _connection_ready first to prevent run_stt
# from sending audio during the close handshake, and to ensure any
# concurrent _do_reconnect() waiter sees a clean state before the
# new connection is established.
self._connection_ready.clear()
connection = self._connection
self._connection = None
@@ -603,6 +623,7 @@ class DeepgramSTTService(STTService):
try:
async with self._client.listen.v1.connect(**connect_kwargs) as connection:
self._connection = connection
self._connection_ready.set()
connection.on(EventType.MESSAGE, self._on_message)
connection.on(EventType.ERROR, self._on_error)
@@ -611,16 +632,13 @@ class DeepgramSTTService(STTService):
keepalive_task = self.create_task(
self._keepalive_handler(), f"{self}::keepalive"
)
try:
await connection.start_listening()
finally:
await self.cancel_task(keepalive_task)
except asyncio.CancelledError:
raise
await connection.start_listening()
except Exception as e:
logger.warning(f"{self}: Connection lost, will retry: {e}")
finally:
self._connection_ready.clear()
self._connection = None
await self.cancel_task(keepalive_task)
async def _keepalive_handler(self):
"""Periodically send KeepAlive frames to prevent server-side timeout.

View File

@@ -28,6 +28,7 @@ from pipecat.frames.frames import (
STTMuteFrame,
STTUpdateSettingsFrame,
TranscriptionFrame,
UserStoppedSpeakingFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
@@ -163,6 +164,16 @@ class STTService(AIService):
self._keepalive_task: Optional[asyncio.Task] = None
self._last_audio_time: float = 0
# VAD-aware reconnect state
# Whether it is safe to reconnect right now (False while the user is speaking).
self._can_reconnect: bool = True
# Whether a reconnect has been requested but deferred until speaking ends.
self._need_reconnect: bool = False
# Whether a reconnect cycle is currently in progress.
self._reconnecting: bool = False
# Audio frames received while _reconnecting is True, replayed after reconnect.
self._reconnect_audio_buffer: list[tuple[AudioRawFrame, FrameDirection]] = []
self._register_event_handler("on_connected")
self._register_event_handler("on_disconnected")
self._register_event_handler("on_connection_error")
@@ -290,6 +301,7 @@ class STTService(AIService):
await super().cleanup()
await self._cancel_ttfb_timeout()
await self._cancel_keepalive_task()
self._reconnect_audio_buffer.clear()
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply an STT settings delta.
@@ -331,15 +343,19 @@ class STTService(AIService):
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
"""Process an audio frame for speech recognition.
If the service is muted, this method does nothing. Otherwise, it
processes the audio frame and runs speech-to-text on it, yielding
transcription results. If the frame has a user_id, it is stored
for later use in transcription.
If a reconnect is in progress, the frame is buffered and replayed
once the connection is restored. If the service is muted, the frame
is dropped. Otherwise the frame is sent to the STT service and, if
a user_id is present, it is stored for use in transcription results.
Args:
frame: The audio frame to process.
direction: The direction of frame processing.
"""
if self._reconnecting:
self._reconnect_audio_buffer.append((frame, direction))
return
if self._muted:
return
@@ -390,6 +406,9 @@ class STTService(AIService):
elif isinstance(frame, VADUserStoppedSpeakingFrame):
await self._handle_vad_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, STTUpdateSettingsFrame):
if frame.service is not None and frame.service is not self:
await self.push_frame(frame, direction)
@@ -483,10 +502,25 @@ class STTService(AIService):
"""
await self._reset_stt_ttfb_state()
self._user_speaking = True
self._can_reconnect = False
self._finalize_requested = False
self._finalize_pending = False
self._last_transcript_time = 0
async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame):
"""Handle user stopped speaking frame.
Called when the user's full turn has ended and the transcription has been
received. Re-enables reconnection and triggers any deferred reconnect that
was requested while the user was speaking.
Args:
frame: The user stopped speaking frame.
"""
self._can_reconnect = True
if self._need_reconnect:
await self._reconnect()
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
"""Handle VAD user stopped speaking frame.
@@ -546,6 +580,58 @@ class STTService(AIService):
await self.cancel_task(self._keepalive_task)
self._keepalive_task = None
async def _reconnect(self):
"""Perform a full reconnect cycle with audio buffering.
Sets ``_reconnecting`` so incoming audio frames are buffered rather than
sent to a dead connection. Delegates the actual connection reset to
``_do_reconnect()``. After the new connection is established all buffered
frames are replayed. On failure the error is reported via ``push_error``
and the ``on_connection_error`` event handler.
"""
logger.info(f"{self} reconnecting...")
self._reconnect_audio_buffer.clear()
self._reconnecting = True
self._need_reconnect = False
try:
await self._do_reconnect()
except Exception as e:
logger.error(f"{self} reconnect failed: {e}")
await self._call_event_handler("on_connection_error", str(e))
await self.push_error(f"{self} reconnect failed: {e}", exception=e)
return
finally:
self._reconnecting = False
# Replay audio frames that arrived while the connection was down.
for buffered_frame, buffered_direction in self._reconnect_audio_buffer:
await self.process_audio_frame(buffered_frame, buffered_direction)
self._reconnect_audio_buffer.clear()
async def _do_reconnect(self):
"""Perform the service-specific connection reset.
Called by ``_reconnect()`` inside the reconnecting guard. The default
implementation is a no-op. Subclasses that support explicit reconnection
(e.g. ``WebsocketSTTService``) should override this to tear down and
re-establish their connection.
"""
pass
async def _request_reconnect(self):
"""Reconnect immediately if safe, or defer until after the current user turn.
Reconnection is unsafe while the user is speaking because the service is
actively receiving audio. Calling this method while the user is speaking
schedules a reconnect that fires as soon as ``UserStoppedSpeakingFrame``
is received.
"""
logger.debug(f"{self} requesting to reconnect!")
if self._can_reconnect:
await self._reconnect()
else:
self._need_reconnect = True
async def _keepalive_task_handler(self):
"""Send periodic silent audio to prevent the server from closing the connection.
@@ -737,6 +823,16 @@ class WebsocketSTTService(STTService, WebsocketService):
await super()._disconnect()
await self._cancel_keepalive_task()
async def _do_reconnect(self):
"""Disconnect and reconnect the websocket.
Called by ``STTService._reconnect()`` inside the reconnecting guard.
Tears down the current websocket connection and re-establishes it.
Keepalive management is handled by ``_connect`` / ``_disconnect``.
"""
await self._disconnect()
await self._connect()
async def _reconnect_websocket(self, attempt_number: int) -> bool:
"""Reconnect and restart keepalive task.