Merge pull request #3379 from lukepayyapilli/fix/fastapi-websocket-json-text-handling
Fix FastAPIWebsocketTransport to handle both binary and text messages
This commit is contained in:
2
.github/workflows/coverage.yaml
vendored
2
.github/workflows/coverage.yaml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
|
||||
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
@@ -7,41 +7,18 @@
|
||||
"""Frame serialization interfaces for Pipecat."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
from pipecat.frames.frames import Frame, StartFrame
|
||||
|
||||
|
||||
class FrameSerializerType(Enum):
|
||||
"""Enumeration of supported frame serialization formats.
|
||||
|
||||
Parameters:
|
||||
BINARY: Binary serialization format for compact representation.
|
||||
TEXT: Text-based serialization format for human-readable output.
|
||||
"""
|
||||
|
||||
BINARY = "binary"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class FrameSerializer(ABC):
|
||||
"""Abstract base class for frame serialization implementations.
|
||||
|
||||
Defines the interface for converting frames to/from serialized formats
|
||||
for transmission or storage. Subclasses must implement serialization
|
||||
type detection and the core serialize/deserialize methods.
|
||||
for transmission or storage. Subclasses must implement the core
|
||||
serialize/deserialize methods.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Get the serialization type supported by this serializer.
|
||||
|
||||
Returns:
|
||||
The FrameSerializerType indicating binary or text format.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Initialize the serializer with startup configuration.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from pipecat.frames.frames import (
|
||||
OutputTransportMessageUrgentFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class ExotelFrameSerializer(FrameSerializer):
|
||||
@@ -70,15 +70,6 @@ class ExotelFrameSerializer(FrameSerializer):
|
||||
self._input_resampler = create_stream_resampler()
|
||||
self._output_resampler = create_stream_resampler()
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
OutputTransportMessageUrgentFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class PlivoFrameSerializer(FrameSerializer):
|
||||
@@ -85,15 +85,6 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
self._output_resampler = create_stream_resampler()
|
||||
self._hangup_attempted = False
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from pipecat.frames.frames import (
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -64,15 +64,6 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
"""Initialize the Protobuf frame serializer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Get the serializer type.
|
||||
|
||||
Returns:
|
||||
FrameSerializerType.BINARY indicating binary serialization format.
|
||||
"""
|
||||
return FrameSerializerType.BINARY
|
||||
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
"""Serialize a frame to Protocol Buffer binary format.
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class TelnyxFrameSerializer(FrameSerializer):
|
||||
@@ -97,15 +97,6 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
self._output_resampler = create_stream_resampler()
|
||||
self._hangup_attempted = False
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
OutputTransportMessageUrgentFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class TwilioFrameSerializer(FrameSerializer):
|
||||
@@ -116,15 +116,6 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
self._output_resampler = create_stream_resampler()
|
||||
self._hangup_attempted = False
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
@@ -77,6 +77,26 @@ class FastAPIWebsocketCallbacks(BaseModel):
|
||||
on_session_timeout: Callable[[WebSocket], Awaitable[None]]
|
||||
|
||||
|
||||
class _WebSocketMessageIterator:
|
||||
"""Async iterator for WebSocket messages that yields both binary and text."""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self._websocket = websocket
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes | str:
|
||||
message = await self._websocket.receive()
|
||||
if message["type"] == "websocket.disconnect":
|
||||
raise StopAsyncIteration
|
||||
if "bytes" in message and message["bytes"] is not None:
|
||||
return message["bytes"]
|
||||
if "text" in message and message["text"] is not None:
|
||||
return message["text"]
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class FastAPIWebsocketClient:
|
||||
"""WebSocket client wrapper for handling connections and message passing.
|
||||
|
||||
@@ -84,17 +104,15 @@ class FastAPIWebsocketClient:
|
||||
with support for both binary and text message types.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks):
|
||||
def __init__(self, websocket: WebSocket, callbacks: FastAPIWebsocketCallbacks):
|
||||
"""Initialize the WebSocket client.
|
||||
|
||||
Args:
|
||||
websocket: The FastAPI WebSocket connection.
|
||||
is_binary: Whether to use binary message format.
|
||||
callbacks: Event callback functions.
|
||||
"""
|
||||
self._websocket = websocket
|
||||
self._closing = False
|
||||
self._is_binary = is_binary
|
||||
self._callbacks = callbacks
|
||||
self._leave_counter = 0
|
||||
|
||||
@@ -110,9 +128,9 @@ class FastAPIWebsocketClient:
|
||||
"""Get an async iterator for receiving WebSocket messages.
|
||||
|
||||
Returns:
|
||||
An async iterator yielding bytes or strings based on message type.
|
||||
An async iterator yielding bytes or strings.
|
||||
"""
|
||||
return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text()
|
||||
return _WebSocketMessageIterator(self._websocket)
|
||||
|
||||
async def send(self, data: str | bytes):
|
||||
"""Send data through the WebSocket connection.
|
||||
@@ -122,7 +140,7 @@ class FastAPIWebsocketClient:
|
||||
"""
|
||||
try:
|
||||
if self._can_send():
|
||||
if self._is_binary:
|
||||
if isinstance(data, bytes):
|
||||
await self._websocket.send_bytes(data)
|
||||
else:
|
||||
await self._websocket.send_text(data)
|
||||
@@ -510,10 +528,7 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
on_session_timeout=self._on_session_timeout,
|
||||
)
|
||||
|
||||
is_binary = False
|
||||
if self._params.serializer:
|
||||
is_binary = self._params.serializer.type == FrameSerializerType.BINARY
|
||||
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
|
||||
self._client = FastAPIWebsocketClient(websocket, self._callbacks)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
self, self._client, self._params, name=self._input_name
|
||||
|
||||
70
tests/test_fastapi_websocket.py
Normal file
70
tests/test_fastapi_websocket.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.transports.websocket.fastapi import _WebSocketMessageIterator
|
||||
|
||||
|
||||
class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_yields_binary_message(self):
|
||||
mock_websocket = AsyncMock()
|
||||
mock_websocket.receive.side_effect = [
|
||||
{"type": "websocket.receive", "bytes": b"binary data", "text": None},
|
||||
{"type": "websocket.disconnect"},
|
||||
]
|
||||
|
||||
iterator = _WebSocketMessageIterator(mock_websocket)
|
||||
messages = [msg async for msg in iterator]
|
||||
|
||||
self.assertEqual(len(messages), 1)
|
||||
self.assertEqual(messages[0], b"binary data")
|
||||
|
||||
async def test_yields_text_message(self):
|
||||
mock_websocket = AsyncMock()
|
||||
mock_websocket.receive.side_effect = [
|
||||
{"type": "websocket.receive", "bytes": None, "text": "text data"},
|
||||
{"type": "websocket.disconnect"},
|
||||
]
|
||||
|
||||
iterator = _WebSocketMessageIterator(mock_websocket)
|
||||
messages = [msg async for msg in iterator]
|
||||
|
||||
self.assertEqual(len(messages), 1)
|
||||
self.assertEqual(messages[0], "text data")
|
||||
|
||||
async def test_yields_mixed_messages(self):
|
||||
mock_websocket = AsyncMock()
|
||||
mock_websocket.receive.side_effect = [
|
||||
{"type": "websocket.receive", "bytes": b"binary", "text": None},
|
||||
{"type": "websocket.receive", "bytes": None, "text": "text"},
|
||||
{"type": "websocket.receive", "bytes": b"more binary", "text": None},
|
||||
{"type": "websocket.disconnect"},
|
||||
]
|
||||
|
||||
iterator = _WebSocketMessageIterator(mock_websocket)
|
||||
messages = [msg async for msg in iterator]
|
||||
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertEqual(messages[0], b"binary")
|
||||
self.assertEqual(messages[1], "text")
|
||||
self.assertEqual(messages[2], b"more binary")
|
||||
|
||||
async def test_stops_on_disconnect(self):
|
||||
mock_websocket = AsyncMock()
|
||||
mock_websocket.receive.side_effect = [
|
||||
{"type": "websocket.disconnect"},
|
||||
]
|
||||
|
||||
iterator = _WebSocketMessageIterator(mock_websocket)
|
||||
messages = [msg async for msg in iterator]
|
||||
|
||||
self.assertEqual(len(messages), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -21,7 +21,7 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_roundtrip(self):
|
||||
text_frame = TextFrame(text="hello world")
|
||||
frame = await self.serializer.deserialize(await self.serializer.serialize(text_frame))
|
||||
self.assertEqual(text_frame, frame)
|
||||
self.assertEqual(frame.text, text_frame.text)
|
||||
|
||||
transcription_frame = TranscriptionFrame(
|
||||
text="Hello there!", user_id="123", timestamp="2021-01-01"
|
||||
@@ -29,7 +29,9 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
|
||||
frame = await self.serializer.deserialize(
|
||||
await self.serializer.serialize(transcription_frame)
|
||||
)
|
||||
self.assertEqual(frame, transcription_frame)
|
||||
self.assertEqual(frame.text, transcription_frame.text)
|
||||
self.assertEqual(frame.user_id, transcription_frame.user_id)
|
||||
self.assertEqual(frame.timestamp, transcription_frame.timestamp)
|
||||
|
||||
audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1)
|
||||
frame = await self.serializer.deserialize(await self.serializer.serialize(audio_frame))
|
||||
|
||||
Reference in New Issue
Block a user