Files
engine-v5-pipecat-core/engine/product_protocol.py
2026-06-01 08:41:00 +08:00

240 lines
8.7 KiB
Python

from __future__ import annotations
import base64
import binascii
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
Frame,
InputAudioRawFrame,
InputTransportMessageFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
OutputTransportMessageUrgentFrame,
TranscriptionFrame,
UserImageRawFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer
MAX_INPUT_IMAGE_BYTES = 8 * 1024 * 1024
SUPPORTED_INPUT_IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
class ProductWebsocketSerializer(FrameSerializer):
"""Stable app-facing JSON/base64 protocol adapter for Pipecat websocket transport."""
protocol = "va.ws.v1"
def __init__(self, *, sample_rate: int, channels: int):
super().__init__()
self._sample_rate = sample_rate
self._channels = channels
self._sequence = 0
async def serialize(self, frame: Frame) -> str | bytes | None:
if isinstance(frame, OutputAudioRawFrame):
return self._event(
"response.audio.delta",
audio=base64.b64encode(frame.audio).decode("ascii"),
bytes=len(frame.audio),
sample_rate=frame.sample_rate,
channels=frame.num_channels,
)
if isinstance(frame, BotStartedSpeakingFrame):
return self._event("response.audio.started")
if isinstance(frame, BotStoppedSpeakingFrame):
return self._event("response.audio.stopped")
if isinstance(frame, TranscriptionFrame):
return self._event(
"input.transcript.final",
text=frame.text,
user_id=frame.user_id,
timestamp=frame.timestamp,
)
# ProductTextStreamProcessor owns response.text.* events. TTS also
# emits TextFrame subclasses internally, and serializing those here
# would make the UI render duplicate assistant bubbles.
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
if self.should_ignore_frame(frame):
return None
message = frame.message
# Allow callers to emit any named protocol event by pushing a
# transport-message frame whose payload already carries a `type`.
# The payload's other fields are merged alongside `type`, so e.g.
# `{"type": "response.text.final", "text": "..."}` is sent verbatim.
if isinstance(message, dict) and isinstance(message.get("type"), str):
event_type = message["type"]
payload = {k: v for k, v in message.items() if k != "type"}
return self._event(event_type, **payload)
return self._event("transport.message", message=message)
return None
async def deserialize(self, data: str | bytes) -> Frame | None:
if isinstance(data, bytes):
return InputAudioRawFrame(
audio=data,
sample_rate=self._sample_rate,
num_channels=self._channels,
)
try:
message = json.loads(data)
except json.JSONDecodeError as exc:
logger.warning(f"Invalid product websocket JSON: {exc}")
return None
if not isinstance(message, dict):
logger.warning("Product websocket message must be a JSON object")
return None
message_type = message.get("type")
if message_type == "session.start":
chat_id = message.get("chatId") or message.get("chat_id")
return InputTransportMessageFrame(
message={
"type": "session.started",
"protocol": self.protocol,
"chatId": chat_id if isinstance(chat_id, str) else None,
"audio": {
"encoding": "pcm_s16le",
"sample_rate": self._sample_rate,
"channels": self._channels,
},
}
)
if message_type == "session.stop":
return EndFrame()
if message_type == "response.cancel":
return CancelFrame(reason="client_cancelled")
if message_type == "input.audio":
audio = message.get("audio") or message.get("data")
if not isinstance(audio, str):
logger.warning("input.audio requires base64 'audio' or 'data'")
return None
try:
pcm = base64.b64decode(audio)
except (binascii.Error, ValueError) as exc:
logger.warning(f"Invalid input.audio base64: {exc}")
return None
return InputAudioRawFrame(
audio=pcm,
sample_rate=int(message.get("sample_rate") or self._sample_rate),
num_channels=int(message.get("channels") or self._channels),
)
if message_type == "input.image":
return self._deserialize_input_image(message)
if message_type == "input.text":
text = message.get("text")
if not isinstance(text, str) or not text.strip():
logger.warning("input.text requires non-empty 'text'")
return None
return InputTransportMessageFrame(
message={
"type": "input.text",
"text": text,
"interrupt": bool(message.get("interrupt", True)),
}
)
if message_type == "session.set_info":
return InputTransportMessageFrame(
message={
"type": "session.set_info",
"request_id": message.get("request_id"),
"key": message.get("key"),
"value": message.get("value"),
}
)
if message_type == "transport.message":
payload = message.get("message")
return InputTransportMessageFrame(message=payload if isinstance(payload, dict) else message)
logger.warning(f"Unsupported product websocket message type: {message_type!r}")
return None
def _deserialize_input_image(self, message: dict[str, Any]) -> Frame | None:
encoded = message.get("image") or message.get("data")
if not isinstance(encoded, str):
logger.warning("input.image requires base64 'image' or 'data'")
return None
mime_type = str(message.get("mime_type") or message.get("media_type") or "image/jpeg")
if mime_type not in SUPPORTED_INPUT_IMAGE_MIME_TYPES:
logger.warning(
"input.image unsupported mime_type "
f"{mime_type!r}; expected one of {sorted(SUPPORTED_INPUT_IMAGE_MIME_TYPES)}"
)
return None
try:
width = int(message.get("width") or 0)
height = int(message.get("height") or 0)
except (TypeError, ValueError):
logger.warning("input.image width and height must be integers")
return None
if width <= 0 or height <= 0:
logger.warning("input.image requires positive integer width and height")
return None
if "," in encoded and encoded.lstrip().startswith("data:"):
encoded = encoded.split(",", 1)[1]
try:
image = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as exc:
logger.warning(f"Invalid input.image base64: {exc}")
return None
if len(image) > MAX_INPUT_IMAGE_BYTES:
logger.warning(
f"input.image too large: {len(image)} bytes; "
f"max is {MAX_INPUT_IMAGE_BYTES} bytes"
)
return None
text = message.get("text")
if text is not None and not isinstance(text, str):
logger.warning("input.image text must be a string when provided")
return None
return UserImageRawFrame(
image=image,
size=(width, height),
format=mime_type,
user_id=str(message.get("user_id") or "product-user"),
text=text or "Answer using this camera image.",
append_to_context=bool(message.get("append_to_context", True)),
)
def _event(self, event_type: str, **payload: Any) -> str:
self._sequence += 1
return json.dumps(
{
"type": event_type,
"protocol": self.protocol,
"seq": self._sequence,
**payload,
},
ensure_ascii=False,
)