serializers: serialize()/deserialize() are now async

This commit is contained in:
Aleix Conchillo Flaqué
2025-01-31 12:47:49 -08:00
parent cda34a1320
commit dcfb86583d
12 changed files with 30 additions and 42 deletions

View File

@@ -1544,6 +1544,9 @@ async def on_connected(processor):
### Changed
- `FrameSerializer.serialize()` and `FrameSerializer.deserialize()` are now
`async`.
- `Filter` has been renamed to `FrameFilter` and it's now under
`processors/filters`.

View File

@@ -30,7 +30,7 @@ class AsyncGeneratorProcessor(FrameProcessor):
if isinstance(frame, (CancelFrame, EndFrame)):
await self._data_queue.put(None)
else:
data = self._serializer.serialize(frame)
data = await self._serializer.serialize(frame)
if data:
await self._data_queue.put(data)

View File

@@ -22,9 +22,9 @@ class FrameSerializer(ABC):
pass
@abstractmethod
def serialize(self, frame: Frame) -> str | bytes | None:
async def serialize(self, frame: Frame) -> str | bytes | None:
pass
@abstractmethod
def deserialize(self, data: str | bytes) -> Frame | None:
async def deserialize(self, data: str | bytes) -> Frame | None:
pass

View File

@@ -25,7 +25,7 @@ class LivekitFrameSerializer(FrameSerializer):
def type(self) -> FrameSerializerType:
return FrameSerializerType.BINARY
def serialize(self, frame: Frame) -> str | bytes | None:
async def serialize(self, frame: Frame) -> str | bytes | None:
if not isinstance(frame, OutputAudioRawFrame):
return None
audio_frame = AudioFrame(
@@ -36,7 +36,7 @@ class LivekitFrameSerializer(FrameSerializer):
)
return pickle.dumps(audio_frame)
def deserialize(self, data: str | bytes) -> Frame | None:
async def deserialize(self, data: str | bytes) -> Frame | None:
audio_frame: AudioFrame = pickle.loads(data)["frame"]
return InputAudioRawFrame(
audio=bytes(audio_frame.data),

View File

@@ -41,7 +41,7 @@ class ProtobufFrameSerializer(FrameSerializer):
def type(self) -> FrameSerializerType:
return FrameSerializerType.BINARY
def serialize(self, frame: Frame) -> str | bytes | None:
async def serialize(self, frame: Frame) -> str | bytes | None:
proto_frame = frame_protos.Frame()
if type(frame) not in self.SERIALIZABLE_TYPES:
logger.warning(f"Frame type {type(frame)} is not serializable")
@@ -57,26 +57,7 @@ class ProtobufFrameSerializer(FrameSerializer):
return proto_frame.SerializeToString()
def deserialize(self, data: str | bytes) -> Frame | None:
"""Returns a Frame object from a Frame protobuf.
Used to convert frames
passed over the wire as protobufs to Frame objects used in pipelines
and frame processors.
>>> serializer = ProtobufFrameSerializer()
>>> serializer.deserialize(
... serializer.serialize(OutputAudioFrame(data=b'1234567890')))
InputAudioFrame(data=b'1234567890')
>>> serializer.deserialize(
... serializer.serialize(TextFrame(text='hello world')))
TextFrame(text='hello world')
>>> serializer.deserialize(serializer.serialize(TranscriptionFrame(
... text="Hello there!", participantId="123", timestamp="2021-01-01")))
TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
"""
async def deserialize(self, data: str | bytes) -> Frame | None:
proto = frame_protos.Frame.FromString(data)
which = proto.WhichOneof("frame")
if which not in self.DESERIALIZABLE_FIELDS:

View File

@@ -53,7 +53,7 @@ class TelnyxFrameSerializer(FrameSerializer):
def type(self) -> FrameSerializerType:
return FrameSerializerType.TEXT
def serialize(self, frame: Frame) -> str | bytes | None:
async def serialize(self, frame: Frame) -> str | bytes | None:
if isinstance(frame, AudioRawFrame):
data = frame.audio
@@ -80,7 +80,7 @@ class TelnyxFrameSerializer(FrameSerializer):
answer = {"event": "clear"}
return json.dumps(answer)
def deserialize(self, data: str | bytes) -> Frame | None:
async def deserialize(self, data: str | bytes) -> Frame | None:
message = json.loads(data)
if message["event"] == "media":

View File

@@ -38,7 +38,7 @@ class TwilioFrameSerializer(FrameSerializer):
def type(self) -> FrameSerializerType:
return FrameSerializerType.TEXT
def serialize(self, frame: Frame) -> str | bytes | None:
async def serialize(self, frame: Frame) -> str | bytes | None:
if isinstance(frame, StartInterruptionFrame):
answer = {"event": "clear", "streamSid": self._stream_sid}
return json.dumps(answer)
@@ -59,7 +59,7 @@ class TwilioFrameSerializer(FrameSerializer):
elif isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)):
return json.dumps(frame.message)
def deserialize(self, data: str | bytes) -> Frame | None:
async def deserialize(self, data: str | bytes) -> Frame | None:
message = json.loads(data)
if message["event"] == "media":

