When the remote side disconnects while send() is in flight, send() was setting _closing=True. This prevented the receive loop from firing on_client_disconnected, causing the pipeline to hang waiting for a disconnect signal that never came. The fix removes _closing from send() (that flag means we initiated the close) and instead checks Starlette application_state in _can_send() to suppress subsequent sends after a failure. Fixes #3912
207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import asyncio
|
|
import unittest
|
|
from unittest.mock import AsyncMock, PropertyMock
|
|
|
|
from starlette.websockets import WebSocketState
|
|
|
|
from pipecat.transports.websocket.fastapi import (
|
|
FastAPIWebsocketCallbacks,
|
|
FastAPIWebsocketClient,
|
|
_WebSocketMessageIterator,
|
|
)
|
|
|
|
|
|
class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase):
|
|
async def test_yields_binary_message(self):
|
|
mock_websocket = AsyncMock()
|
|
mock_websocket.receive.side_effect = [
|
|
{"type": "websocket.receive", "bytes": b"binary data", "text": None},
|
|
{"type": "websocket.disconnect"},
|
|
]
|
|
|
|
iterator = _WebSocketMessageIterator(mock_websocket)
|
|
messages = [msg async for msg in iterator]
|
|
|
|
self.assertEqual(len(messages), 1)
|
|
self.assertEqual(messages[0], b"binary data")
|
|
|
|
async def test_yields_text_message(self):
|
|
mock_websocket = AsyncMock()
|
|
mock_websocket.receive.side_effect = [
|
|
{"type": "websocket.receive", "bytes": None, "text": "text data"},
|
|
{"type": "websocket.disconnect"},
|
|
]
|
|
|
|
iterator = _WebSocketMessageIterator(mock_websocket)
|
|
messages = [msg async for msg in iterator]
|
|
|
|
self.assertEqual(len(messages), 1)
|
|
self.assertEqual(messages[0], "text data")
|
|
|
|
async def test_yields_mixed_messages(self):
|
|
mock_websocket = AsyncMock()
|
|
mock_websocket.receive.side_effect = [
|
|
{"type": "websocket.receive", "bytes": b"binary", "text": None},
|
|
{"type": "websocket.receive", "bytes": None, "text": "text"},
|
|
{"type": "websocket.receive", "bytes": b"more binary", "text": None},
|
|
{"type": "websocket.disconnect"},
|
|
]
|
|
|
|
iterator = _WebSocketMessageIterator(mock_websocket)
|
|
messages = [msg async for msg in iterator]
|
|
|
|
self.assertEqual(len(messages), 3)
|
|
self.assertEqual(messages[0], b"binary")
|
|
self.assertEqual(messages[1], "text")
|
|
self.assertEqual(messages[2], b"more binary")
|
|
|
|
async def test_stops_on_disconnect(self):
|
|
mock_websocket = AsyncMock()
|
|
mock_websocket.receive.side_effect = [
|
|
{"type": "websocket.disconnect"},
|
|
]
|
|
|
|
iterator = _WebSocketMessageIterator(mock_websocket)
|
|
messages = [msg async for msg in iterator]
|
|
|
|
self.assertEqual(len(messages), 0)
|
|
|
|
|
|
class TestSendDisconnectRace(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for the race condition in issue #3912.
|
|
|
|
When the remote side disconnects while send() is in flight, send() should
|
|
not set _closing = True, because that flag means "we initiated the close."
|
|
Setting it from send() prevents the receive loop from firing
|
|
on_client_disconnected, which can cause the pipeline to hang.
|
|
"""
|
|
|
|
def _make_client(self, mock_ws):
|
|
callbacks = FastAPIWebsocketCallbacks(
|
|
on_client_connected=AsyncMock(),
|
|
on_client_disconnected=AsyncMock(),
|
|
on_session_timeout=AsyncMock(),
|
|
)
|
|
client = FastAPIWebsocketClient(mock_ws, callbacks)
|
|
return client, callbacks
|
|
|
|
async def test_send_disconnect_does_not_set_closing(self):
|
|
"""send() should not set _closing when the remote side disconnects."""
|
|
mock_ws = AsyncMock()
|
|
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
|
|
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
|
|
mock_ws.send_bytes.side_effect = Exception("connection closed")
|
|
|
|
client, _ = self._make_client(mock_ws)
|
|
|
|
await client.send(b"audio data")
|
|
|
|
self.assertFalse(client.is_closing)
|
|
|
|
async def test_send_suppressed_after_disconnect(self):
|
|
"""After a failed send, _can_send() returns False via application_state.
|
|
|
|
Simulates real Starlette behavior: application_state starts CONNECTED,
|
|
transitions to DISCONNECTED when send_bytes raises (Starlette does this
|
|
internally on OSError before re-raising as WebSocketDisconnect).
|
|
"""
|
|
mock_ws = AsyncMock()
|
|
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
|
|
|
|
# application_state transitions from CONNECTED → DISCONNECTED on send failure
|
|
app_state = {"state": WebSocketState.CONNECTED}
|
|
type(mock_ws).application_state = PropertyMock(side_effect=lambda: app_state["state"])
|
|
|
|
def fail_and_transition(data):
|
|
app_state["state"] = WebSocketState.DISCONNECTED
|
|
raise Exception("connection closed")
|
|
|
|
mock_ws.send_bytes.side_effect = fail_and_transition
|
|
|
|
client, _ = self._make_client(mock_ws)
|
|
|
|
# First send: _can_send() passes (app_state CONNECTED), send_bytes raises,
|
|
# Starlette sets app_state to DISCONNECTED
|
|
await client.send(b"audio data")
|
|
# Second send: _can_send() returns False (app_state now DISCONNECTED)
|
|
await client.send(b"more audio")
|
|
|
|
# send_bytes was only called once (the first attempt)
|
|
mock_ws.send_bytes.assert_called_once()
|
|
|
|
async def test_disconnect_callback_fires_when_send_races_receive(self):
|
|
"""Regression test for issue #3912.
|
|
|
|
The receive loop is blocked waiting for the next message. Meanwhile,
|
|
send() is called and hits an exception because the remote side closed.
|
|
Then the receive loop unblocks and sees the disconnect.
|
|
|
|
on_client_disconnected must still fire, because the remote side
|
|
initiated the close — not us.
|
|
"""
|
|
send_done = asyncio.Event()
|
|
|
|
mock_ws = AsyncMock()
|
|
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
|
|
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
|
|
mock_ws.send_bytes.side_effect = Exception("connection closed")
|
|
|
|
# receive() blocks until send has completed, then returns disconnect.
|
|
# This enforces the exact ordering that causes the bug.
|
|
async def mock_receive():
|
|
await send_done.wait()
|
|
return {"type": "websocket.disconnect"}
|
|
|
|
mock_ws.receive = mock_receive
|
|
|
|
client, callbacks = self._make_client(mock_ws)
|
|
|
|
# Simulate the _receive_messages logic from FastAPIWebsocketInputTransport
|
|
async def receive_loop():
|
|
try:
|
|
async for _ in _WebSocketMessageIterator(mock_ws):
|
|
pass
|
|
except Exception:
|
|
pass
|
|
if not client.is_closing:
|
|
await client.trigger_client_disconnected()
|
|
|
|
recv_task = asyncio.create_task(receive_loop())
|
|
|
|
# Let the receive loop start and block on receive()
|
|
await asyncio.sleep(0)
|
|
|
|
# send() races — hits exception but does NOT set _closing
|
|
await client.send(b"audio data")
|
|
self.assertFalse(client.is_closing)
|
|
|
|
# Unblock the receive loop — it sees the disconnect
|
|
send_done.set()
|
|
await recv_task
|
|
|
|
# The callback fires because _closing was not poisoned by send()
|
|
callbacks.on_client_disconnected.assert_called_once()
|
|
|
|
async def test_send_text_disconnect_does_not_set_closing(self):
|
|
"""Same as test_send_disconnect_does_not_set_closing but with text data."""
|
|
mock_ws = AsyncMock()
|
|
type(mock_ws).client_state = PropertyMock(return_value=WebSocketState.CONNECTED)
|
|
type(mock_ws).application_state = PropertyMock(return_value=WebSocketState.DISCONNECTED)
|
|
mock_ws.send_text.side_effect = Exception("connection closed")
|
|
|
|
client, _ = self._make_client(mock_ws)
|
|
|
|
await client.send("text data")
|
|
|
|
self.assertFalse(client.is_closing)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|