Files
engine-v5-pipecat-core/engine/product_protocol.py
2026-05-21 13:08:40 +08:00

165 lines
5.7 KiB
Python

from __future__ import annotations
import base64
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
Frame,
InputAudioRawFrame,
InputTransportMessageFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
OutputTransportMessageUrgentFrame,
TextFrame,
TranscriptionFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer
class ProductWebsocketSerializer(FrameSerializer):
"""Stable app-facing JSON/base64 protocol adapter for Pipecat websocket transport."""
protocol = "va.ws.v1"
def __init__(self, *, sample_rate: int, channels: int):
super().__init__()
self._sample_rate = sample_rate
self._channels = channels
self._sequence = 0
async def serialize(self, frame: Frame) -> str | bytes | None:
if isinstance(frame, OutputAudioRawFrame):
return self._event(
"response.audio.delta",
audio=base64.b64encode(frame.audio).decode("ascii"),
bytes=len(frame.audio),
sample_rate=frame.sample_rate,
channels=frame.num_channels,
)
if isinstance(frame, BotStartedSpeakingFrame):
return self._event("response.audio.started")
if isinstance(frame, BotStoppedSpeakingFrame):
return self._event("response.audio.stopped")
if isinstance(frame, TranscriptionFrame):
return self._event(
"input.transcript.final",
text=frame.text,
user_id=frame.user_id,
timestamp=frame.timestamp,
)
if isinstance(frame, TextFrame):
return self._event("response.text.delta", text=frame.text)
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
if self.should_ignore_frame(frame):
return None
message = frame.message
# Allow callers to emit any named protocol event by pushing a
# transport-message frame whose payload already carries a `type`.
# The payload's other fields are merged alongside `type`, so e.g.
# `{"type": "response.text.final", "text": "..."}` is sent verbatim.
if isinstance(message, dict) and isinstance(message.get("type"), str):
event_type = message["type"]
payload = {k: v for k, v in message.items() if k != "type"}
return self._event(event_type, **payload)
return self._event("transport.message", message=message)
return None
async def deserialize(self, data: str | bytes) -> Frame | None:
if isinstance(data, bytes):
return InputAudioRawFrame(
audio=data,
sample_rate=self._sample_rate,
num_channels=self._channels,
)
try:
message = json.loads(data)
except json.JSONDecodeError as exc:
logger.warning(f"Invalid product websocket JSON: {exc}")
return None
if not isinstance(message, dict):
logger.warning("Product websocket message must be a JSON object")
return None
message_type = message.get("type")
if message_type == "session.start":
return InputTransportMessageFrame(
message={
"type": "session.started",
"protocol": self.protocol,
"audio": {
"encoding": "pcm_s16le",
"sample_rate": self._sample_rate,
"channels": self._channels,
},
}
)
if message_type == "session.stop":
return EndFrame()
if message_type == "response.cancel":
return CancelFrame(reason="client_cancelled")
if message_type == "input.audio":
audio = message.get("audio") or message.get("data")
if not isinstance(audio, str):
logger.warning("input.audio requires base64 'audio' or 'data'")
return None
try:
pcm = base64.b64decode(audio)
except ValueError as exc:
logger.warning(f"Invalid input.audio base64: {exc}")
return None
return InputAudioRawFrame(
audio=pcm,
sample_rate=int(message.get("sample_rate") or self._sample_rate),
num_channels=int(message.get("channels") or self._channels),
)
if message_type == "input.text":
text = message.get("text")
if not isinstance(text, str) or not text.strip():
logger.warning("input.text requires non-empty 'text'")
return None
return InputTransportMessageFrame(
message={
"type": "input.text",
"text": text,
"interrupt": bool(message.get("interrupt", True)),
}
)
if message_type == "transport.message":
payload = message.get("message")
return InputTransportMessageFrame(message=payload if isinstance(payload, dict) else message)
logger.warning(f"Unsupported product websocket message type: {message_type!r}")
return None
def _event(self, event_type: str, **payload: Any) -> str:
self._sequence += 1
return json.dumps(
{
"type": event_type,
"protocol": self.protocol,
"seq": self._sequence,
**payload,
},
ensure_ascii=False,
)