View File

@@ -91,7 +91,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
async def _receive_messages(self):
try:
async for message in self._iter_data():
frame = self._params.serializer.deserialize(message)
frame = await self._params.serializer.deserialize(message)
if not frame:
continue
@@ -163,7 +163,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
async def _write_frame(self, frame: Frame):
try:
payload = self._params.serializer.serialize(frame)
payload = await self._params.serializer.serialize(frame)
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
await self._send_data(payload)
except Exception as e:

View File

@@ -138,7 +138,7 @@ class WebsocketClientInputTransport(BaseInputTransport):
await self._session.disconnect()
async def on_message(self, websocket, message):
frame = self._params.serializer.deserialize(message)
frame = await self._params.serializer.deserialize(message)
if not frame:
return
if isinstance(frame, InputAudioRawFrame) and self._params.audio_in_enabled:
@@ -200,7 +200,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
await self._write_audio_sleep()
async def _write_frame(self, frame: Frame):
payload = self._params.serializer.serialize(frame)
payload = await self._params.serializer.serialize(frame)
if payload:
await self._session.send(payload)

View File

@@ -105,7 +105,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
# Handle incoming messages
try:
async for message in websocket:
frame = self._params.serializer.deserialize(message)
frame = await self._params.serializer.deserialize(message)
if not frame:
continue
@@ -193,7 +193,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
async def _write_frame(self, frame: Frame):
try:
payload = self._params.serializer.serialize(frame)
payload = await self._params.serializer.serialize(frame)
if payload and self._websocket:
await self._websocket.send(payload)
except Exception as e:

View File

@@ -385,7 +385,9 @@ class LiveKitInputTransport(BaseInputTransport):
audio_data = await self._client.get_next_audio_frame()
if audio_data:
audio_frame_event, participant_id = audio_data
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
pipecat_audio_frame = await self._convert_livekit_audio_to_pipecat(
audio_frame_event
)
input_audio_frame = InputAudioRawFrame(
audio=pipecat_audio_frame.audio,
sample_rate=pipecat_audio_frame.sample_rate,
@@ -393,12 +395,12 @@ class LiveKitInputTransport(BaseInputTransport):
)
await self.push_audio_frame(input_audio_frame)
def _convert_livekit_audio_to_pipecat(
async def _convert_livekit_audio_to_pipecat(
self, audio_frame_event: rtc.AudioFrameEvent
) -> AudioRawFrame:
audio_frame = audio_frame_event.frame
audio_data = self._resampler.resample(
audio_data = await self._resampler.resample(
audio_frame.data.tobytes(), audio_frame.sample_rate, self._params.audio_in_sample_rate
)

View File

@@ -20,17 +20,19 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
async def test_roundtrip(self):
text_frame = TextFrame(text="hello world")
frame = self.serializer.deserialize(self.serializer.serialize(text_frame))
frame = await self.serializer.deserialize(await self.serializer.serialize(text_frame))
self.assertEqual(text_frame, frame)
transcription_frame = TranscriptionFrame(
text="Hello there!", user_id="123", timestamp="2021-01-01"
)
frame = self.serializer.deserialize(self.serializer.serialize(transcription_frame))
frame = await self.serializer.deserialize(
await self.serializer.serialize(transcription_frame)
)
self.assertEqual(frame, transcription_frame)
audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1)
frame = self.serializer.deserialize(self.serializer.serialize(audio_frame))
frame = await self.serializer.deserialize(await self.serializer.serialize(audio_frame))
self.assertEqual(frame.audio, audio_frame.audio)
self.assertEqual(frame.sample_rate, audio_frame.sample_rate)
self.assertEqual(frame.num_channels, audio_frame.num_channels)