diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index 5eda4b7f6..d6c663dda 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -586,9 +586,9 @@ class AssemblyAISTTService(WebsocketSTTService): await self._call_event_handler("on_connected") logger.debug(f"{self} Connected to AssemblyAI WebSocket") except Exception as e: + self._websocket = None self._connected = False await self.push_error(error_msg=f"Unable to connect to AssemblyAI: {e}", exception=e) - raise async def _disconnect_websocket(self): """Close the websocket connection to AssemblyAI.""" diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index 6a427d707..3e60ed110 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -339,10 +339,10 @@ class AWSTranscribeSTTService(WebsocketSTTService): await self._call_event_handler("on_connected") logger.info(f"{self} Successfully connected to AWS Transcribe") except Exception as e: + self._websocket = None await self.push_error( error_msg=f"Unable to connect to AWS Transcribe: {e}", exception=e ) - raise async def _disconnect_websocket(self): """Close the websocket connection to AWS Transcribe.""" diff --git a/src/pipecat/services/cartesia/stt.py b/src/pipecat/services/cartesia/stt.py index af3418edd..2d3eb8c96 100644 --- a/src/pipecat/services/cartesia/stt.py +++ b/src/pipecat/services/cartesia/stt.py @@ -354,7 +354,8 @@ class CartesiaSTTService(WebsocketSTTService): self._websocket = await websocket_connect(ws_url, additional_headers=headers) await self._call_event_handler("on_connected") except Exception as e: - await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) + self._websocket = None + await self.push_error(error_msg=f"Unable to connect to Cartesia: {e}", exception=e) async def _disconnect_websocket(self): ws = self._websocket diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index 1e1d942cf..b7cdf8119 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -823,6 +823,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): await self._call_event_handler("on_connected") logger.debug("Connected to ElevenLabs Realtime STT") except Exception as e: + self._websocket = None await self.push_error( error_msg=f"Unable to connect to ElevenLabs Realtime STT: {e}", exception=e ) diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index 26fea653c..0d788e234 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -558,8 +558,9 @@ class GladiaSTTService(WebsocketSTTService): logger.debug(f"{self} Connected to Gladia WebSocket") except Exception as e: + self._websocket = None + self._connection_active = False await self.push_error(error_msg=f"Unable to connect to Gladia: {e}", exception=e) - raise async def _disconnect_websocket(self): """Close the websocket connection to Gladia.""" diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 3e5c954c2..64a77402d 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -423,8 +423,8 @@ class GradiumSTTService(WebsocketSTTService): logger.debug("Connected to Gradium STT") except Exception as e: - await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) - raise + self._websocket = None + await self.push_error(error_msg=f"Unable to connect to Gradium: {e}", exception=e) async def _disconnect(self): await super()._disconnect() diff --git a/src/pipecat/services/soniox/stt.py b/src/pipecat/services/soniox/stt.py index 2a50581ce..823517aa1 100644 --- a/src/pipecat/services/soniox/stt.py +++ b/src/pipecat/services/soniox/stt.py @@ -537,8 +537,8 @@ class SonioxSTTService(WebsocketSTTService): await self._call_event_handler("on_connected") logger.debug("Connected to Soniox STT") except Exception as e: + self._websocket = None await self.push_error(error_msg=f"Unable to connect to Soniox: {e}", exception=e) - raise async def _disconnect_websocket(self): """Close the websocket connection to Soniox.""" diff --git a/src/pipecat/services/xai/stt.py b/src/pipecat/services/xai/stt.py index e219cfb6e..6bb32bb7e 100644 --- a/src/pipecat/services/xai/stt.py +++ b/src/pipecat/services/xai/stt.py @@ -293,8 +293,9 @@ class XAISTTService(WebsocketSTTService): await self._call_event_handler("on_connected") logger.debug(f"{self} connected to xAI STT WebSocket") except Exception as e: + self._websocket = None + self._session_ready.clear() await self.push_error(error_msg=f"Unable to connect to xAI STT: {e}", exception=e) - raise async def _disconnect_websocket(self): """Close the WebSocket connection.""" diff --git a/tests/test_cartesia_stt.py b/tests/test_cartesia_stt.py new file mode 100644 index 000000000..08563e711 --- /dev/null +++ b/tests/test_cartesia_stt.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from unittest.mock import AsyncMock + +import pytest +from websockets.protocol import State + +from pipecat.services.cartesia.stt import CartesiaSTTService + + +class _FakeWebsocket: + def __init__(self, *, state=State.OPEN, send_side_effect=None): + self.state = state + self.send = AsyncMock(side_effect=send_side_effect) + + +@pytest.mark.asyncio +async def test_cartesia_connect_failure_clears_stale_websocket(monkeypatch): + async def fake_websocket_connect(*args, **kwargs): + raise RuntimeError("connection failed") + + monkeypatch.setattr("pipecat.services.cartesia.stt.websocket_connect", fake_websocket_connect) + + service = CartesiaSTTService(api_key="test-key", sample_rate=16000) + service._websocket = _FakeWebsocket(state=State.CLOSED) + + await service._connect_websocket() + + assert service._websocket is None + + +@pytest.mark.asyncio +async def test_cartesia_run_stt_logs_send_failure_without_clearing_websocket(): + service = CartesiaSTTService(api_key="test-key", sample_rate=16000) + websocket = _FakeWebsocket(send_side_effect=RuntimeError("websocket closed")) + service._websocket = websocket + + async for _ in service.run_stt(b"\x00" * 160): + pass + + assert service._websocket is websocket diff --git a/tests/test_soniox_stt.py b/tests/test_soniox_stt.py index e6d6713f5..99eaeb5c0 100644 --- a/tests/test_soniox_stt.py +++ b/tests/test_soniox_stt.py @@ -5,8 +5,10 @@ # import json +from unittest.mock import AsyncMock import pytest +from websockets.protocol import State from pipecat.frames.frames import TranscriptionFrame from pipecat.services.soniox.stt import END_TOKEN, SonioxSTTService, _language_from_tokens @@ -14,8 +16,10 @@ from pipecat.transcriptions.language import Language class _FakeWebsocket: - def __init__(self, messages): + def __init__(self, messages, *, state=State.OPEN, send_side_effect=None): self._messages = messages + self.state = state + self.send = AsyncMock(side_effect=send_side_effect) def __aiter__(self): return self._iter_messages() @@ -25,6 +29,21 @@ class _FakeWebsocket: yield message +@pytest.mark.asyncio +async def test_connect_failure_clears_stale_websocket_without_raising(monkeypatch): + async def fake_websocket_connect(*args, **kwargs): + raise RuntimeError("connection failed") + + monkeypatch.setattr("pipecat.services.soniox.stt.websocket_connect", fake_websocket_connect) + + service = SonioxSTTService(api_key="test-key") + service._websocket = _FakeWebsocket([], state=State.CLOSED) + + await service._connect_websocket() + + assert service._websocket is None + + def test_language_from_tokens_uses_single_recognized_language(): tokens = [ {"text": "Hello", "language": "en"},