From 86a16d53bca19d1119cabbbcabda2571236ad4c9 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 30 Mar 2026 12:16:29 -0400 Subject: [PATCH 1/4] Detect quick connection failures in WebsocketService to prevent infinite reconnection loops When a WebSocket server accepts the handshake but immediately closes the connection (e.g. invalid API key returning close code 1008), the existing exponential backoff does not help because the handshake keeps succeeding. This tracks how long each connection survives and emits a non-fatal ErrorFrame after 3 consecutive sub-5s failures, allowing ServiceSwitcher failover instead of killing the pipeline. Fixes #3711 --- src/pipecat/services/websocket_service.py | 38 +++++ tests/test_websocket_service.py | 183 ++++++++++++++++++++++ 2 files changed, 221 insertions(+) create mode 100644 tests/test_websocket_service.py diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index 85e5b2db7..e5449a539 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -7,6 +7,7 @@ """Base websocket service with automatic reconnection and error handling.""" import asyncio +import time from abc import ABC, abstractmethod from typing import Awaitable, Callable, Optional @@ -27,6 +28,13 @@ class WebsocketService(ABC): Subclasses implement service-specific connection and message handling logic. """ + # Rapid failure detection: when a server accepts the WebSocket handshake but + # immediately closes the connection (e.g. invalid API key, policy rejection), + # exponential backoff won't help because the handshake keeps succeeding. We + # detect this by tracking how long the connection survives after being established. + _MIN_STABLE_CONNECTION_DURATION = 5.0 # seconds + _MAX_CONSECUTIVE_QUICK_FAILURES = 3 + def __init__(self, *, reconnect_on_error: bool = True, **kwargs): """Initialize the websocket service. @@ -38,6 +46,8 @@ class WebsocketService(ABC): self._reconnect_on_error = reconnect_on_error self._reconnect_in_progress: bool = False self._disconnecting: bool = False + self._quick_failure_count: int = 0 + self._last_connect_time: float = 0.0 async def _verify_connection(self) -> bool: """Verify the websocket connection is active and responsive. @@ -86,6 +96,7 @@ class WebsocketService(ABC): logger.warning(f"{self} reconnecting, attempt {attempt}") if await self._reconnect_websocket(attempt): logger.info(f"{self} reconnected successfully on attempt {attempt}") + self._last_connect_time = time.monotonic() return True except Exception as e: last_exception = e @@ -145,6 +156,31 @@ class WebsocketService(ABC): logger.debug(f"{self} receive loop ended during disconnect") return False + # Check if the connection died too quickly after being established. This + # catches cases where the handshake succeeds but the server immediately + # closes (e.g. invalid API key). Exponential backoff won't help here + # because the handshake keeps succeeding — we need to stop the loop. + if self._last_connect_time > 0: + connection_duration = time.monotonic() - self._last_connect_time + if connection_duration < self._MIN_STABLE_CONNECTION_DURATION: + self._quick_failure_count += 1 + logger.warning( + f"{self} connection lasted only {connection_duration:.1f}s " + f"({self._quick_failure_count}/{self._MAX_CONSECUTIVE_QUICK_FAILURES} " + f"consecutive quick failures)" + ) + if self._quick_failure_count >= self._MAX_CONSECUTIVE_QUICK_FAILURES: + msg = ( + f"{self} connection failed {self._MAX_CONSECUTIVE_QUICK_FAILURES} " + f"times immediately after connecting" + ) + logger.error(msg) + await report_error(ErrorFrame(msg)) + return False + else: + # Connection was stable — reset the counter. + self._quick_failure_count = 0 + # Log the message logger.warning(error_message) @@ -168,6 +204,7 @@ class WebsocketService(ABC): report_error: Callback function to report connection errors. """ while True: + self._last_connect_time = time.monotonic() try: await self._receive_messages() # _receive_messages() returned normally. This happens when the websocket @@ -205,6 +242,7 @@ class WebsocketService(ABC): additional setup required. """ self._disconnecting = False + self._quick_failure_count = 0 async def _disconnect(self): """Disconnect from the service and set disconnecting flag. diff --git a/tests/test_websocket_service.py b/tests/test_websocket_service.py new file mode 100644 index 000000000..2f180b987 --- /dev/null +++ b/tests/test_websocket_service.py @@ -0,0 +1,183 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tests for WebsocketService quick failure detection.""" + +import time +from unittest.mock import AsyncMock, patch + +import pytest +from websockets.exceptions import ConnectionClosedError +from websockets.frames import Close + +from pipecat.frames.frames import ErrorFrame +from pipecat.services.websocket_service import WebsocketService + + +class ConcreteWebsocketService(WebsocketService): + """Minimal concrete implementation for testing.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._receive_messages_impl: AsyncMock | None = None + + async def _connect_websocket(self): + pass + + async def _disconnect_websocket(self): + pass + + async def _receive_messages(self): + if self._receive_messages_impl: + await self._receive_messages_impl() + + +@pytest.fixture +def service(): + return ConcreteWebsocketService() + + +@pytest.fixture +def report_error(): + return AsyncMock() + + +@pytest.mark.asyncio +async def test_quick_failures_emit_error(service, report_error): + """Connections that fail immediately after being established should emit an error + after MAX_CONSECUTIVE_QUICK_FAILURES consecutive quick failures.""" + call_count = 0 + + async def fail_immediately(): + nonlocal call_count + call_count += 1 + raise ConnectionClosedError(Close(1008, "Invalid API key"), None) + + service._receive_messages_impl = AsyncMock(side_effect=fail_immediately) + # Mock _try_reconnect to succeed (handshake passes but connection dies right away) + service._try_reconnect = AsyncMock(return_value=True) + + await service._receive_task_handler(report_error) + + # Should have called _receive_messages MAX_RAPID_FAILURES times + assert call_count == service._MAX_CONSECUTIVE_QUICK_FAILURES + # Should have emitted a fatal error + report_error.assert_called_once() + error_frame = report_error.call_args[0][0] + assert isinstance(error_frame, ErrorFrame) + assert error_frame.fatal is False + assert "failed 3 times immediately after connecting" in error_frame.error + + +@pytest.mark.asyncio +async def test_stable_connection_resets_quick_failure_counter(service, report_error): + """A connection that survives beyond the threshold should reset the quick failure counter.""" + call_count = 0 + + async def fail_then_stable_then_fail(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + # First two calls: quick failures + raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) + elif call_count == 3: + # Third call: simulate a stable connection by advancing time past threshold + raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) + else: + # Fourth and beyond: quick failures again + raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) + + service._receive_messages_impl = AsyncMock(side_effect=fail_then_stable_then_fail) + service._try_reconnect = AsyncMock(return_value=True) + + # Patch time.monotonic to control timing + base_time = 1000.0 + time_values = iter( + [ + # Call 1: set _last_connect_time + base_time, + # Call 1: check in _maybe_try_reconnect (rapid: 0s elapsed) + base_time, + # Call 2: set _last_connect_time + base_time + 1.0, + # Call 2: check in _maybe_try_reconnect (rapid: 0s elapsed) + base_time + 1.0, + # Call 3: set _last_connect_time + base_time + 2.0, + # Call 3: check in _maybe_try_reconnect (stable: 10s elapsed) + base_time + 12.0, + # Call 4: set _last_connect_time + base_time + 13.0, + # Call 4: check in _maybe_try_reconnect (rapid: 0s elapsed) + base_time + 13.0, + # Call 5: set _last_connect_time + base_time + 14.0, + # Call 5: check in _maybe_try_reconnect (rapid: 0s elapsed) + base_time + 14.0, + # Call 6: set _last_connect_time + base_time + 15.0, + # Call 6: check in _maybe_try_reconnect (rapid: 0s elapsed) + base_time + 15.0, + ] + ) + + with patch("pipecat.services.websocket_service.time") as mock_time: + mock_time.monotonic = lambda: next(time_values) + + await service._receive_task_handler(report_error) + + # After the stable connection (call 3), counter resets to 0. + # Then calls 4, 5, 6 are quick failures (counter: 1, 2, 3) -> error emitted + assert call_count == 6 + report_error.assert_called_once() + error_frame = report_error.call_args[0][0] + assert error_frame.fatal is False + + +@pytest.mark.asyncio +async def test_graceful_close_counts_toward_quick_failures(service, report_error): + """A _receive_messages that returns normally (graceful close) should also count + toward quick failures if it happens immediately.""" + call_count = 0 + + async def return_immediately(): + nonlocal call_count + call_count += 1 + + service._receive_messages_impl = AsyncMock(side_effect=return_immediately) + service._try_reconnect = AsyncMock(return_value=True) + + await service._receive_task_handler(report_error) + + assert call_count == service._MAX_CONSECUTIVE_QUICK_FAILURES + report_error.assert_called_once() + error_frame = report_error.call_args[0][0] + assert isinstance(error_frame, ErrorFrame) + assert error_frame.fatal is False + + +@pytest.mark.asyncio +async def test_connect_resets_quick_failure_counter(service): + """Calling _connect() should reset the quick failure counter.""" + service._quick_failure_count = 5 + await service._connect() + assert service._quick_failure_count == 0 + + +@pytest.mark.asyncio +async def test_intentional_disconnect_skips_quick_failure_logic(service, report_error): + """When _disconnecting is True, quick failure detection should not run.""" + service._disconnecting = True + service._quick_failure_count = 0 + service._last_connect_time = time.monotonic() + + result = await service._maybe_try_reconnect("test error", report_error) + + assert result is False + # Counter should not have been incremented + assert service._quick_failure_count == 0 + # No error frame should have been emitted + report_error.assert_not_called() From f37bf989dd80e57f96b6662a296624249ebe11a1 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 30 Mar 2026 12:29:46 -0400 Subject: [PATCH 2/4] Make reconnection failure error non-fatal to allow service failover A single service failing to reconnect should not kill the entire pipeline. Non-fatal errors flow through the pipeline so application code (e.g. ServiceSwitcher) can handle failover to a backup service. --- src/pipecat/services/websocket_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index e5449a539..9258aa90c 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -107,12 +107,12 @@ class WebsocketService(ABC): ) wait_time = exponential_backoff_time(attempt) await asyncio.sleep(wait_time) - fatal_msg = f"{self} failed to reconnect after {max_retries} attempts" + msg = f"{self} failed to reconnect after {max_retries} attempts" if last_exception: - fatal_msg += f": {last_exception}" - logger.error(fatal_msg) + msg += f": {last_exception}" + logger.error(msg) if report_error: - await report_error(ErrorFrame(fatal_msg, fatal=True)) + await report_error(ErrorFrame(msg)) return False finally: self._reconnect_in_progress = False From 3af93ed257228eb0d815937f1de81c946648970e Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 30 Mar 2026 12:31:22 -0400 Subject: [PATCH 3/4] Add changelog for #4201 --- changelog/4201.changed.md | 1 + changelog/4201.fixed.md | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog/4201.changed.md create mode 100644 changelog/4201.fixed.md diff --git a/changelog/4201.changed.md b/changelog/4201.changed.md new file mode 100644 index 000000000..4143e9c66 --- /dev/null +++ b/changelog/4201.changed.md @@ -0,0 +1 @@ +- `WebsocketService` reconnection errors are now non-fatal. When a websocket service exhausts its reconnection attempts (either via exponential backoff or quick failure detection), it emits a non-fatal `ErrorFrame` instead of a fatal one. This allows application-level failover (e.g. `ServiceSwitcher`) to handle the failure instead of killing the entire pipeline. diff --git a/changelog/4201.fixed.md b/changelog/4201.fixed.md new file mode 100644 index 000000000..fb44f7344 --- /dev/null +++ b/changelog/4201.fixed.md @@ -0,0 +1 @@ +- Fixed `WebsocketService` entering an infinite reconnection loop when a server accepts the WebSocket handshake but immediately closes the connection (e.g. invalid API key, close code 1008). The service now detects connections that fail repeatedly within seconds of being established and stops retrying after 3 consecutive quick failures. From f6a3678f93df332e3b26a896d03c5bfe07728caa Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 30 Mar 2026 12:46:30 -0400 Subject: [PATCH 4/4] Improve tests --- tests/test_websocket_service.py | 238 ++++++++++++++++++++++---------- 1 file changed, 164 insertions(+), 74 deletions(-) diff --git a/tests/test_websocket_service.py b/tests/test_websocket_service.py index 2f180b987..2cffbce97 100644 --- a/tests/test_websocket_service.py +++ b/tests/test_websocket_service.py @@ -4,13 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Tests for WebsocketService quick failure detection.""" +"""Tests for WebsocketService reconnection and lifecycle behavior.""" -import time from unittest.mock import AsyncMock, patch import pytest -from websockets.exceptions import ConnectionClosedError +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import Close from pipecat.frames.frames import ErrorFrame @@ -45,10 +44,135 @@ def report_error(): return AsyncMock() +@pytest.fixture(autouse=True) +def _no_sleep(): + """Patch asyncio.sleep globally to avoid real backoff waits.""" + with patch("pipecat.services.websocket_service.asyncio.sleep", new_callable=AsyncMock): + yield + + +# --------------------------------------------------------------------------- +# Receive loop — how each exception type is handled +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_connection_closed_ok_exits_cleanly(service, report_error): + """ConnectionClosedOK exits the loop with no error and no reconnection.""" + service._receive_messages_impl = AsyncMock( + side_effect=ConnectionClosedOK(Close(1000, "Normal closure"), None) + ) + service._try_reconnect = AsyncMock() + + await service._receive_task_handler(report_error) + + report_error.assert_not_called() + service._try_reconnect.assert_not_called() + + +@pytest.mark.asyncio +async def test_connection_closed_error_triggers_reconnect(service, report_error): + """ConnectionClosedError triggers reconnection; loop continues after success.""" + call_count = 0 + + async def fail_then_exit(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) + service._disconnecting = True + + service._receive_messages_impl = AsyncMock(side_effect=fail_then_exit) + service._try_reconnect = AsyncMock(return_value=True) + + await service._receive_task_handler(report_error) + + assert call_count == 2 + service._try_reconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_graceful_server_close_triggers_reconnect(service, report_error): + """Normal return from _receive_messages (server close frame) triggers reconnection.""" + call_count = 0 + + async def return_then_exit(): + nonlocal call_count + call_count += 1 + if call_count > 1: + service._disconnecting = True + + service._receive_messages_impl = AsyncMock(side_effect=return_then_exit) + service._try_reconnect = AsyncMock(return_value=True) + + await service._receive_task_handler(report_error) + + assert call_count == 2 + service._try_reconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_general_exception_triggers_reconnect(service, report_error): + """A general exception in _receive_messages triggers reconnection.""" + call_count = 0 + + async def fail_then_exit(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("something broke") + service._disconnecting = True + + service._receive_messages_impl = AsyncMock(side_effect=fail_then_exit) + service._try_reconnect = AsyncMock(return_value=True) + + await service._receive_task_handler(report_error) + + assert call_count == 2 + service._try_reconnect.assert_called_once() + + +# --------------------------------------------------------------------------- +# Exponential backoff — server unreachable +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reconnect_succeeds_on_later_attempt(service, report_error): + """_try_reconnect retries and succeeds on a later attempt.""" + service._reconnect_websocket = AsyncMock( + side_effect=[ConnectionError("fail"), ConnectionError("fail"), True] + ) + + result = await service._try_reconnect(report_error=report_error) + + assert result is True + assert service._reconnect_websocket.call_count == 3 + + +@pytest.mark.asyncio +async def test_reconnect_exhausted_emits_non_fatal_error(service, report_error): + """Exhausting all retries returns False and emits a non-fatal ErrorFrame.""" + service._reconnect_websocket = AsyncMock(side_effect=ConnectionError("Connection refused")) + + result = await service._try_reconnect(report_error=report_error) + + assert result is False + assert service._reconnect_websocket.call_count == 3 + final_error = report_error.call_args_list[-1][0][0] + assert isinstance(final_error, ErrorFrame) + assert final_error.fatal is False + assert "Connection refused" in final_error.error + + +# --------------------------------------------------------------------------- +# Quick failure detection — accept then immediately close +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio async def test_quick_failures_emit_error(service, report_error): - """Connections that fail immediately after being established should emit an error - after MAX_CONSECUTIVE_QUICK_FAILURES consecutive quick failures.""" + """Connections failing immediately after establishment emit error after 3 cycles.""" call_count = 0 async def fail_immediately(): @@ -57,14 +181,11 @@ async def test_quick_failures_emit_error(service, report_error): raise ConnectionClosedError(Close(1008, "Invalid API key"), None) service._receive_messages_impl = AsyncMock(side_effect=fail_immediately) - # Mock _try_reconnect to succeed (handshake passes but connection dies right away) service._try_reconnect = AsyncMock(return_value=True) await service._receive_task_handler(report_error) - # Should have called _receive_messages MAX_RAPID_FAILURES times assert call_count == service._MAX_CONSECUTIVE_QUICK_FAILURES - # Should have emitted a fatal error report_error.assert_called_once() error_frame = report_error.call_args[0][0] assert isinstance(error_frame, ErrorFrame) @@ -74,110 +195,79 @@ async def test_quick_failures_emit_error(service, report_error): @pytest.mark.asyncio async def test_stable_connection_resets_quick_failure_counter(service, report_error): - """A connection that survives beyond the threshold should reset the quick failure counter.""" + """A stable connection resets the quick failure counter; needs 3 new failures to trigger.""" call_count = 0 - async def fail_then_stable_then_fail(): + async def always_fail(): nonlocal call_count call_count += 1 - if call_count <= 2: - # First two calls: quick failures - raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) - elif call_count == 3: - # Third call: simulate a stable connection by advancing time past threshold - raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) - else: - # Fourth and beyond: quick failures again - raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) + raise ConnectionClosedError(Close(1006, "Abnormal closure"), None) - service._receive_messages_impl = AsyncMock(side_effect=fail_then_stable_then_fail) + service._receive_messages_impl = AsyncMock(side_effect=always_fail) service._try_reconnect = AsyncMock(return_value=True) - # Patch time.monotonic to control timing base_time = 1000.0 time_values = iter( [ - # Call 1: set _last_connect_time + # Call 1: set _last_connect_time, check in _maybe_try_reconnect (quick) -> count=1 base_time, - # Call 1: check in _maybe_try_reconnect (rapid: 0s elapsed) base_time, - # Call 2: set _last_connect_time + # Call 2: quick -> count=2 base_time + 1.0, - # Call 2: check in _maybe_try_reconnect (rapid: 0s elapsed) base_time + 1.0, - # Call 3: set _last_connect_time + # Call 3: stable (10s elapsed) -> count=0 base_time + 2.0, - # Call 3: check in _maybe_try_reconnect (stable: 10s elapsed) base_time + 12.0, - # Call 4: set _last_connect_time + # Call 4: quick -> count=1 base_time + 13.0, - # Call 4: check in _maybe_try_reconnect (rapid: 0s elapsed) base_time + 13.0, - # Call 5: set _last_connect_time + # Call 5: quick -> count=2 base_time + 14.0, - # Call 5: check in _maybe_try_reconnect (rapid: 0s elapsed) base_time + 14.0, - # Call 6: set _last_connect_time + # Call 6: quick -> count=3 -> error emitted, loop stops base_time + 15.0, - # Call 6: check in _maybe_try_reconnect (rapid: 0s elapsed) base_time + 15.0, ] ) with patch("pipecat.services.websocket_service.time") as mock_time: mock_time.monotonic = lambda: next(time_values) - await service._receive_task_handler(report_error) - # After the stable connection (call 3), counter resets to 0. - # Then calls 4, 5, 6 are quick failures (counter: 1, 2, 3) -> error emitted assert call_count == 6 report_error.assert_called_once() error_frame = report_error.call_args[0][0] assert error_frame.fatal is False +# --------------------------------------------------------------------------- +# Lifecycle and guards +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio -async def test_graceful_close_counts_toward_quick_failures(service, report_error): - """A _receive_messages that returns normally (graceful close) should also count - toward quick failures if it happens immediately.""" - call_count = 0 +async def test_disconnect_prevents_reconnection(service, report_error): + """After _disconnect(), errors exit the loop without reconnecting or emitting errors.""" + await service._disconnect() - async def return_immediately(): - nonlocal call_count - call_count += 1 - - service._receive_messages_impl = AsyncMock(side_effect=return_immediately) - service._try_reconnect = AsyncMock(return_value=True) + service._receive_messages_impl = AsyncMock( + side_effect=ConnectionClosedError(Close(1006, "Abnormal closure"), None) + ) + service._try_reconnect = AsyncMock() await service._receive_task_handler(report_error) - assert call_count == service._MAX_CONSECUTIVE_QUICK_FAILURES - report_error.assert_called_once() - error_frame = report_error.call_args[0][0] - assert isinstance(error_frame, ErrorFrame) - assert error_frame.fatal is False - - -@pytest.mark.asyncio -async def test_connect_resets_quick_failure_counter(service): - """Calling _connect() should reset the quick failure counter.""" - service._quick_failure_count = 5 - await service._connect() - assert service._quick_failure_count == 0 - - -@pytest.mark.asyncio -async def test_intentional_disconnect_skips_quick_failure_logic(service, report_error): - """When _disconnecting is True, quick failure detection should not run.""" - service._disconnecting = True - service._quick_failure_count = 0 - service._last_connect_time = time.monotonic() - - result = await service._maybe_try_reconnect("test error", report_error) - - assert result is False - # Counter should not have been incremented - assert service._quick_failure_count == 0 - # No error frame should have been emitted report_error.assert_not_called() + service._try_reconnect.assert_not_called() + + +@pytest.mark.asyncio +async def test_connect_resets_state(service): + """_connect() resets _disconnecting and _quick_failure_count.""" + service._disconnecting = True + service._quick_failure_count = 5 + + await service._connect() + + assert service._disconnecting is False + assert service._quick_failure_count == 0