serializers(twilio): formatting and allow str | bytes | None
This commit is contained in:
@@ -1,3 +1,9 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
@@ -12,29 +18,36 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.sid = None
|
||||
self._sid = None
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if not isinstance(frame, AudioRawFrame):
|
||||
return None
|
||||
|
||||
def serialize(self, frame: AudioRawFrame) -> dict:
|
||||
data = frame.audio
|
||||
|
||||
serialized_data = pcm_16000_to_ulaw_8000(data)
|
||||
payload = base64.b64encode(serialized_data).decode('utf-8')
|
||||
answer_dict = {"event": "media",
|
||||
"streamSid": self.sid,
|
||||
"media": {"payload": payload}}
|
||||
payload = base64.b64encode(serialized_data).decode("utf-8")
|
||||
answer = {
|
||||
"event": "media",
|
||||
"streamSid": self._sid,
|
||||
"media": {
|
||||
"payload": payload
|
||||
}
|
||||
}
|
||||
|
||||
return answer_dict
|
||||
return json.dumps(answer)
|
||||
|
||||
def deserialize(self, message: bytes) -> AudioRawFrame | None:
|
||||
data = json.loads(message)
|
||||
if not self.sid:
|
||||
self.sid = data['streamSid'] if data.get("streamSid") else None
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
message = json.loads(data)
|
||||
|
||||
if data['event'] != 'media':
|
||||
if not self._sid:
|
||||
self._sid = message["streamSid"] if "streamSid" in message else None
|
||||
|
||||
if message["event"] != "media":
|
||||
return None
|
||||
else:
|
||||
payload_base64 = data['media']['payload']
|
||||
payload_base64 = message["media"]["payload"]
|
||||
payload = base64.b64decode(payload_base64)
|
||||
|
||||
deserialized_data = ulaw_8000_to_pcm_16000(payload)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
from fastapi import WebSocket
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.serializers.TwilioFrameSerializer import TwilioFrameSerializer
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
@@ -16,6 +22,15 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use FastAPI websockets, you need to `pip install pipecat-ai[websocket]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class FastAPIWebsocketParams(TransportParams):
|
||||
add_wav_header: bool = False
|
||||
@@ -30,7 +45,12 @@ class FastAPIWebsocketCallbacks(BaseModel):
|
||||
|
||||
class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, callbacks: FastAPIWebsocketCallbacks, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
params: FastAPIWebsocketParams,
|
||||
callbacks: FastAPIWebsocketCallbacks,
|
||||
**kwargs):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
@@ -43,7 +63,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
|
||||
|
||||
async def stop(self):
|
||||
await self._websocket.close()
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
await super().stop()
|
||||
|
||||
async def _receive_messages(self):
|
||||
@@ -58,6 +79,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
|
||||
|
||||
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
|
||||
@@ -92,17 +114,23 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
frame = wav_frame
|
||||
|
||||
payload = self._params.serializer.serialize(frame)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._websocket.send_json(payload), self.get_event_loop())
|
||||
future.result()
|
||||
if payload:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._websocket.send_text(payload), self.get_event_loop())
|
||||
future.result()
|
||||
|
||||
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
|
||||
|
||||
|
||||
class FastAPIWebsocketTransport(BaseTransport):
|
||||
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams = FastAPIWebsocketParams(), input_name: str | None = None, output_name: str | None = None, loop: asyncio.AbstractEventLoop | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
params: FastAPIWebsocketParams = FastAPIWebsocketParams(),
|
||||
input_name: str | None = None,
|
||||
output_name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None):
|
||||
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
|
||||
self._params = params
|
||||
|
||||
@@ -111,8 +139,10 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
on_client_disconnected=self._on_client_disconnected
|
||||
)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(websocket, self._params, self._callbacks, name=self._input_name)
|
||||
self._output = FastAPIWebsocketOutputTransport(websocket, self._params, name=self._output_name)
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
websocket, self._params, self._callbacks, name=self._input_name)
|
||||
self._output = FastAPIWebsocketOutputTransport(
|
||||
websocket, self._params, name=self._output_name)
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
|
||||
Reference in New Issue
Block a user