Fix FastAPI WebSocket disconnect race condition causing pipeline hang

When the remote side disconnects while send() is in flight, send() was
setting _closing=True. This prevented the receive loop from firing
on_client_disconnected, causing the pipeline to hang waiting for a
disconnect signal that never came.

The fix removes _closing from send() (that flag means we initiated the
close) and instead checks Starlette application_state in _can_send()
to suppress subsequent sends after a failure.

Fixes #3912
This commit is contained in:
Mark Backman
2026-03-28 00:01:13 -04:00
parent b33df03724
commit 47e53890e3
2 changed files with 144 additions and 12 deletions

View File

@@ -150,17 +150,9 @@ class FastAPIWebsocketClient:
else:
await self._websocket.send_text(data)
except Exception as e:
logger.error(
logger.warning(
f"{self} exception sending data: {e.__class__.__name__} ({e}), application_state: {self._websocket.application_state}"
)
# For some reason the websocket is disconnected, and we are not able to send data
# So let's properly handle it and disconnect the transport if it is not already disconnecting
if (
self._websocket.application_state == WebSocketState.DISCONNECTED
and not self.is_closing
):
logger.warning("Closing already disconnected websocket!")
self._closing = True
async def disconnect(self):
"""Disconnect the WebSocket client."""
@@ -189,7 +181,11 @@ class FastAPIWebsocketClient:
def _can_send(self):
"""Check if data can be sent through the WebSocket."""
return self.is_connected and not self.is_closing
return (
self.is_connected
and not self.is_closing
and self._websocket.application_state != WebSocketState.DISCONNECTED
)
@property
def is_connected(self) -> bool:

View File

@@ -4,10 +4,17 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, PropertyMock
from pipecat.transports.websocket.fastapi import _WebSocketMessageIterator
from starlette.websockets import WebSocketState
from pipecat.transports.websocket.fastapi import (
FastAPIWebsocketCallbacks,
FastAPIWebsocketClient,
_WebSocketMessageIterator,
)
class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase):
@@ -66,5 +73,134 @@ class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(messages), 0)
class TestSendDisconnectRace(unittest.IsolatedAsyncioTestCase):
"""Tests for the race condition in issue #3912.
When the remote side disconnects while send() is in flight, send() should
not set _closing = True, because that flag means "we initiated the close."
Setting it from send() prevents the receive loop from firing
on_client_disconnected, which can cause the pipeline to hang.
"""
def _make_client(self, mock_ws):
callbacks = FastAPIWebsocketCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_session_timeout=AsyncMock(),
)
client = FastAPIWebsocketClient(mock_ws, callbacks)
return client, callbacks
async def test_send_disconnect_does_not_set_closing(self):
"""send() should not set _closing when the remote side disconnects."""
mock_ws = AsyncMock()
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
mock_ws.send_bytes.side_effect = Exception("connection closed")
client, _ = self._make_client(mock_ws)
await client.send(b"audio data")
self.assertFalse(client.is_closing)
async def test_send_suppressed_after_disconnect(self):
"""After a failed send, _can_send() returns False via application_state.
Simulates real Starlette behavior: application_state starts CONNECTED,
transitions to DISCONNECTED when send_bytes raises (Starlette does this
internally on OSError before re-raising as WebSocketDisconnect).
"""
mock_ws = AsyncMock()
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
# application_state transitions from CONNECTED → DISCONNECTED on send failure
app_state = {"state": WebSocketState.CONNECTED}
type(mock_ws).application_state = PropertyMock(side_effect=lambda: app_state["state"])
def fail_and_transition(data):
app_state["state"] = WebSocketState.DISCONNECTED
raise Exception("connection closed")
mock_ws.send_bytes.side_effect = fail_and_transition
client, _ = self._make_client(mock_ws)
# First send: _can_send() passes (app_state CONNECTED), send_bytes raises,
# Starlette sets app_state to DISCONNECTED
await client.send(b"audio data")
# Second send: _can_send() returns False (app_state now DISCONNECTED)
await client.send(b"more audio")
# send_bytes was only called once (the first attempt)
mock_ws.send_bytes.assert_called_once()
async def test_disconnect_callback_fires_when_send_races_receive(self):
"""Regression test for issue #3912.
The receive loop is blocked waiting for the next message. Meanwhile,
send() is called and hits an exception because the remote side closed.
Then the receive loop unblocks and sees the disconnect.
on_client_disconnected must still fire, because the remote side
initiated the close — not us.
"""
send_done = asyncio.Event()
mock_ws = AsyncMock()
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
mock_ws.send_bytes.side_effect = Exception("connection closed")
# receive() blocks until send has completed, then returns disconnect.
# This enforces the exact ordering that causes the bug.
async def mock_receive():
await send_done.wait()
return {"type": "websocket.disconnect"}
mock_ws.receive = mock_receive
client, callbacks = self._make_client(mock_ws)
# Simulate the _receive_messages logic from FastAPIWebsocketInputTransport
async def receive_loop():
try:
async for _ in _WebSocketMessageIterator(mock_ws):
pass
except Exception:
pass
if not client.is_closing:
await client.trigger_client_disconnected()
recv_task = asyncio.create_task(receive_loop())
# Let the receive loop start and block on receive()
await asyncio.sleep(0)
# send() races — hits exception but does NOT set _closing
await client.send(b"audio data")
self.assertFalse(client.is_closing)
# Unblock the receive loop — it sees the disconnect
send_done.set()
await recv_task
# The callback fires because _closing was not poisoned by send()
callbacks.on_client_disconnected.assert_called_once()
async def test_send_text_disconnect_does_not_set_closing(self):
"""Same as test_send_disconnect_does_not_set_closing but with text data."""
mock_ws = AsyncMock()
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
mock_ws.send_text.side_effect = Exception("connection closed")
client, _ = self._make_client(mock_ws)
await client.send("text data")
self.assertFalse(client.is_closing)
if __name__ == "__main__":
unittest.main()