Merge pull request #3379 from lukepayyapilli/fix/fastapi-websocket-json-text-handling

Fix FastAPIWebsocketTransport to handle both binary and text messages
This commit is contained in:
Mark Backman
2026-01-08 17:26:35 -05:00
committed by GitHub
11 changed files with 109 additions and 90 deletions

View File

@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: |
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
- name: Run tests with coverage
run: |

View File

@@ -37,7 +37,7 @@ jobs:
- name: Install dependencies
run: |
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
- name: Test with pytest
run: |

View File

@@ -7,41 +7,18 @@
"""Frame serialization interfaces for Pipecat."""
from abc import ABC, abstractmethod
from enum import Enum
from pipecat.frames.frames import Frame, StartFrame
class FrameSerializerType(Enum):
"""Enumeration of supported frame serialization formats.
Parameters:
BINARY: Binary serialization format for compact representation.
TEXT: Text-based serialization format for human-readable output.
"""
BINARY = "binary"
TEXT = "text"
class FrameSerializer(ABC):
"""Abstract base class for frame serialization implementations.
Defines the interface for converting frames to/from serialized formats
for transmission or storage. Subclasses must implement serialization
type detection and the core serialize/deserialize methods.
for transmission or storage. Subclasses must implement the core
serialize/deserialize methods.
"""
@property
@abstractmethod
def type(self) -> FrameSerializerType:
"""Get the serialization type supported by this serializer.
Returns:
The FrameSerializerType indicating binary or text format.
"""
pass
async def setup(self, frame: StartFrame):
"""Initialize the serializer with startup configuration.

View File

@@ -25,7 +25,7 @@ from pipecat.frames.frames import (
OutputTransportMessageUrgentFrame,
StartFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
class ExotelFrameSerializer(FrameSerializer):
@@ -70,15 +70,6 @@ class ExotelFrameSerializer(FrameSerializer):
self._input_resampler = create_stream_resampler()
self._output_resampler = create_stream_resampler()
@property
def type(self) -> FrameSerializerType:
"""Gets the serializer type.
Returns:
The serializer type, either TEXT or BINARY.
"""
return FrameSerializerType.TEXT
async def setup(self, frame: StartFrame):
"""Sets up the serializer with pipeline configuration.

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
OutputTransportMessageUrgentFrame,
StartFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
class PlivoFrameSerializer(FrameSerializer):
@@ -85,15 +85,6 @@ class PlivoFrameSerializer(FrameSerializer):
self._output_resampler = create_stream_resampler()
self._hangup_attempted = False
@property
def type(self) -> FrameSerializerType:
"""Gets the serializer type.
Returns:
The serializer type, either TEXT or BINARY.
"""
return FrameSerializerType.TEXT
async def setup(self, frame: StartFrame):
"""Sets up the serializer with pipeline configuration.

View File

@@ -22,7 +22,7 @@ from pipecat.frames.frames import (
TextFrame,
TranscriptionFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
@dataclasses.dataclass
@@ -64,15 +64,6 @@ class ProtobufFrameSerializer(FrameSerializer):
"""Initialize the Protobuf frame serializer."""
pass
@property
def type(self) -> FrameSerializerType:
"""Get the serializer type.
Returns:
FrameSerializerType.BINARY indicating binary serialization format.
"""
return FrameSerializerType.BINARY
async def serialize(self, frame: Frame) -> str | bytes | None:
"""Serialize a frame to Protocol Buffer binary format.

View File

