165 lines
5.7 KiB
Python
165 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
BotStartedSpeakingFrame,
|
|
BotStoppedSpeakingFrame,
|
|
EndFrame,
|
|
Frame,
|
|
InputAudioRawFrame,
|
|
InputTransportMessageFrame,
|
|
OutputAudioRawFrame,
|
|
OutputTransportMessageFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
TextFrame,
|
|
TranscriptionFrame,
|
|
)
|
|
from pipecat.serializers.base_serializer import FrameSerializer
|
|
|
|
|
|
class ProductWebsocketSerializer(FrameSerializer):
|
|
"""Stable app-facing JSON/base64 protocol adapter for Pipecat websocket transport."""
|
|
|
|
protocol = "va.ws.v1"
|
|
|
|
def __init__(self, *, sample_rate: int, channels: int):
|
|
super().__init__()
|
|
self._sample_rate = sample_rate
|
|
self._channels = channels
|
|
self._sequence = 0
|
|
|
|
async def serialize(self, frame: Frame) -> str | bytes | None:
|
|
if isinstance(frame, OutputAudioRawFrame):
|
|
return self._event(
|
|
"response.audio.delta",
|
|
audio=base64.b64encode(frame.audio).decode("ascii"),
|
|
bytes=len(frame.audio),
|
|
sample_rate=frame.sample_rate,
|
|
channels=frame.num_channels,
|
|
)
|
|
|
|
if isinstance(frame, BotStartedSpeakingFrame):
|
|
return self._event("response.audio.started")
|
|
|
|
if isinstance(frame, BotStoppedSpeakingFrame):
|
|
return self._event("response.audio.stopped")
|
|
|
|
if isinstance(frame, TranscriptionFrame):
|
|
return self._event(
|
|
"input.transcript.final",
|
|
text=frame.text,
|
|
user_id=frame.user_id,
|
|
timestamp=frame.timestamp,
|
|
)
|
|
|
|
if isinstance(frame, TextFrame):
|
|
return self._event("response.text.delta", text=frame.text)
|
|
|
|
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
|
|
if self.should_ignore_frame(frame):
|
|
return None
|
|
message = frame.message
|
|
# Allow callers to emit any named protocol event by pushing a
|
|
# transport-message frame whose payload already carries a `type`.
|
|
# The payload's other fields are merged alongside `type`, so e.g.
|
|
# `{"type": "response.text.final", "text": "..."}` is sent verbatim.
|
|
if isinstance(message, dict) and isinstance(message.get("type"), str):
|
|
event_type = message["type"]
|
|
payload = {k: v for k, v in message.items() if k != "type"}
|
|
return self._event(event_type, **payload)
|
|
return self._event("transport.message", message=message)
|
|
|
|
return None
|
|
|
|
async def deserialize(self, data: str | bytes) -> Frame | None:
|
|
if isinstance(data, bytes):
|
|
return InputAudioRawFrame(
|
|
audio=data,
|
|
sample_rate=self._sample_rate,
|
|
num_channels=self._channels,
|
|
)
|
|
|
|
try:
|
|
message = json.loads(data)
|
|
except json.JSONDecodeError as exc:
|
|
logger.warning(f"Invalid product websocket JSON: {exc}")
|
|
return None
|
|
|
|
if not isinstance(message, dict):
|
|
logger.warning("Product websocket message must be a JSON object")
|
|
return None
|
|
|
|
message_type = message.get("type")
|
|
if message_type == "session.start":
|
|
return InputTransportMessageFrame(
|
|
message={
|
|
"type": "session.started",
|
|
"protocol": self.protocol,
|
|
"audio": {
|
|
"encoding": "pcm_s16le",
|
|
"sample_rate": self._sample_rate,
|
|
"channels": self._channels,
|
|
},
|
|
}
|
|
)
|
|
|
|
if message_type == "session.stop":
|
|
return EndFrame()
|
|
|
|
if message_type == "response.cancel":
|
|
return CancelFrame(reason="client_cancelled")
|
|
|
|
if message_type == "input.audio":
|
|
audio = message.get("audio") or message.get("data")
|
|
if not isinstance(audio, str):
|
|
logger.warning("input.audio requires base64 'audio' or 'data'")
|
|
return None
|
|
try:
|
|
pcm = base64.b64decode(audio)
|
|
except ValueError as exc:
|
|
logger.warning(f"Invalid input.audio base64: {exc}")
|
|
return None
|
|
return InputAudioRawFrame(
|
|
audio=pcm,
|
|
sample_rate=int(message.get("sample_rate") or self._sample_rate),
|
|
num_channels=int(message.get("channels") or self._channels),
|
|
)
|
|
|
|
if message_type == "input.text":
|
|
text = message.get("text")
|
|
if not isinstance(text, str) or not text.strip():
|
|
logger.warning("input.text requires non-empty 'text'")
|
|
return None
|
|
return InputTransportMessageFrame(
|
|
message={
|
|
"type": "input.text",
|
|
"text": text,
|
|
"interrupt": bool(message.get("interrupt", True)),
|
|
}
|
|
)
|
|
|
|
if message_type == "transport.message":
|
|
payload = message.get("message")
|
|
return InputTransportMessageFrame(message=payload if isinstance(payload, dict) else message)
|
|
|
|
logger.warning(f"Unsupported product websocket message type: {message_type!r}")
|
|
return None
|
|
|
|
def _event(self, event_type: str, **payload: Any) -> str:
|
|
self._sequence += 1
|
|
return json.dumps(
|
|
{
|
|
"type": event_type,
|
|
"protocol": self.protocol,
|
|
"seq": self._sequence,
|
|
**payload,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|