Merge pull request #1257 from pipecat-ai/fastapi_disconnect_issue

Fixed an issue where FastAPI was not triggering on_client_disconnected.
This commit is contained in:
Filipi da Silva Fuchter
2025-02-21 09:15:15 -03:00
committed by GitHub
2 changed files with 89 additions and 33 deletions

View File

@@ -32,6 +32,8 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
### Fixed
- Fixed an issue where `EndTaskFrame` was not triggering `on_client_disconnected` or closing the WebSocket in FastAPI.
- Fixed a context aggregator issue that would not append the LLM text response
to the context if a function call happened in the same LLM turn.

View File

@@ -55,45 +55,89 @@ class FastAPIWebsocketCallbacks(BaseModel):
on_session_timeout: Callable[[WebSocket], Awaitable[None]]
class FastAPIWebsocketClient:
def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks):
self._websocket = websocket
self._closing = False
self._is_binary = is_binary
self._callbacks = callbacks
def receive(self) -> typing.AsyncIterator[bytes | str]:
return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text()
async def send(self, data: str | bytes):
if self._can_send():
if self._is_binary:
await self._websocket.send_bytes(data)
else:
await self._websocket.send_text(data)
async def disconnect(self):
if self.is_connected and not self.is_closing:
self._closing = True
await self._websocket.close()
await self.trigger_client_disconnected()
async def trigger_client_disconnected(self):
await self._callbacks.on_client_disconnected(self._websocket)
async def trigger_client_connected(self):
await self._callbacks.on_client_connected(self._websocket)
async def trigger_client_timout(self):
await self._callbacks.on_session_timeout(self._websocket)
def _can_send(self):
return self.is_connected and not self.is_closing
@property
def is_connected(self) -> bool:
return self._websocket.client_state == WebSocketState.CONNECTED
@property
def is_closing(self) -> bool:
return self._closing
class FastAPIWebsocketInputTransport(BaseInputTransport):
def __init__(
self,
websocket: WebSocket,
client: FastAPIWebsocketClient,
params: FastAPIWebsocketParams,
callbacks: FastAPIWebsocketCallbacks,
**kwargs,
):
super().__init__(params, **kwargs)
self._websocket = websocket
self._client = client
self._params = params
self._callbacks = callbacks
self._receive_task = None
self._monitor_websocket_task = None
async def start(self, frame: StartFrame):
await super().start(frame)
await self._params.serializer.setup(frame)
if self._params.session_timeout:
self._monitor_websocket_task = self.create_task(self._monitor_websocket())
await self._callbacks.on_client_connected(self._websocket)
await self._client.trigger_client_connected()
self._receive_task = self.create_task(self._receive_messages())
async def _stop_tasks(self):
if self._monitor_websocket_task:
await self.cancel_task(self._monitor_websocket_task)
await self.cancel_task(self._receive_task)
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self.cancel_task(self._receive_task)
await self._stop_tasks()
await self._client.disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self.cancel_task(self._receive_task)
def _iter_data(self) -> typing.AsyncIterator[bytes | str]:
if self._params.serializer.type == FrameSerializerType.BINARY:
return self._websocket.iter_bytes()
else:
return self._websocket.iter_text()
await self._stop_tasks()
await self._client.disconnect()
async def _receive_messages(self):
try:
async for message in self._iter_data():
async for message in self._client.receive():
frame = await self._params.serializer.deserialize(message)
if not frame:
@@ -106,19 +150,23 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
await self._callbacks.on_client_disconnected(self._websocket)
await self._client.trigger_client_disconnected()
async def _monitor_websocket(self):
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
await asyncio.sleep(self._params.session_timeout)
await self._callbacks.on_session_timeout(self._websocket)
await self._client.trigger_client_timout()
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
def __init__(
self,
client: FastAPIWebsocketClient,
params: FastAPIWebsocketParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._websocket = websocket
self._client = client
self._params = params
# write_raw_audio_frames() is called quickly, as soon as we get audio
@@ -134,6 +182,14 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
await self._params.serializer.setup(frame)
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._client.disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -145,7 +201,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
await self._write_frame(frame)
async def write_raw_audio_frames(self, frames: bytes):
if self._websocket.client_state != WebSocketState.CONNECTED:
if self._client.is_closing:
return
if not self._client.is_connected:
# Simulate audio playback with a sleep.
await self._write_audio_sleep()
return
@@ -172,25 +231,17 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
await self._write_frame(frame)
self._websocket_audio_buffer = bytes()
# Simulate audio playback with a sleep.
await self._write_audio_sleep()
async def _write_frame(self, frame: Frame):
try:
payload = await self._params.serializer.serialize(frame)
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
await self._send_data(payload)
if payload:
await self._client.send(payload)
except Exception as e:
logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})")
def _send_data(self, data: str | bytes):
if self._params.serializer.type == FrameSerializerType.BINARY:
return self._websocket.send_bytes(data)
else:
return self._websocket.send_text(data)
async def _write_audio_sleep(self):
# Simulate a clock.
current_time = time.monotonic()
@@ -219,11 +270,14 @@ class FastAPIWebsocketTransport(BaseTransport):
on_session_timeout=self._on_session_timeout,
)
is_binary = self._params.serializer.type == FrameSerializerType.BINARY
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
self._input = FastAPIWebsocketInputTransport(
websocket, self._params, self._callbacks, name=self._input_name
self._client, self._params, name=self._input_name
)
self._output = FastAPIWebsocketOutputTransport(
websocket, self._params, name=self._output_name
self._client, self._params, name=self._output_name
)
# Register supported handlers. The user will only be able to register