From 47e53890e3e171d2b8244c6482c08ff95b68dc4d Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 28 Mar 2026 00:01:13 -0400 Subject: [PATCH] Fix FastAPI WebSocket disconnect race condition causing pipeline hang When the remote side disconnects while send() is in flight, send() was setting _closing=True. This prevented the receive loop from firing on_client_disconnected, causing the pipeline to hang waiting for a disconnect signal that never came. The fix removes _closing from send() (that flag means we initiated the close) and instead checks Starlette application_state in _can_send() to suppress subsequent sends after a failure. Fixes #3912 --- src/pipecat/transports/websocket/fastapi.py | 16 +-- tests/test_fastapi_websocket.py | 140 +++++++++++++++++++- 2 files changed, 144 insertions(+), 12 deletions(-) diff --git a/src/pipecat/transports/websocket/fastapi.py b/src/pipecat/transports/websocket/fastapi.py index 0fde2b9ae..d9b7d7ae1 100644 --- a/src/pipecat/transports/websocket/fastapi.py +++ b/src/pipecat/transports/websocket/fastapi.py @@ -150,17 +150,9 @@ class FastAPIWebsocketClient: else: await self._websocket.send_text(data) except Exception as e: - logger.error( + logger.warning( f"{self} exception sending data: {e.__class__.__name__} ({e}), application_state: {self._websocket.application_state}" ) - # For some reason the websocket is disconnected, and we are not able to send data - # So let's properly handle it and disconnect the transport if it is not already disconnecting - if ( - self._websocket.application_state == WebSocketState.DISCONNECTED - and not self.is_closing - ): - logger.warning("Closing already disconnected websocket!") - self._closing = True async def disconnect(self): """Disconnect the WebSocket client.""" @@ -189,7 +181,11 @@ class FastAPIWebsocketClient: def _can_send(self): """Check if data can be sent through the WebSocket.""" - return self.is_connected and not self.is_closing + return ( + self.is_connected + and not self.is_closing + and self._websocket.application_state != WebSocketState.DISCONNECTED + ) @property def is_connected(self) -> bool: diff --git a/tests/test_fastapi_websocket.py b/tests/test_fastapi_websocket.py index 44a1ff61a..d1fa435d9 100644 --- a/tests/test_fastapi_websocket.py +++ b/tests/test_fastapi_websocket.py @@ -4,10 +4,17 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio import unittest -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, PropertyMock -from pipecat.transports.websocket.fastapi import _WebSocketMessageIterator +from starlette.websockets import WebSocketState + +from pipecat.transports.websocket.fastapi import ( + FastAPIWebsocketCallbacks, + FastAPIWebsocketClient, + _WebSocketMessageIterator, +) class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase): @@ -66,5 +73,134 @@ class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(messages), 0) +class TestSendDisconnectRace(unittest.IsolatedAsyncioTestCase): + """Tests for the race condition in issue #3912. + + When the remote side disconnects while send() is in flight, send() should + not set _closing = True, because that flag means "we initiated the close." + Setting it from send() prevents the receive loop from firing + on_client_disconnected, which can cause the pipeline to hang. + """ + + def _make_client(self, mock_ws): + callbacks = FastAPIWebsocketCallbacks( + on_client_connected=AsyncMock(), + on_client_disconnected=AsyncMock(), + on_session_timeout=AsyncMock(), + ) + client = FastAPIWebsocketClient(mock_ws, callbacks) + return client, callbacks + + async def test_send_disconnect_does_not_set_closing(self): + """send() should not set _closing when the remote side disconnects.""" + mock_ws = AsyncMock() + type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED) + type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED) + mock_ws.send_bytes.side_effect = Exception("connection closed") + + client, _ = self._make_client(mock_ws) + + await client.send(b"audio data") + + self.assertFalse(client.is_closing) + + async def test_send_suppressed_after_disconnect(self): + """After a failed send, _can_send() returns False via application_state. + + Simulates real Starlette behavior: application_state starts CONNECTED, + transitions to DISCONNECTED when send_bytes raises (Starlette does this + internally on OSError before re-raising as WebSocketDisconnect). + """ + mock_ws = AsyncMock() + type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED) + + # application_state transitions from CONNECTED → DISCONNECTED on send failure + app_state = {"state": WebSocketState.CONNECTED} + type(mock_ws).application_state = PropertyMock(side_effect=lambda: app_state["state"]) + + def fail_and_transition(data): + app_state["state"] = WebSocketState.DISCONNECTED + raise Exception("connection closed") + + mock_ws.send_bytes.side_effect = fail_and_transition + + client, _ = self._make_client(mock_ws) + + # First send: _can_send() passes (app_state CONNECTED), send_bytes raises, + # Starlette sets app_state to DISCONNECTED + await client.send(b"audio data") + # Second send: _can_send() returns False (app_state now DISCONNECTED) + await client.send(b"more audio") + + # send_bytes was only called once (the first attempt) + mock_ws.send_bytes.assert_called_once() + + async def test_disconnect_callback_fires_when_send_races_receive(self): + """Regression test for issue #3912. + + The receive loop is blocked waiting for the next message. Meanwhile, + send() is called and hits an exception because the remote side closed. + Then the receive loop unblocks and sees the disconnect. + + on_client_disconnected must still fire, because the remote side + initiated the close — not us. + """ + send_done = asyncio.Event() + + mock_ws = AsyncMock() + type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED) + type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED) + mock_ws.send_bytes.side_effect = Exception("connection closed") + + # receive() blocks until send has completed, then returns disconnect. + # This enforces the exact ordering that causes the bug. + async def mock_receive(): + await send_done.wait() + return {"type": "websocket.disconnect"} + + mock_ws.receive = mock_receive + + client, callbacks = self._make_client(mock_ws) + + # Simulate the _receive_messages logic from FastAPIWebsocketInputTransport + async def receive_loop(): + try: + async for _ in _WebSocketMessageIterator(mock_ws): + pass + except Exception: + pass + if not client.is_closing: + await client.trigger_client_disconnected() + + recv_task = asyncio.create_task(receive_loop()) + + # Let the receive loop start and block on receive() + await asyncio.sleep(0) + + # send() races — hits exception but does NOT set _closing + await client.send(b"audio data") + self.assertFalse(client.is_closing) + + # Unblock the receive loop — it sees the disconnect + send_done.set() + await recv_task + + # The callback fires because _closing was not poisoned by send() + callbacks.on_client_disconnected.assert_called_once() + + async def test_send_text_disconnect_does_not_set_closing(self): + """Same as test_send_disconnect_does_not_set_closing but with text data.""" + mock_ws = AsyncMock() + type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED) + type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED) + mock_ws.send_text.side_effect = Exception("connection closed") + + client, _ = self._make_client(mock_ws) + + await client.send("text data") + + self.assertFalse(client.is_closing) + + if __name__ == "__main__": unittest.main()