diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 9fad65dbc..0dd30f9e5 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -33,7 +33,7 @@ jobs: - name: Install dependencies run: | - uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain + uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket - name: Run tests with coverage run: | diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 857ebb489..8e58845e4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies run: | - uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain + uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket - name: Test with pytest run: | diff --git a/src/pipecat/serializers/base_serializer.py b/src/pipecat/serializers/base_serializer.py index 823d00c7b..490951a69 100644 --- a/src/pipecat/serializers/base_serializer.py +++ b/src/pipecat/serializers/base_serializer.py @@ -7,41 +7,18 @@ """Frame serialization interfaces for Pipecat.""" from abc import ABC, abstractmethod -from enum import Enum from pipecat.frames.frames import Frame, StartFrame -class FrameSerializerType(Enum): - """Enumeration of supported frame serialization formats. - - Parameters: - BINARY: Binary serialization format for compact representation. - TEXT: Text-based serialization format for human-readable output. - """ - - BINARY = "binary" - TEXT = "text" - - class FrameSerializer(ABC): """Abstract base class for frame serialization implementations. Defines the interface for converting frames to/from serialized formats - for transmission or storage. Subclasses must implement serialization - type detection and the core serialize/deserialize methods. + for transmission or storage. Subclasses must implement the core + serialize/deserialize methods. """ - @property - @abstractmethod - def type(self) -> FrameSerializerType: - """Get the serialization type supported by this serializer. - - Returns: - The FrameSerializerType indicating binary or text format. - """ - pass - async def setup(self, frame: StartFrame): """Initialize the serializer with startup configuration. diff --git a/src/pipecat/serializers/exotel.py b/src/pipecat/serializers/exotel.py index f9c2b2e5f..61d6eeada 100644 --- a/src/pipecat/serializers/exotel.py +++ b/src/pipecat/serializers/exotel.py @@ -25,7 +25,7 @@ from pipecat.frames.frames import ( OutputTransportMessageUrgentFrame, StartFrame, ) -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer class ExotelFrameSerializer(FrameSerializer): @@ -70,15 +70,6 @@ class ExotelFrameSerializer(FrameSerializer): self._input_resampler = create_stream_resampler() self._output_resampler = create_stream_resampler() - @property - def type(self) -> FrameSerializerType: - """Gets the serializer type. - - Returns: - The serializer type, either TEXT or BINARY. - """ - return FrameSerializerType.TEXT - async def setup(self, frame: StartFrame): """Sets up the serializer with pipeline configuration. diff --git a/src/pipecat/serializers/plivo.py b/src/pipecat/serializers/plivo.py index d54f86fe9..2a57d3698 100644 --- a/src/pipecat/serializers/plivo.py +++ b/src/pipecat/serializers/plivo.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( OutputTransportMessageUrgentFrame, StartFrame, ) -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer class PlivoFrameSerializer(FrameSerializer): @@ -85,15 +85,6 @@ class PlivoFrameSerializer(FrameSerializer): self._output_resampler = create_stream_resampler() self._hangup_attempted = False - @property - def type(self) -> FrameSerializerType: - """Gets the serializer type. - - Returns: - The serializer type, either TEXT or BINARY. - """ - return FrameSerializerType.TEXT - async def setup(self, frame: StartFrame): """Sets up the serializer with pipeline configuration. diff --git a/src/pipecat/serializers/protobuf.py b/src/pipecat/serializers/protobuf.py index f079a1f72..6d989c7dd 100644 --- a/src/pipecat/serializers/protobuf.py +++ b/src/pipecat/serializers/protobuf.py @@ -22,7 +22,7 @@ from pipecat.frames.frames import ( TextFrame, TranscriptionFrame, ) -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer @dataclasses.dataclass @@ -64,15 +64,6 @@ class ProtobufFrameSerializer(FrameSerializer): """Initialize the Protobuf frame serializer.""" pass - @property - def type(self) -> FrameSerializerType: - """Get the serializer type. - - Returns: - FrameSerializerType.BINARY indicating binary serialization format. - """ - return FrameSerializerType.BINARY - async def serialize(self, frame: Frame) -> str | bytes | None: """Serialize a frame to Protocol Buffer binary format. diff --git a/src/pipecat/serializers/telnyx.py b/src/pipecat/serializers/telnyx.py index a9e837f06..769244f93 100644 --- a/src/pipecat/serializers/telnyx.py +++ b/src/pipecat/serializers/telnyx.py @@ -32,7 +32,7 @@ from pipecat.frames.frames import ( InterruptionFrame, StartFrame, ) -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer class TelnyxFrameSerializer(FrameSerializer): @@ -97,15 +97,6 @@ class TelnyxFrameSerializer(FrameSerializer): self._output_resampler = create_stream_resampler() self._hangup_attempted = False - @property - def type(self) -> FrameSerializerType: - """Gets the serializer type. - - Returns: - The serializer type, either TEXT or BINARY. - """ - return FrameSerializerType.TEXT - async def setup(self, frame: StartFrame): """Sets up the serializer with pipeline configuration. diff --git a/src/pipecat/serializers/twilio.py b/src/pipecat/serializers/twilio.py index c9569044c..2e60399fd 100644 --- a/src/pipecat/serializers/twilio.py +++ b/src/pipecat/serializers/twilio.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( OutputTransportMessageUrgentFrame, StartFrame, ) -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer class TwilioFrameSerializer(FrameSerializer): @@ -116,15 +116,6 @@ class TwilioFrameSerializer(FrameSerializer): self._output_resampler = create_stream_resampler() self._hangup_attempted = False - @property - def type(self) -> FrameSerializerType: - """Gets the serializer type. - - Returns: - The serializer type, either TEXT or BINARY. - """ - return FrameSerializerType.TEXT - async def setup(self, frame: StartFrame): """Sets up the serializer with pipeline configuration. diff --git a/src/pipecat/transports/websocket/fastapi.py b/src/pipecat/transports/websocket/fastapi.py index ca328e097..1bcc59e8b 100644 --- a/src/pipecat/transports/websocket/fastapi.py +++ b/src/pipecat/transports/websocket/fastapi.py @@ -33,7 +33,7 @@ from pipecat.frames.frames import ( StartFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +from pipecat.serializers.base_serializer import FrameSerializer from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -77,6 +77,26 @@ class FastAPIWebsocketCallbacks(BaseModel): on_session_timeout: Callable[[WebSocket], Awaitable[None]] +class _WebSocketMessageIterator: + """Async iterator for WebSocket messages that yields both binary and text.""" + + def __init__(self, websocket: WebSocket): + self._websocket = websocket + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes | str: + message = await self._websocket.receive() + if message["type"] == "websocket.disconnect": + raise StopAsyncIteration + if "bytes" in message and message["bytes"] is not None: + return message["bytes"] + if "text" in message and message["text"] is not None: + return message["text"] + raise StopAsyncIteration + + class FastAPIWebsocketClient: """WebSocket client wrapper for handling connections and message passing. @@ -84,17 +104,15 @@ class FastAPIWebsocketClient: with support for both binary and text message types. """ - def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks): + def __init__(self, websocket: WebSocket, callbacks: FastAPIWebsocketCallbacks): """Initialize the WebSocket client. Args: websocket: The FastAPI WebSocket connection. - is_binary: Whether to use binary message format. callbacks: Event callback functions. """ self._websocket = websocket self._closing = False - self._is_binary = is_binary self._callbacks = callbacks self._leave_counter = 0 @@ -110,9 +128,9 @@ class FastAPIWebsocketClient: """Get an async iterator for receiving WebSocket messages. Returns: - An async iterator yielding bytes or strings based on message type. + An async iterator yielding bytes or strings. """ - return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text() + return _WebSocketMessageIterator(self._websocket) async def send(self, data: str | bytes): """Send data through the WebSocket connection. @@ -122,7 +140,7 @@ class FastAPIWebsocketClient: """ try: if self._can_send(): - if self._is_binary: + if isinstance(data, bytes): await self._websocket.send_bytes(data) else: await self._websocket.send_text(data) @@ -510,10 +528,7 @@ class FastAPIWebsocketTransport(BaseTransport): on_session_timeout=self._on_session_timeout, ) - is_binary = False - if self._params.serializer: - is_binary = self._params.serializer.type == FrameSerializerType.BINARY - self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks) + self._client = FastAPIWebsocketClient(websocket, self._callbacks) self._input = FastAPIWebsocketInputTransport( self, self._client, self._params, name=self._input_name diff --git a/tests/test_fastapi_websocket.py b/tests/test_fastapi_websocket.py new file mode 100644 index 000000000..44a1ff61a --- /dev/null +++ b/tests/test_fastapi_websocket.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest +from unittest.mock import AsyncMock + +from pipecat.transports.websocket.fastapi import _WebSocketMessageIterator + + +class TestWebSocketMessageIterator(unittest.IsolatedAsyncioTestCase): + async def test_yields_binary_message(self): + mock_websocket = AsyncMock() + mock_websocket.receive.side_effect = [ + {"type": "websocket.receive", "bytes": b"binary data", "text": None}, + {"type": "websocket.disconnect"}, + ] + + iterator = _WebSocketMessageIterator(mock_websocket) + messages = [msg async for msg in iterator] + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0], b"binary data") + + async def test_yields_text_message(self): + mock_websocket = AsyncMock() + mock_websocket.receive.side_effect = [ + {"type": "websocket.receive", "bytes": None, "text": "text data"}, + {"type": "websocket.disconnect"}, + ] + + iterator = _WebSocketMessageIterator(mock_websocket) + messages = [msg async for msg in iterator] + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0], "text data") + + async def test_yields_mixed_messages(self): + mock_websocket = AsyncMock() + mock_websocket.receive.side_effect = [ + {"type": "websocket.receive", "bytes": b"binary", "text": None}, + {"type": "websocket.receive", "bytes": None, "text": "text"}, + {"type": "websocket.receive", "bytes": b"more binary", "text": None}, + {"type": "websocket.disconnect"}, + ] + + iterator = _WebSocketMessageIterator(mock_websocket) + messages = [msg async for msg in iterator] + + self.assertEqual(len(messages), 3) + self.assertEqual(messages[0], b"binary") + self.assertEqual(messages[1], "text") + self.assertEqual(messages[2], b"more binary") + + async def test_stops_on_disconnect(self): + mock_websocket = AsyncMock() + mock_websocket.receive.side_effect = [ + {"type": "websocket.disconnect"}, + ] + + iterator = _WebSocketMessageIterator(mock_websocket) + messages = [msg async for msg in iterator] + + self.assertEqual(len(messages), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_protobuf_serializer.py b/tests/test_protobuf_serializer.py index 60a697ae0..99df3f96b 100644 --- a/tests/test_protobuf_serializer.py +++ b/tests/test_protobuf_serializer.py @@ -21,7 +21,7 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase): async def test_roundtrip(self): text_frame = TextFrame(text="hello world") frame = await self.serializer.deserialize(await self.serializer.serialize(text_frame)) - self.assertEqual(text_frame, frame) + self.assertEqual(frame.text, text_frame.text) transcription_frame = TranscriptionFrame( text="Hello there!", user_id="123", timestamp="2021-01-01" @@ -29,7 +29,9 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase): frame = await self.serializer.deserialize( await self.serializer.serialize(transcription_frame) ) - self.assertEqual(frame, transcription_frame) + self.assertEqual(frame.text, transcription_frame.text) + self.assertEqual(frame.user_id, transcription_frame.user_id) + self.assertEqual(frame.timestamp, transcription_frame.timestamp) audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1) frame = await self.serializer.deserialize(await self.serializer.serialize(audio_frame))