Merge pull request #4201 from pipecat-ai/mb/handle-recurring-disconnects
Fix WebsocketService infinite reconnection loop
This commit is contained in:
1
changelog/4201.changed.md
Normal file
1
changelog/4201.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- `WebsocketService` reconnection errors are now non-fatal. When a websocket service exhausts its reconnection attempts (either via exponential backoff or quick failure detection), it emits a non-fatal `ErrorFrame` instead of a fatal one. This allows application-level failover (e.g. `ServiceSwitcher`) to handle the failure instead of killing the entire pipeline.
|
||||
1
changelog/4201.fixed.md
Normal file
1
changelog/4201.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `WebsocketService` entering an infinite reconnection loop when a server accepts the WebSocket handshake but immediately closes the connection (e.g. invalid API key, close code 1008). The service now detects connections that fail repeatedly within seconds of being established and stops retrying after 3 consecutive quick failures.
|
||||
@@ -7,6 +7,7 @@
|
||||
"""Base websocket service with automatic reconnection and error handling."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
@@ -27,6 +28,13 @@ class WebsocketService(ABC):
|
||||
Subclasses implement service-specific connection and message handling logic.
|
||||
"""
|
||||
|
||||
# Rapid failure detection: when a server accepts the WebSocket handshake but
|
||||
# immediately closes the connection (e.g. invalid API key, policy rejection),
|
||||
# exponential backoff won't help because the handshake keeps succeeding. We
|
||||
# detect this by tracking how long the connection survives after being established.
|
||||
_MIN_STABLE_CONNECTION_DURATION = 5.0 # seconds
|
||||
_MAX_CONSECUTIVE_QUICK_FAILURES = 3
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
"""Initialize the websocket service.
|
||||
|
||||
@@ -38,6 +46,8 @@ class WebsocketService(ABC):
|
||||
self._reconnect_on_error = reconnect_on_error
|
||||
self._reconnect_in_progress: bool = False
|
||||
self._disconnecting: bool = False
|
||||
self._quick_failure_count: int = 0
|
||||
self._last_connect_time: float = 0.0
|
||||
|
||||
async def _verify_connection(self) -> bool:
|
||||
"""Verify the websocket connection is active and responsive.
|
||||
@@ -86,6 +96,7 @@ class WebsocketService(ABC):
|
||||
logger.warning(f"{self} reconnecting, attempt {attempt}")
|
||||
if await self._reconnect_websocket(attempt):
|
||||
logger.info(f"{self} reconnected successfully on attempt {attempt}")
|
||||
self._last_connect_time = time.monotonic()
|
||||
return True
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
@@ -96,12 +107,12 @@ class WebsocketService(ABC):
|
||||
)
|
||||
wait_time = exponential_backoff_time(attempt)
|
||||
await asyncio.sleep(wait_time)
|
||||
fatal_msg = f"{self} failed to reconnect after {max_retries} attempts"
|
||||
msg = f"{self} failed to reconnect after {max_retries} attempts"
|
||||
if last_exception:
|
||||
fatal_msg += f": {last_exception}"
|
||||
logger.error(fatal_msg)
|
||||
msg += f": {last_exception}"
|
||||
logger.error(msg)
|
||||
if report_error:
|
||||
await report_error(ErrorFrame(fatal_msg, fatal=True))
|
||||
await report_error(ErrorFrame(msg))
|
||||
return False
|
||||
finally:
|
||||
self._reconnect_in_progress = False
|
||||
@@ -145,6 +156,31 @@ class WebsocketService(ABC):
|
||||
logger.debug(f"{self} receive loop ended during disconnect")
|
||||
return False
|
||||
|
||||
# Check if the connection died too quickly after being established. This
|
||||
# catches cases where the handshake succeeds but the server immediately
|
||||
# closes (e.g. invalid API key). Exponential backoff won't help here
|
||||
# because the handshake keeps succeeding — we need to stop the loop.
|
||||
if self._last_connect_time > 0:
|
||||
connection_duration = time.monotonic() - self._last_connect_time
|
||||
if connection_duration < self._MIN_STABLE_CONNECTION_DURATION:
|
||||
self._quick_failure_count += 1
|
||||
logger.warning(
|
||||
f"{self} connection lasted only {connection_duration:.1f}s "
|
||||
f"({self._quick_failure_count}/{self._MAX_CONSECUTIVE_QUICK_FAILURES} "
|
||||
f"consecutive quick failures)"
|
||||
)
|
||||
if self._quick_failure_count >= self._MAX_CONSECUTIVE_QUICK_FAILURES:
|
||||
msg = (
|
||||
f"{self} connection failed {self._MAX_CONSECUTIVE_QUICK_FAILURES} "
|
||||
f"times immediately after connecting"
|
||||
)
|
||||
logger.error(msg)
|
||||
await report_error(ErrorFrame(msg))
|
||||
return False
|
||||
else:
|
||||
# Connection was stable — reset the counter.
|
||||
self._quick_failure_count = 0
|
||||
|
||||
# Log the message
|
||||
logger.warning(error_message)
|
||||
|
||||
@@ -168,6 +204,7 @@ class WebsocketService(ABC):
|
||||
report_error: Callback function to report connection errors.
|
||||
"""
|
||||
while True:
|
||||
self._last_connect_time = time.monotonic()
|
||||
try:
|
||||
await self._receive_messages()
|
||||
# _receive_messages() returned normally. This happens when the websocket
|
||||
@@ -205,6 +242,7 @@ class WebsocketService(ABC):
|
||||
additional setup required.
|
||||
"""
|
||||
self._disconnecting = False
|
||||
self._quick_failure_count = 0
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the service and set disconnecting flag.
|
||||
|
||||
273
tests/test_websocket_service.py
Normal file
273
tests/test_websocket_service.py
Normal file
@@ -0,0 +1,273 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for WebsocketService reconnection and lifecycle behavior."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
from websockets.frames import Close
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
|
||||
|
||||
class ConcreteWebsocketService(WebsocketService):
|
||||
"""Minimal concrete implementation for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._receive_messages_impl: AsyncMock | None = None
|
||||
|
||||
async def _connect_websocket(self):
|
||||
pass
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
pass
|
||||
|
||||
async def _receive_messages(self):
|
||||
if self._receive_messages_impl:
|
||||
await self._receive_messages_impl()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service():
|
||||
return ConcreteWebsocketService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def report_error():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_sleep():
|
||||
"""Patch asyncio.sleep globally to avoid real backoff waits."""
|
||||
with patch("pipecat.services.websocket_service.asyncio.sleep", new_callable=AsyncMock):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Receive loop — how each exception type is handled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_closed_ok_exits_cleanly(service, report_error):
|
||||
"""ConnectionClosedOK exits the loop with no error and no reconnection."""
|
||||
service._receive_messages_impl = AsyncMock(
|
||||
side_effect=ConnectionClosedOK(Close(1000, "Normal closure"), None)
|
||||
)
|
||||
service._try_reconnect = AsyncMock()
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
report_error.assert_not_called()
|
||||
service._try_reconnect.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_closed_error_triggers_reconnect(service, report_error):
|
||||
"""ConnectionClosedError triggers reconnection; loop continues after success."""
|
||||
call_count = 0
|
||||
|
||||
async def fail_then_exit():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ConnectionClosedError(Close(1006, "Abnormal closure"), None)
|
||||
service._disconnecting = True
|
||||
|
||||
service._receive_messages_impl = AsyncMock(side_effect=fail_then_exit)
|
||||
service._try_reconnect = AsyncMock(return_value=True)
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
assert call_count == 2
|
||||
service._try_reconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_server_close_triggers_reconnect(service, report_error):
|
||||
"""Normal return from _receive_messages (server close frame) triggers reconnection."""
|
||||
call_count = 0
|
||||
|
||||
async def return_then_exit():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
service._disconnecting = True
|
||||
|
||||
service._receive_messages_impl = AsyncMock(side_effect=return_then_exit)
|
||||
service._try_reconnect = AsyncMock(return_value=True)
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
assert call_count == 2
|
||||
service._try_reconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_general_exception_triggers_reconnect(service, report_error):
|
||||
"""A general exception in _receive_messages triggers reconnection."""
|
||||
call_count = 0
|
||||
|
||||
async def fail_then_exit():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("something broke")
|
||||
service._disconnecting = True
|
||||
|
||||
service._receive_messages_impl = AsyncMock(side_effect=fail_then_exit)
|
||||
service._try_reconnect = AsyncMock(return_value=True)
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
assert call_count == 2
|
||||
service._try_reconnect.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exponential backoff — server unreachable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_succeeds_on_later_attempt(service, report_error):
|
||||
"""_try_reconnect retries and succeeds on a later attempt."""
|
||||
service._reconnect_websocket = AsyncMock(
|
||||
side_effect=[ConnectionError("fail"), ConnectionError("fail"), True]
|
||||
)
|
||||
|
||||
result = await service._try_reconnect(report_error=report_error)
|
||||
|
||||
assert result is True
|
||||
assert service._reconnect_websocket.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_exhausted_emits_non_fatal_error(service, report_error):
|
||||
"""Exhausting all retries returns False and emits a non-fatal ErrorFrame."""
|
||||
service._reconnect_websocket = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
result = await service._try_reconnect(report_error=report_error)
|
||||
|
||||
assert result is False
|
||||
assert service._reconnect_websocket.call_count == 3
|
||||
final_error = report_error.call_args_list[-1][0][0]
|
||||
assert isinstance(final_error, ErrorFrame)
|
||||
assert final_error.fatal is False
|
||||
assert "Connection refused" in final_error.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quick failure detection — accept then immediately close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_failures_emit_error(service, report_error):
|
||||
"""Connections failing immediately after establishment emit error after 3 cycles."""
|
||||
call_count = 0
|
||||
|
||||
async def fail_immediately():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise ConnectionClosedError(Close(1008, "Invalid API key"), None)
|
||||
|
||||
service._receive_messages_impl = AsyncMock(side_effect=fail_immediately)
|
||||
service._try_reconnect = AsyncMock(return_value=True)
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
assert call_count == service._MAX_CONSECUTIVE_QUICK_FAILURES
|
||||
report_error.assert_called_once()
|
||||
error_frame = report_error.call_args[0][0]
|
||||
assert isinstance(error_frame, ErrorFrame)
|
||||
assert error_frame.fatal is False
|
||||
assert "failed 3 times immediately after connecting" in error_frame.error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stable_connection_resets_quick_failure_counter(service, report_error):
|
||||
"""A stable connection resets the quick failure counter; needs 3 new failures to trigger."""
|
||||
call_count = 0
|
||||
|
||||
async def always_fail():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise ConnectionClosedError(Close(1006, "Abnormal closure"), None)
|
||||
|
||||
service._receive_messages_impl = AsyncMock(side_effect=always_fail)
|
||||
service._try_reconnect = AsyncMock(return_value=True)
|
||||
|
||||
base_time = 1000.0
|
||||
time_values = iter(
|
||||
[
|
||||
# Call 1: set _last_connect_time, check in _maybe_try_reconnect (quick) -> count=1
|
||||
base_time,
|
||||
base_time,
|
||||
# Call 2: quick -> count=2
|
||||
base_time + 1.0,
|
||||
base_time + 1.0,
|
||||
# Call 3: stable (10s elapsed) -> count=0
|
||||
base_time + 2.0,
|
||||
base_time + 12.0,
|
||||
# Call 4: quick -> count=1
|
||||
base_time + 13.0,
|
||||
base_time + 13.0,
|
||||
# Call 5: quick -> count=2
|
||||
base_time + 14.0,
|
||||
base_time + 14.0,
|
||||
# Call 6: quick -> count=3 -> error emitted, loop stops
|
||||
base_time + 15.0,
|
||||
base_time + 15.0,
|
||||
]
|
||||
)
|
||||
|
||||
with patch("pipecat.services.websocket_service.time") as mock_time:
|
||||
mock_time.monotonic = lambda: next(time_values)
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
assert call_count == 6
|
||||
report_error.assert_called_once()
|
||||
error_frame = report_error.call_args[0][0]
|
||||
assert error_frame.fatal is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifecycle and guards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_prevents_reconnection(service, report_error):
|
||||
"""After _disconnect(), errors exit the loop without reconnecting or emitting errors."""
|
||||
await service._disconnect()
|
||||
|
||||
service._receive_messages_impl = AsyncMock(
|
||||
side_effect=ConnectionClosedError(Close(1006, "Abnormal closure"), None)
|
||||
)
|
||||
service._try_reconnect = AsyncMock()
|
||||
|
||||
await service._receive_task_handler(report_error)
|
||||
|
||||
report_error.assert_not_called()
|
||||
service._try_reconnect.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_resets_state(service):
|
||||
"""_connect() resets _disconnecting and _quick_failure_count."""
|
||||
service._disconnecting = True
|
||||
service._quick_failure_count = 5
|
||||
|
||||
await service._connect()
|
||||
|
||||
assert service._disconnecting is False
|
||||
assert service._quick_failure_count == 0
|
||||
Reference in New Issue
Block a user