240 lines
8.7 KiB
Python
240 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import binascii
|
|
import json
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
BotStartedSpeakingFrame,
|
|
BotStoppedSpeakingFrame,
|
|
EndFrame,
|
|
Frame,
|
|
InputAudioRawFrame,
|
|
InputTransportMessageFrame,
|
|
OutputAudioRawFrame,
|
|
OutputTransportMessageFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
TranscriptionFrame,
|
|
UserImageRawFrame,
|
|
)
|
|
from pipecat.serializers.base_serializer import FrameSerializer
|
|
|
|
|
|
MAX_INPUT_IMAGE_BYTES = 8 * 1024 * 1024
|
|
SUPPORTED_INPUT_IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
|
|
|
|
|
class ProductWebsocketSerializer(FrameSerializer):
|
|
"""Stable app-facing JSON/base64 protocol adapter for Pipecat websocket transport."""
|
|
|
|
protocol = "va.ws.v1"
|
|
|
|
def __init__(self, *, sample_rate: int, channels: int):
|
|
super().__init__()
|
|
self._sample_rate = sample_rate
|
|
self._channels = channels
|
|
self._sequence = 0
|
|
|
|
async def serialize(self, frame: Frame) -> str | bytes | None:
|
|
if isinstance(frame, OutputAudioRawFrame):
|
|
return self._event(
|
|
"response.audio.delta",
|
|
audio=base64.b64encode(frame.audio).decode("ascii"),
|
|
bytes=len(frame.audio),
|
|
sample_rate=frame.sample_rate,
|
|
channels=frame.num_channels,
|
|
)
|
|
|
|
if isinstance(frame, BotStartedSpeakingFrame):
|
|
return self._event("response.audio.started")
|
|
|
|
if isinstance(frame, BotStoppedSpeakingFrame):
|
|
return self._event("response.audio.stopped")
|
|
|
|
if isinstance(frame, TranscriptionFrame):
|
|
return self._event(
|
|
"input.transcript.final",
|
|
text=frame.text,
|
|
user_id=frame.user_id,
|
|
timestamp=frame.timestamp,
|
|
)
|
|
|
|
# ProductTextStreamProcessor owns response.text.* events. TTS also
|
|
# emits TextFrame subclasses internally, and serializing those here
|
|
# would make the UI render duplicate assistant bubbles.
|
|
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
|
|
if self.should_ignore_frame(frame):
|
|
return None
|
|
message = frame.message
|
|
# Allow callers to emit any named protocol event by pushing a
|
|
# transport-message frame whose payload already carries a `type`.
|
|
# The payload's other fields are merged alongside `type`, so e.g.
|
|
# `{"type": "response.text.final", "text": "..."}` is sent verbatim.
|
|
if isinstance(message, dict) and isinstance(message.get("type"), str):
|
|
event_type = message["type"]
|
|
payload = {k: v for k, v in message.items() if k != "type"}
|
|
return self._event(event_type, **payload)
|
|
return self._event("transport.message", message=message)
|
|
|
|
return None
|
|
|
|
async def deserialize(self, data: str | bytes) -> Frame | None:
|
|
if isinstance(data, bytes):
|
|
return InputAudioRawFrame(
|
|
audio=data,
|
|
sample_rate=self._sample_rate,
|
|
num_channels=self._channels,
|
|
)
|
|
|
|
try:
|
|
message = json.loads(data)
|
|
except json.JSONDecodeError as exc:
|
|
logger.warning(f"Invalid product websocket JSON: {exc}")
|
|
return None
|
|
|
|
if not isinstance(message, dict):
|
|
logger.warning("Product websocket message must be a JSON object")
|
|
return None
|
|
|
|
message_type = message.get("type")
|
|
if message_type == "session.start":
|
|
chat_id = message.get("chatId") or message.get("chat_id")
|
|
return InputTransportMessageFrame(
|
|
message={
|
|
"type": "session.started",
|
|
"protocol": self.protocol,
|
|
"chatId": chat_id if isinstance(chat_id, str) else None,
|
|
"audio": {
|
|
"encoding": "pcm_s16le",
|
|
"sample_rate": self._sample_rate,
|
|
"channels": self._channels,
|
|
},
|
|
}
|
|
)
|
|
|
|
if message_type == "session.stop":
|
|
return EndFrame()
|
|
|
|
if message_type == "response.cancel":
|
|
return CancelFrame(reason="client_cancelled")
|
|
|
|
if message_type == "input.audio":
|
|
audio = message.get("audio") or message.get("data")
|
|
if not isinstance(audio, str):
|
|
logger.warning("input.audio requires base64 'audio' or 'data'")
|
|
return None
|
|
try:
|
|
pcm = base64.b64decode(audio)
|
|
except (binascii.Error, ValueError) as exc:
|
|
logger.warning(f"Invalid input.audio base64: {exc}")
|
|
return None
|
|
return InputAudioRawFrame(
|
|
audio=pcm,
|
|
sample_rate=int(message.get("sample_rate") or self._sample_rate),
|
|
num_channels=int(message.get("channels") or self._channels),
|
|
)
|
|
|
|
if message_type == "input.image":
|
|
return self._deserialize_input_image(message)
|
|
|
|
if message_type == "input.text":
|
|
text = message.get("text")
|
|
if not isinstance(text, str) or not text.strip():
|
|
logger.warning("input.text requires non-empty 'text'")
|
|
return None
|
|
return InputTransportMessageFrame(
|
|
message={
|
|
"type": "input.text",
|
|
"text": text,
|
|
"interrupt": bool(message.get("interrupt", True)),
|
|
}
|
|
)
|
|
|
|
if message_type == "session.set_info":
|
|
return InputTransportMessageFrame(
|
|
message={
|
|
"type": "session.set_info",
|
|
"request_id": message.get("request_id"),
|
|
"key": message.get("key"),
|
|
"value": message.get("value"),
|
|
}
|
|
)
|
|
|
|
if message_type == "transport.message":
|
|
payload = message.get("message")
|
|
return InputTransportMessageFrame(message=payload if isinstance(payload, dict) else message)
|
|
|
|
logger.warning(f"Unsupported product websocket message type: {message_type!r}")
|
|
return None
|
|
|
|
def _deserialize_input_image(self, message: dict[str, Any]) -> Frame | None:
|
|
encoded = message.get("image") or message.get("data")
|
|
if not isinstance(encoded, str):
|
|
logger.warning("input.image requires base64 'image' or 'data'")
|
|
return None
|
|
|
|
mime_type = str(message.get("mime_type") or message.get("media_type") or "image/jpeg")
|
|
if mime_type not in SUPPORTED_INPUT_IMAGE_MIME_TYPES:
|
|
logger.warning(
|
|
"input.image unsupported mime_type "
|
|
f"{mime_type!r}; expected one of {sorted(SUPPORTED_INPUT_IMAGE_MIME_TYPES)}"
|
|
)
|
|
return None
|
|
|
|
try:
|
|
width = int(message.get("width") or 0)
|
|
height = int(message.get("height") or 0)
|
|
except (TypeError, ValueError):
|
|
logger.warning("input.image width and height must be integers")
|
|
return None
|
|
|
|
if width <= 0 or height <= 0:
|
|
logger.warning("input.image requires positive integer width and height")
|
|
return None
|
|
|
|
if "," in encoded and encoded.lstrip().startswith("data:"):
|
|
encoded = encoded.split(",", 1)[1]
|
|
|
|
try:
|
|
image = base64.b64decode(encoded, validate=True)
|
|
except (binascii.Error, ValueError) as exc:
|
|
logger.warning(f"Invalid input.image base64: {exc}")
|
|
return None
|
|
|
|
if len(image) > MAX_INPUT_IMAGE_BYTES:
|
|
logger.warning(
|
|
f"input.image too large: {len(image)} bytes; "
|
|
f"max is {MAX_INPUT_IMAGE_BYTES} bytes"
|
|
)
|
|
return None
|
|
|
|
text = message.get("text")
|
|
if text is not None and not isinstance(text, str):
|
|
logger.warning("input.image text must be a string when provided")
|
|
return None
|
|
|
|
return UserImageRawFrame(
|
|
image=image,
|
|
size=(width, height),
|
|
format=mime_type,
|
|
user_id=str(message.get("user_id") or "product-user"),
|
|
text=text or "Answer using this camera image.",
|
|
append_to_context=bool(message.get("append_to_context", True)),
|
|
)
|
|
|
|
def _event(self, event_type: str, **payload: Any) -> str:
|
|
self._sequence += 1
|
|
return json.dumps(
|
|
{
|
|
"type": event_type,
|
|
"protocol": self.protocol,
|
|
"seq": self._sequence,
|
|
**payload,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|