From 7c7b4c52af63339f2be4bf0bf8c267a8bff3149f Mon Sep 17 00:00:00 2001 From: Filipi Fuchter Date: Fri, 21 Feb 2025 09:11:58 -0300 Subject: [PATCH] Fixed an issue where EndTaskFrame was not triggering on_client_disconnected or closing the WebSocket in FastAPI. --- CHANGELOG.md | 2 + .../transports/network/fastapi_websocket.py | 120 +++++++++++++----- 2 files changed, 89 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79c663e7f..6c9e2292f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general")) ### Fixed +- Fixed an issue where `EndTaskFrame` was not triggering `on_client_disconnected` or closing the WebSocket in FastAPI. + - Fixed a context aggregator issue that would not append the LLM text response to the context if a function call happened in the same LLM turn. diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index 9c4c170f2..565d1906c 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -55,45 +55,89 @@ class FastAPIWebsocketCallbacks(BaseModel): on_session_timeout: Callable[[WebSocket], Awaitable[None]] +class FastAPIWebsocketClient: + def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks): + self._websocket = websocket + self._closing = False + self._is_binary = is_binary + self._callbacks = callbacks + + def receive(self) -> typing.AsyncIterator[bytes | str]: + return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text() + + async def send(self, data: str | bytes): + if self._can_send(): + if self._is_binary: + await self._websocket.send_bytes(data) + else: + await self._websocket.send_text(data) + + async def disconnect(self): + if self.is_connected and not self.is_closing: + self._closing = True + await self._websocket.close() + await self.trigger_client_disconnected() + + async def trigger_client_disconnected(self): + await self._callbacks.on_client_disconnected(self._websocket) + + async def trigger_client_connected(self): + await self._callbacks.on_client_connected(self._websocket) + + async def trigger_client_timout(self): + await self._callbacks.on_session_timeout(self._websocket) + + def _can_send(self): + return self.is_connected and not self.is_closing + + @property + def is_connected(self) -> bool: + return self._websocket.client_state == WebSocketState.CONNECTED + + @property + def is_closing(self) -> bool: + return self._closing + + class FastAPIWebsocketInputTransport(BaseInputTransport): def __init__( self, - websocket: WebSocket, + client: FastAPIWebsocketClient, params: FastAPIWebsocketParams, - callbacks: FastAPIWebsocketCallbacks, **kwargs, ): super().__init__(params, **kwargs) - - self._websocket = websocket + self._client = client self._params = params - self._callbacks = callbacks + self._receive_task = None + self._monitor_websocket_task = None async def start(self, frame: StartFrame): await super().start(frame) await self._params.serializer.setup(frame) if self._params.session_timeout: self._monitor_websocket_task = self.create_task(self._monitor_websocket()) - await self._callbacks.on_client_connected(self._websocket) + await self._client.trigger_client_connected() self._receive_task = self.create_task(self._receive_messages()) + async def _stop_tasks(self): + if self._monitor_websocket_task: + await self.cancel_task(self._monitor_websocket_task) + await self.cancel_task(self._receive_task) + async def stop(self, frame: EndFrame): await super().stop(frame) - await self.cancel_task(self._receive_task) + await self._stop_tasks() + await self._client.disconnect() async def cancel(self, frame: CancelFrame): await super().cancel(frame) - await self.cancel_task(self._receive_task) - - def _iter_data(self) -> typing.AsyncIterator[bytes | str]: - if self._params.serializer.type == FrameSerializerType.BINARY: - return self._websocket.iter_bytes() - else: - return self._websocket.iter_text() + await self._stop_tasks() + await self._client.disconnect() async def _receive_messages(self): try: - async for message in self._iter_data(): + async for message in self._client.receive(): frame = await self._params.serializer.deserialize(message) if not frame: @@ -106,19 +150,23 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): except Exception as e: logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})") - await self._callbacks.on_client_disconnected(self._websocket) + await self._client.trigger_client_disconnected() async def _monitor_websocket(self): """Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event.""" await asyncio.sleep(self._params.session_timeout) - await self._callbacks.on_session_timeout(self._websocket) + await self._client.trigger_client_timout() class FastAPIWebsocketOutputTransport(BaseOutputTransport): - def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs): + def __init__( + self, + client: FastAPIWebsocketClient, + params: FastAPIWebsocketParams, + **kwargs, + ): super().__init__(params, **kwargs) - - self._websocket = websocket + self._client = client self._params = params # write_raw_audio_frames() is called quickly, as soon as we get audio @@ -134,6 +182,14 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self._params.serializer.setup(frame) self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2 + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._client.disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._client.disconnect() + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -145,7 +201,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self._write_frame(frame) async def write_raw_audio_frames(self, frames: bytes): - if self._websocket.client_state != WebSocketState.CONNECTED: + if self._client.is_closing: + return + + if not self._client.is_connected: # Simulate audio playback with a sleep. await self._write_audio_sleep() return @@ -172,25 +231,17 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self._write_frame(frame) - self._websocket_audio_buffer = bytes() - # Simulate audio playback with a sleep. await self._write_audio_sleep() async def _write_frame(self, frame: Frame): try: payload = await self._params.serializer.serialize(frame) - if payload and self._websocket.client_state == WebSocketState.CONNECTED: - await self._send_data(payload) + if payload: + await self._client.send(payload) except Exception as e: logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})") - def _send_data(self, data: str | bytes): - if self._params.serializer.type == FrameSerializerType.BINARY: - return self._websocket.send_bytes(data) - else: - return self._websocket.send_text(data) - async def _write_audio_sleep(self): # Simulate a clock. current_time = time.monotonic() @@ -219,11 +270,14 @@ class FastAPIWebsocketTransport(BaseTransport): on_session_timeout=self._on_session_timeout, ) + is_binary = self._params.serializer.type == FrameSerializerType.BINARY + self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks) + self._input = FastAPIWebsocketInputTransport( - websocket, self._params, self._callbacks, name=self._input_name + self._client, self._params, name=self._input_name ) self._output = FastAPIWebsocketOutputTransport( - websocket, self._params, name=self._output_name + self._client, self._params, name=self._output_name ) # Register supported handlers. The user will only be able to register