Merge pull request #1257 from pipecat-ai/fastapi_disconnect_issue
Fixed an issue where FastAPI was not triggering on_client_disconnected.
This commit is contained in:
@@ -32,6 +32,8 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `EndTaskFrame` was not triggering `on_client_disconnected` or closing the WebSocket in FastAPI.
|
||||
|
||||
- Fixed a context aggregator issue that would not append the LLM text response
|
||||
to the context if a function call happened in the same LLM turn.
|
||||
|
||||
|
||||
@@ -55,45 +55,89 @@ class FastAPIWebsocketCallbacks(BaseModel):
|
||||
on_session_timeout: Callable[[WebSocket], Awaitable[None]]
|
||||
|
||||
|
||||
class FastAPIWebsocketClient:
|
||||
def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks):
|
||||
self._websocket = websocket
|
||||
self._closing = False
|
||||
self._is_binary = is_binary
|
||||
self._callbacks = callbacks
|
||||
|
||||
def receive(self) -> typing.AsyncIterator[bytes | str]:
|
||||
return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text()
|
||||
|
||||
async def send(self, data: str | bytes):
|
||||
if self._can_send():
|
||||
if self._is_binary:
|
||||
await self._websocket.send_bytes(data)
|
||||
else:
|
||||
await self._websocket.send_text(data)
|
||||
|
||||
async def disconnect(self):
|
||||
if self.is_connected and not self.is_closing:
|
||||
self._closing = True
|
||||
await self._websocket.close()
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def trigger_client_disconnected(self):
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
|
||||
async def trigger_client_connected(self):
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
|
||||
async def trigger_client_timout(self):
|
||||
await self._callbacks.on_session_timeout(self._websocket)
|
||||
|
||||
def _can_send(self):
|
||||
return self.is_connected and not self.is_closing
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._websocket.client_state == WebSocketState.CONNECTED
|
||||
|
||||
@property
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
|
||||
|
||||
class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
callbacks: FastAPIWebsocketCallbacks,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._callbacks = callbacks
|
||||
self._receive_task = None
|
||||
self._monitor_websocket_task = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
if self._params.session_timeout:
|
||||
self._monitor_websocket_task = self.create_task(self._monitor_websocket())
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
await self._client.trigger_client_connected()
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._monitor_websocket_task:
|
||||
await self.cancel_task(self._monitor_websocket_task)
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self.cancel_task(self._receive_task)
|
||||
await self._stop_tasks()
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
def _iter_data(self) -> typing.AsyncIterator[bytes | str]:
|
||||
if self._params.serializer.type == FrameSerializerType.BINARY:
|
||||
return self._websocket.iter_bytes()
|
||||
else:
|
||||
return self._websocket.iter_text()
|
||||
await self._stop_tasks()
|
||||
await self._client.disconnect()
|
||||
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
async for message in self._iter_data():
|
||||
async for message in self._client.receive():
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
@@ -106,19 +150,23 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
await asyncio.sleep(self._params.session_timeout)
|
||||
await self._callbacks.on_session_timeout(self._websocket)
|
||||
await self._client.trigger_client_timout()
|
||||
|
||||
|
||||
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
# write_raw_audio_frames() is called quickly, as soon as we get audio
|
||||
@@ -134,6 +182,14 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -145,7 +201,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await self._write_frame(frame)
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
if self._websocket.client_state != WebSocketState.CONNECTED:
|
||||
if self._client.is_closing:
|
||||
return
|
||||
|
||||
if not self._client.is_connected:
|
||||
# Simulate audio playback with a sleep.
|
||||
await self._write_audio_sleep()
|
||||
return
|
||||
@@ -172,25 +231,17 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
await self._write_frame(frame)
|
||||
|
||||
self._websocket_audio_buffer = bytes()
|
||||
|
||||
# Simulate audio playback with a sleep.
|
||||
await self._write_audio_sleep()
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
try:
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self._send_data(payload)
|
||||
if payload:
|
||||
await self._client.send(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})")
|
||||
|
||||
def _send_data(self, data: str | bytes):
|
||||
if self._params.serializer.type == FrameSerializerType.BINARY:
|
||||
return self._websocket.send_bytes(data)
|
||||
else:
|
||||
return self._websocket.send_text(data)
|
||||
|
||||
async def _write_audio_sleep(self):
|
||||
# Simulate a clock.
|
||||
current_time = time.monotonic()
|
||||
@@ -219,11 +270,14 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
on_session_timeout=self._on_session_timeout,
|
||||
)
|
||||
|
||||
is_binary = self._params.serializer.type == FrameSerializerType.BINARY
|
||||
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
websocket, self._params, self._callbacks, name=self._input_name
|
||||
self._client, self._params, name=self._input_name
|
||||
)
|
||||
self._output = FastAPIWebsocketOutputTransport(
|
||||
websocket, self._params, name=self._output_name
|
||||
self._client, self._params, name=self._output_name
|
||||
)
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
|
||||
Reference in New Issue
Block a user