@@ -32,7 +32,7 @@ from pipecat.frames.frames import (
InterruptionFrame,
StartFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
class TelnyxFrameSerializer(FrameSerializer):
@@ -97,15 +97,6 @@ class TelnyxFrameSerializer(FrameSerializer):
self._output_resampler = create_stream_resampler()
self._hangup_attempted = False
@property
def type(self) -> FrameSerializerType:
"""Gets the serializer type.
Returns:
The serializer type, either TEXT or BINARY.
"""
return FrameSerializerType.TEXT
async def setup(self, frame: StartFrame):
"""Sets up the serializer with pipeline configuration.

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
OutputTransportMessageUrgentFrame,
StartFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
class TwilioFrameSerializer(FrameSerializer):
@@ -116,15 +116,6 @@ class TwilioFrameSerializer(FrameSerializer):
self._output_resampler = create_stream_resampler()
self._hangup_attempted = False
@property
def type(self) -> FrameSerializerType:
"""Gets the serializer type.
Returns:
The serializer type, either TEXT or BINARY.
"""
return FrameSerializerType.TEXT
async def setup(self, frame: StartFrame):
"""Sets up the serializer with pipeline configuration.

View File

@@ -33,7 +33,7 @@ from pipecat.frames.frames import (
StartFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -77,6 +77,26 @@ class FastAPIWebsocketCallbacks(BaseModel):
on_session_timeout: Callable[[WebSocket], Awaitable[None]]
class _WebSocketMessageIterator:
"""Async iterator for WebSocket messages that yields both binary and text."""
def __init__(self, websocket: WebSocket):
self._websocket = websocket
def __aiter__(self):
return self
async def __anext__(self) -> bytes | str:
message = await self._websocket.receive()
if message["type"] == "websocket.disconnect":
raise StopAsyncIteration
if "bytes" in message and message["bytes"] is not None:
return message["bytes"]
if "text" in message and message["text"] is not None:
return message["text"]
raise StopAsyncIteration
class FastAPIWebsocketClient:
"""WebSocket client wrapper for handling connections and message passing.
@@ -84,17 +104,15 @@ class FastAPIWebsocketClient:
with support for both binary and text message types.
"""
def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks):
def __init__(self, websocket: WebSocket, callbacks: FastAPIWebsocketCallbacks):
"""Initialize the WebSocket client.
Args:
websocket: The FastAPI WebSocket connection.
is_binary: Whether to use binary message format.
callbacks: Event callback functions.
"""
self._websocket = websocket
self._closing = False
self._is_binary = is_binary
self._callbacks = callbacks
self._leave_counter = 0
@@ -110,9 +128,9 @@ class FastAPIWebsocketClient:
"""Get an async iterator for receiving WebSocket messages.
Returns:
An async iterator yielding bytes or strings based on message type.
An async iterator yielding bytes or strings.
"""
return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text()
return _WebSocketMessageIterator(self._websocket)
async def send(self, data: str | bytes):
"""Send data through the WebSocket connection.
@@ -122,7 +140,7 @@ class FastAPIWebsocketClient:
"""
try:
if self._can_send():
if self._is_binary:
if isinstance(data, bytes):
await self._websocket.send_bytes(data)
else:
await self._websocket.send_text(data)
@@ -510,10 +528,7 @@ class FastAPIWebsocketTransport(BaseTransport):
on_session_timeout=self._on_session_timeout,
)
is_binary = False
if self._params.serializer:
is_binary = self._params.serializer.type == FrameSerializerType.BINARY
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
self._client = FastAPIWebsocketClient(websocket, self._callbacks)
self._input = FastAPIWebsocketInputTransport(
self, self._client, self._params, name=self._input_name

View File

@@ -0,0 +1,70 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from unittest.mock import AsyncMock
from pipecat.transports.websocket.fastapi import _WebSocketMessageIterator
class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase):
async def test_yields_binary_message(self):
mock_websocket = AsyncMock()
mock_websocket.receive.side_effect = [
{"type": "websocket.receive", "bytes": b"binary data", "text": None},
{"type": "websocket.disconnect"},
]
iterator = _WebSocketMessageIterator(mock_websocket)
messages = [msg async for msg in iterator]
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0], b"binary data")
async def test_yields_text_message(self):
mock_websocket = AsyncMock()
mock_websocket.receive.side_effect = [
{"type": "websocket.receive", "bytes": None, "text": "text data"},
{"type": "websocket.disconnect"},
]
iterator = _WebSocketMessageIterator(mock_websocket)
messages = [msg async for msg in iterator]
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0], "text data")
async def test_yields_mixed_messages(self):
mock_websocket = AsyncMock()
mock_websocket.receive.side_effect = [
{"type": "websocket.receive", "bytes": b"binary", "text": None},
{"type": "websocket.receive", "bytes": None, "text": "text"},
{"type": "websocket.receive", "bytes": b"more binary", "text": None},
{"type": "websocket.disconnect"},
]
iterator = _WebSocketMessageIterator(mock_websocket)
messages = [msg async for msg in iterator]
self.assertEqual(len(messages), 3)
self.assertEqual(messages[0], b"binary")
self.assertEqual(messages[1], "text")
self.assertEqual(messages[2], b"more binary")
async def test_stops_on_disconnect(self):
mock_websocket = AsyncMock()
mock_websocket.receive.side_effect = [
{"type": "websocket.disconnect"},
]
iterator = _WebSocketMessageIterator(mock_websocket)
messages = [msg async for msg in iterator]
self.assertEqual(len(messages), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -21,7 +21,7 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
async def test_roundtrip(self):
text_frame = TextFrame(text="hello world")
frame = await self.serializer.deserialize(await self.serializer.serialize(text_frame))
self.assertEqual(text_frame, frame)
self.assertEqual(frame.text, text_frame.text)
transcription_frame = TranscriptionFrame(
text="Hello there!", user_id="123", timestamp="2021-01-01"
@@ -29,7 +29,9 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
frame = await self.serializer.deserialize(
await self.serializer.serialize(transcription_frame)
)
self.assertEqual(frame, transcription_frame)
self.assertEqual(frame.text, transcription_frame.text)
self.assertEqual(frame.user_id, transcription_frame.user_id)
self.assertEqual(frame.timestamp, transcription_frame.timestamp)
audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1)
frame = await self.serializer.deserialize(await self.serializer.serialize(audio_frame))