serializers: serialize()/deserialize() are now async
This commit is contained in:
@@ -1544,6 +1544,9 @@ async def on_connected(processor):
|
||||
|
||||
### Changed
|
||||
|
||||
- `FrameSerializer.serialize()` and `FrameSerializer.deserialize()` are now
|
||||
`async`.
|
||||
|
||||
- `Filter` has been renamed to `FrameFilter` and it's now under
|
||||
`processors/filters`.
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class AsyncGeneratorProcessor(FrameProcessor):
|
||||
if isinstance(frame, (CancelFrame, EndFrame)):
|
||||
await self._data_queue.put(None)
|
||||
else:
|
||||
data = self._serializer.serialize(frame)
|
||||
data = await self._serializer.serialize(frame)
|
||||
if data:
|
||||
await self._data_queue.put(data)
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ class FrameSerializer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
pass
|
||||
|
||||
@@ -25,7 +25,7 @@ class LivekitFrameSerializer(FrameSerializer):
|
||||
def type(self) -> FrameSerializerType:
|
||||
return FrameSerializerType.BINARY
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if not isinstance(frame, OutputAudioRawFrame):
|
||||
return None
|
||||
audio_frame = AudioFrame(
|
||||
@@ -36,7 +36,7 @@ class LivekitFrameSerializer(FrameSerializer):
|
||||
)
|
||||
return pickle.dumps(audio_frame)
|
||||
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
audio_frame: AudioFrame = pickle.loads(data)["frame"]
|
||||
return InputAudioRawFrame(
|
||||
audio=bytes(audio_frame.data),
|
||||
|
||||
@@ -41,7 +41,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
def type(self) -> FrameSerializerType:
|
||||
return FrameSerializerType.BINARY
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
proto_frame = frame_protos.Frame()
|
||||
if type(frame) not in self.SERIALIZABLE_TYPES:
|
||||
logger.warning(f"Frame type {type(frame)} is not serializable")
|
||||
@@ -57,26 +57,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
|
||||
return proto_frame.SerializeToString()
|
||||
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Returns a Frame object from a Frame protobuf.
|
||||
|
||||
Used to convert frames
|
||||
passed over the wire as protobufs to Frame objects used in pipelines
|
||||
and frame processors.
|
||||
|
||||
>>> serializer = ProtobufFrameSerializer()
|
||||
>>> serializer.deserialize(
|
||||
... serializer.serialize(OutputAudioFrame(data=b'1234567890')))
|
||||
InputAudioFrame(data=b'1234567890')
|
||||
|
||||
>>> serializer.deserialize(
|
||||
... serializer.serialize(TextFrame(text='hello world')))
|
||||
TextFrame(text='hello world')
|
||||
|
||||
>>> serializer.deserialize(serializer.serialize(TranscriptionFrame(
|
||||
... text="Hello there!", participantId="123", timestamp="2021-01-01")))
|
||||
TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
|
||||
"""
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
proto = frame_protos.Frame.FromString(data)
|
||||
which = proto.WhichOneof("frame")
|
||||
if which not in self.DESERIALIZABLE_FIELDS:
|
||||
|
||||
@@ -53,7 +53,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
def type(self) -> FrameSerializerType:
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
data = frame.audio
|
||||
|
||||
@@ -80,7 +80,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
answer = {"event": "clear"}
|
||||
return json.dumps(answer)
|
||||
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
message = json.loads(data)
|
||||
|
||||
if message["event"] == "media":
|
||||
|
||||
@@ -38,7 +38,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
def type(self) -> FrameSerializerType:
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
@@ -59,7 +59,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
elif isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)):
|
||||
return json.dumps(frame.message)
|
||||
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
message = json.loads(data)
|
||||
|
||||
if message["event"] == "media":
|
||||
|
||||
@@ -91,7 +91,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
async for message in self._iter_data():
|
||||
frame = self._params.serializer.deserialize(message)
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
continue
|
||||
@@ -163,7 +163,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
try:
|
||||
payload = self._params.serializer.serialize(frame)
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self._send_data(payload)
|
||||
except Exception as e:
|
||||
|
||||
@@ -138,7 +138,7 @@ class WebsocketClientInputTransport(BaseInputTransport):
|
||||
await self._session.disconnect()
|
||||
|
||||
async def on_message(self, websocket, message):
|
||||
frame = self._params.serializer.deserialize(message)
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
if not frame:
|
||||
return
|
||||
if isinstance(frame, InputAudioRawFrame) and self._params.audio_in_enabled:
|
||||
@@ -200,7 +200,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
|
||||
await self._write_audio_sleep()
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
payload = self._params.serializer.serialize(frame)
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload:
|
||||
await self._session.send(payload)
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
# Handle incoming messages
|
||||
try:
|
||||
async for message in websocket:
|
||||
frame = self._params.serializer.deserialize(message)
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
continue
|
||||
@@ -193,7 +193,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
try:
|
||||
payload = self._params.serializer.serialize(frame)
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload and self._websocket:
|
||||
await self._websocket.send(payload)
|
||||
except Exception as e:
|
||||
|
||||
@@ -385,7 +385,9 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
audio_data = await self._client.get_next_audio_frame()
|
||||
if audio_data:
|
||||
audio_frame_event, participant_id = audio_data
|
||||
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
|
||||
pipecat_audio_frame = await self._convert_livekit_audio_to_pipecat(
|
||||
audio_frame_event
|
||||
)
|
||||
input_audio_frame = InputAudioRawFrame(
|
||||
audio=pipecat_audio_frame.audio,
|
||||
sample_rate=pipecat_audio_frame.sample_rate,
|
||||
@@ -393,12 +395,12 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
)
|
||||
await self.push_audio_frame(input_audio_frame)
|
||||
|
||||
def _convert_livekit_audio_to_pipecat(
|
||||
async def _convert_livekit_audio_to_pipecat(
|
||||
self, audio_frame_event: rtc.AudioFrameEvent
|
||||
) -> AudioRawFrame:
|
||||
audio_frame = audio_frame_event.frame
|
||||
|
||||
audio_data = self._resampler.resample(
|
||||
audio_data = await self._resampler.resample(
|
||||
audio_frame.data.tobytes(), audio_frame.sample_rate, self._params.audio_in_sample_rate
|
||||
)
|
||||
|
||||
|
||||
@@ -20,17 +20,19 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_roundtrip(self):
|
||||
text_frame = TextFrame(text="hello world")
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(text_frame))
|
||||
frame = await self.serializer.deserialize(await self.serializer.serialize(text_frame))
|
||||
self.assertEqual(text_frame, frame)
|
||||
|
||||
transcription_frame = TranscriptionFrame(
|
||||
text="Hello there!", user_id="123", timestamp="2021-01-01"
|
||||
)
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(transcription_frame))
|
||||
frame = await self.serializer.deserialize(
|
||||
await self.serializer.serialize(transcription_frame)
|
||||
)
|
||||
self.assertEqual(frame, transcription_frame)
|
||||
|
||||
audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1)
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(audio_frame))
|
||||
frame = await self.serializer.deserialize(await self.serializer.serialize(audio_frame))
|
||||
self.assertEqual(frame.audio, audio_frame.audio)
|
||||
self.assertEqual(frame.sample_rate, audio_frame.sample_rate)
|
||||
self.assertEqual(frame.num_channels, audio_frame.num_channels)
|
||||
|
||||
Reference in New Issue
Block a user