Merge pull request #4514 from pipecat-ai/mb/websocket-stt-service-exception-handling
Align websocket STT connection failures
This commit is contained in:
1
changelog/4514.fixed.md
Normal file
1
changelog/4514.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed websocket STT connection setup failures so services clear stale websocket state and emit non-fatal error frames, allowing `ServiceSwitcher` failover to keep agents running.
|
||||
@@ -586,9 +586,9 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug(f"{self} Connected to AssemblyAI WebSocket")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
self._connected = False
|
||||
await self.push_error(error_msg=f"Unable to connect to AssemblyAI: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to AssemblyAI."""
|
||||
|
||||
@@ -339,10 +339,10 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to AWS Transcribe: {e}", exception=e
|
||||
)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to AWS Transcribe."""
|
||||
|
||||
@@ -354,7 +354,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self.push_error(error_msg=f"Unable to connect to Cartesia: {e}", exception=e)
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
ws = self._websocket
|
||||
|
||||
@@ -823,6 +823,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to ElevenLabs Realtime STT")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to ElevenLabs Realtime STT: {e}", exception=e
|
||||
)
|
||||
|
||||
@@ -558,8 +558,9 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
logger.debug(f"{self} Connected to Gladia WebSocket")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
self._connection_active = False
|
||||
await self.push_error(error_msg=f"Unable to connect to Gladia: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to Gladia."""
|
||||
|
||||
@@ -423,8 +423,8 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
logger.debug("Connected to Gradium STT")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
self._websocket = None
|
||||
await self.push_error(error_msg=f"Unable to connect to Gradium: {e}", exception=e)
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
@@ -537,8 +537,8 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to Soniox STT")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
await self.push_error(error_msg=f"Unable to connect to Soniox: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to Soniox."""
|
||||
|
||||
@@ -76,7 +76,9 @@ class WebsocketService(ABC):
|
||||
logger.warning(f"{self} reconnecting (attempt: {attempt_number})")
|
||||
await self._disconnect_websocket()
|
||||
await self._connect_websocket()
|
||||
return await self._verify_connection()
|
||||
if not await self._verify_connection():
|
||||
raise ConnectionError(f"{self} websocket reconnection failed verification")
|
||||
return True
|
||||
|
||||
async def _try_reconnect(
|
||||
self,
|
||||
|
||||
@@ -293,8 +293,9 @@ class XAISTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug(f"{self} connected to xAI STT WebSocket")
|
||||
except Exception as e:
|
||||
self._websocket = None
|
||||
self._session_ready.clear()
|
||||
await self.push_error(error_msg=f"Unable to connect to xAI STT: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the WebSocket connection."""
|
||||
|
||||
45
tests/test_cartesia_stt.py
Normal file
45
tests/test_cartesia_stt.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from websockets.protocol import State
|
||||
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
|
||||
|
||||
class _FakeWebsocket:
|
||||
def __init__(self, *, state=State.OPEN, send_side_effect=None):
|
||||
self.state = state
|
||||
self.send = AsyncMock(side_effect=send_side_effect)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cartesia_connect_failure_clears_stale_websocket(monkeypatch):
|
||||
async def fake_websocket_connect(*args, **kwargs):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
monkeypatch.setattr("pipecat.services.cartesia.stt.websocket_connect", fake_websocket_connect)
|
||||
|
||||
service = CartesiaSTTService(api_key="test-key", sample_rate=16000)
|
||||
service._websocket = _FakeWebsocket(state=State.CLOSED)
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert service._websocket is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cartesia_run_stt_logs_send_failure_without_clearing_websocket():
|
||||
service = CartesiaSTTService(api_key="test-key", sample_rate=16000)
|
||||
websocket = _FakeWebsocket(send_side_effect=RuntimeError("websocket closed"))
|
||||
service._websocket = websocket
|
||||
|
||||
async for _ in service.run_stt(b"\x00" * 160):
|
||||
pass
|
||||
|
||||
assert service._websocket is websocket
|
||||
@@ -5,8 +5,10 @@
|
||||
#
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from websockets.protocol import State
|
||||
|
||||
from pipecat.frames.frames import TranscriptionFrame
|
||||
from pipecat.services.soniox.stt import END_TOKEN, SonioxSTTService, _language_from_tokens
|
||||
@@ -14,8 +16,10 @@ from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class _FakeWebsocket:
|
||||
def __init__(self, messages):
|
||||
def __init__(self, messages, *, state=State.OPEN, send_side_effect=None):
|
||||
self._messages = messages
|
||||
self.state = state
|
||||
self.send = AsyncMock(side_effect=send_side_effect)
|
||||
|
||||
def __aiter__(self):
|
||||
return self._iter_messages()
|
||||
@@ -25,6 +29,21 @@ class _FakeWebsocket:
|
||||
yield message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure_clears_stale_websocket_without_raising(monkeypatch):
|
||||
async def fake_websocket_connect(*args, **kwargs):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
monkeypatch.setattr("pipecat.services.soniox.stt.websocket_connect", fake_websocket_connect)
|
||||
|
||||
service = SonioxSTTService(api_key="test-key")
|
||||
service._websocket = _FakeWebsocket([], state=State.CLOSED)
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert service._websocket is None
|
||||
|
||||
|
||||
def test_language_from_tokens_uses_single_recognized_language():
|
||||
tokens = [
|
||||
{"text": "Hello", "language": "en"},
|
||||
|
||||
@@ -165,6 +165,19 @@ async def test_reconnect_exhausted_emits_non_fatal_error(service, report_error):
|
||||
assert "Connection refused" in final_error.error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_exhausted_when_connect_does_not_raise(service, report_error):
|
||||
"""A non-raising failed connect is treated as a failed reconnect attempt."""
|
||||
result = await service._try_reconnect(report_error=report_error)
|
||||
|
||||
assert result is False
|
||||
assert report_error.call_count == 4
|
||||
final_error = report_error.call_args_list[-1][0][0]
|
||||
assert isinstance(final_error, ErrorFrame)
|
||||
assert final_error.fatal is False
|
||||
assert "websocket reconnection failed verification" in final_error.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quick failure detection — accept then immediately close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user