serializers(twilio): formatting and allow str | bytes | None

This commit is contained in:
Aleix Conchillo Flaqué
2024-06-21 08:25:35 -07:00
parent 25ef0cb87b
commit b62227b4ae
2 changed files with 67 additions and 24 deletions

View File

@@ -1,3 +1,9 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
import json
@@ -12,29 +18,36 @@ class TwilioFrameSerializer(FrameSerializer):
}
def __init__(self):
self.sid = None
self._sid = None
def serialize(self, frame: Frame) -> str | bytes | None:
if not isinstance(frame, AudioRawFrame):
return None
def serialize(self, frame: AudioRawFrame) -> dict:
data = frame.audio
serialized_data = pcm_16000_to_ulaw_8000(data)
payload = base64.b64encode(serialized_data).decode('utf-8')
answer_dict = {"event": "media",
"streamSid": self.sid,
"media": {"payload": payload}}
payload = base64.b64encode(serialized_data).decode("utf-8")
answer = {
"event": "media",
"streamSid": self._sid,
"media": {
"payload": payload
}
}
return answer_dict
return json.dumps(answer)
def deserialize(self, message: bytes) -> AudioRawFrame | None:
data = json.loads(message)
if not self.sid:
self.sid = data['streamSid'] if data.get("streamSid") else None
def deserialize(self, data: str | bytes) -> Frame | None:
message = json.loads(data)
if data['event'] != 'media':
if not self._sid:
self._sid = message["streamSid"] if "streamSid" in message else None
if message["event"] != "media":
return None
else:
payload_base64 = data['media']['payload']
payload_base64 = message["media"]["payload"]
payload = base64.b64decode(payload_base64)
deserialized_data = ulaw_8000_to_pcm_16000(payload)

View File

@@ -1,12 +1,18 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import io
import wave
from fastapi import WebSocket
from typing import Awaitable, Callable
from pydantic.main import BaseModel
from pipecat.serializers.TwilioFrameSerializer import TwilioFrameSerializer
from pipecat.serializers.twilio import TwilioFrameSerializer
from pipecat.frames.frames import AudioRawFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.serializers.base_serializer import FrameSerializer
@@ -16,6 +22,15 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams
from loguru import logger
try:
from fastapi import WebSocket
from starlette.websockets import WebSocketState
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use FastAPI websockets, you need to `pip install pipecat-ai[websocket]`.")
raise Exception(f"Missing module: {e}")
class FastAPIWebsocketParams(TransportParams):
add_wav_header: bool = False
@@ -30,7 +45,12 @@ class FastAPIWebsocketCallbacks(BaseModel):
class FastAPIWebsocketInputTransport(BaseInputTransport):
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, callbacks: FastAPIWebsocketCallbacks, **kwargs):
def __init__(
self,
websocket: WebSocket,
params: FastAPIWebsocketParams,
callbacks: FastAPIWebsocketCallbacks,
**kwargs):
super().__init__(params, **kwargs)
self._websocket = websocket
@@ -43,7 +63,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
async def stop(self):
await self._websocket.close()
if self._websocket.client_state != WebSocketState.DISCONNECTED:
await self._websocket.close()
await super().stop()
async def _receive_messages(self):
@@ -58,6 +79,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
await self._callbacks.on_client_disconnected(self._websocket)
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
@@ -92,17 +114,23 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
frame = wav_frame
payload = self._params.serializer.serialize(frame)
future = asyncio.run_coroutine_threadsafe(
self._websocket.send_json(payload), self.get_event_loop())
future.result()
if payload:
future = asyncio.run_coroutine_threadsafe(
self._websocket.send_text(payload), self.get_event_loop())
future.result()
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
class FastAPIWebsocketTransport(BaseTransport):
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams = FastAPIWebsocketParams(), input_name: str | None = None, output_name: str | None = None, loop: asyncio.AbstractEventLoop | None = None):
def __init__(
self,
websocket: WebSocket,
params: FastAPIWebsocketParams = FastAPIWebsocketParams(),
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None):
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
self._params = params
@@ -111,8 +139,10 @@ class FastAPIWebsocketTransport(BaseTransport):
on_client_disconnected=self._on_client_disconnected
)
self._input = FastAPIWebsocketInputTransport(websocket, self._params, self._callbacks, name=self._input_name)
self._output = FastAPIWebsocketOutputTransport(websocket, self._params, name=self._output_name)
self._input = FastAPIWebsocketInputTransport(
websocket, self._params, self._callbacks, name=self._input_name)
self._output = FastAPIWebsocketOutputTransport(
websocket, self._params, name=self._output_name)
# Register supported handlers. The user will only be able to register
# these handlers.