diff --git a/src/pipecat/serializers/twilio.py b/src/pipecat/serializers/twilio.py index 5ac6227dd..ea7a562aa 100644 --- a/src/pipecat/serializers/twilio.py +++ b/src/pipecat/serializers/twilio.py @@ -1,3 +1,9 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import base64 import json @@ -12,29 +18,36 @@ class TwilioFrameSerializer(FrameSerializer): } def __init__(self): - self.sid = None + self._sid = None + def serialize(self, frame: Frame) -> str | bytes | None: + if not isinstance(frame, AudioRawFrame): + return None - def serialize(self, frame: AudioRawFrame) -> dict: data = frame.audio serialized_data = pcm_16000_to_ulaw_8000(data) - payload = base64.b64encode(serialized_data).decode('utf-8') - answer_dict = {"event": "media", - "streamSid": self.sid, - "media": {"payload": payload}} + payload = base64.b64encode(serialized_data).decode("utf-8") + answer = { + "event": "media", + "streamSid": self._sid, + "media": { + "payload": payload + } + } - return answer_dict + return json.dumps(answer) - def deserialize(self, message: bytes) -> AudioRawFrame | None: - data = json.loads(message) - if not self.sid: - self.sid = data['streamSid'] if data.get("streamSid") else None + def deserialize(self, data: str | bytes) -> Frame | None: + message = json.loads(data) - if data['event'] != 'media': + if not self._sid: + self._sid = message["streamSid"] if "streamSid" in message else None + + if message["event"] != "media": return None else: - payload_base64 = data['media']['payload'] + payload_base64 = message["media"]["payload"] payload = base64.b64decode(payload_base64) deserialized_data = ulaw_8000_to_pcm_16000(payload) diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index fedcd9b4c..6e5307780 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -1,12 +1,18 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + import asyncio import io import wave -from fastapi import WebSocket from typing import Awaitable, Callable from pydantic.main import BaseModel -from pipecat.serializers.TwilioFrameSerializer import TwilioFrameSerializer +from pipecat.serializers.twilio import TwilioFrameSerializer from pipecat.frames.frames import AudioRawFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.serializers.base_serializer import FrameSerializer @@ -16,6 +22,15 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams from loguru import logger +try: + from fastapi import WebSocket + from starlette.websockets import WebSocketState +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use FastAPI websockets, you need to `pip install pipecat-ai[websocket]`.") + raise Exception(f"Missing module: {e}") + class FastAPIWebsocketParams(TransportParams): add_wav_header: bool = False @@ -30,7 +45,12 @@ class FastAPIWebsocketCallbacks(BaseModel): class FastAPIWebsocketInputTransport(BaseInputTransport): - def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, callbacks: FastAPIWebsocketCallbacks, **kwargs): + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams, + callbacks: FastAPIWebsocketCallbacks, + **kwargs): super().__init__(params, **kwargs) self._websocket = websocket @@ -43,7 +63,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): self._receive_task = self.get_event_loop().create_task(self._receive_messages()) async def stop(self): - await self._websocket.close() + if self._websocket.client_state != WebSocketState.DISCONNECTED: + await self._websocket.close() await super().stop() async def _receive_messages(self): @@ -58,6 +79,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): await self._callbacks.on_client_disconnected(self._websocket) + class FastAPIWebsocketOutputTransport(BaseOutputTransport): def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs): @@ -92,17 +114,23 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): frame = wav_frame payload = self._params.serializer.serialize(frame) - - future = asyncio.run_coroutine_threadsafe( - self._websocket.send_json(payload), self.get_event_loop()) - future.result() + if payload: + future = asyncio.run_coroutine_threadsafe( + self._websocket.send_text(payload), self.get_event_loop()) + future.result() self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:] class FastAPIWebsocketTransport(BaseTransport): - def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams = FastAPIWebsocketParams(), input_name: str | None = None, output_name: str | None = None, loop: asyncio.AbstractEventLoop | None = None): + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams = FastAPIWebsocketParams(), + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None): super().__init__(input_name=input_name, output_name=output_name, loop=loop) self._params = params @@ -111,8 +139,10 @@ class FastAPIWebsocketTransport(BaseTransport): on_client_disconnected=self._on_client_disconnected ) - self._input = FastAPIWebsocketInputTransport(websocket, self._params, self._callbacks, name=self._input_name) - self._output = FastAPIWebsocketOutputTransport(websocket, self._params, name=self._output_name) + self._input = FastAPIWebsocketInputTransport( + websocket, self._params, self._callbacks, name=self._input_name) + self._output = FastAPIWebsocketOutputTransport( + websocket, self._params, name=self._output_name) # Register supported handlers. The user will only be able to register # these handlers.