Add support for new RTVI client message protocol: handling and responding
This commit is contained in:
committed by
Mattie Ruth
parent
c4a9fc7f88
commit
43049c865c
@@ -13,6 +13,7 @@ and frame observation for the RTVI protocol.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -44,6 +45,7 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMTextFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
@@ -71,6 +73,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
|
||||
)
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
@@ -240,6 +243,66 @@ class RTVIActionFrame(DataFrame):
|
||||
message_id: Optional[str] = None
|
||||
|
||||
|
||||
class RTVIRawClientMessageData(BaseModel):
|
||||
"""Data structure expected from client messages sent to the RTVI server."""
|
||||
|
||||
t: str
|
||||
d: Optional[Any] = None
|
||||
|
||||
|
||||
class RTVIClientMessage(BaseModel):
|
||||
"""Cleansed data structure for client messages for handling."""
|
||||
|
||||
msg_id: str
|
||||
type: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIClientMessageFrame(SystemFrame):
|
||||
"""A frame for sending messages from the client to the RTVI server.
|
||||
|
||||
This frame is meant for custom messaging from the client to the server
|
||||
and expects a server-response message.
|
||||
"""
|
||||
|
||||
msg_id: str
|
||||
type: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIServerResponseFrame(SystemFrame):
|
||||
"""A frame for sending messages from the client to the RTVI server.
|
||||
|
||||
This frame is meant for custom messaging from the client to the server
|
||||
and expects a server-response message.
|
||||
"""
|
||||
|
||||
client_msg: RTVIClientMessageFrame
|
||||
data: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class RTVIRawServerResponseData(BaseModel):
|
||||
"""Data structure for server responses to client messages."""
|
||||
|
||||
t: str
|
||||
d: Optional[Any] = None
|
||||
|
||||
|
||||
class RTVIServerResponse(BaseModel):
|
||||
"""A response message from the client to the RTVI server.
|
||||
|
||||
This message is used to respond to custom messages sent by the server.
|
||||
"""
|
||||
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["server-response"] = "server-response"
|
||||
id: str
|
||||
data: RTVIRawServerResponseData
|
||||
|
||||
|
||||
class RTVIMessage(BaseModel):
|
||||
"""Base RTVI message structure.
|
||||
|
||||
@@ -418,6 +481,18 @@ class RTVILLMFunctionCallMessage(BaseModel):
|
||||
data: RTVILLMFunctionCallMessageData
|
||||
|
||||
|
||||
class RTVIAppendToContextData(BaseModel):
|
||||
role: Literal["user", "assistant"] | str
|
||||
content: Any
|
||||
run_immediately: bool = False
|
||||
|
||||
|
||||
class RTVIAppendToContext(BaseModel):
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["append-to-context"] = "append-to-context"
|
||||
data: RTVIAppendToContextData
|
||||
|
||||
|
||||
class RTVILLMFunctionCallStartMessageData(BaseModel):
|
||||
"""Data for LLM function call start notification.
|
||||
|
||||
@@ -752,6 +827,11 @@ class RTVIObserver(BaseObserver):
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
await self.push_transport_message_urgent(message)
|
||||
elif isinstance(frame, RTVIServerResponseFrame):
|
||||
if frame.error is not None:
|
||||
await self._send_error_response(frame)
|
||||
else:
|
||||
await self._send_server_response(frame)
|
||||
|
||||
if mark_as_seen:
|
||||
self._frames_seen.add(frame.id)
|
||||
@@ -879,6 +959,22 @@ class RTVIObserver(BaseObserver):
|
||||
message = RTVIMetricsMessage(data=metrics)
|
||||
await self.push_transport_message_urgent(message)
|
||||
|
||||
async def _send_server_response(self, frame: RTVIServerResponseFrame):
|
||||
"""Send a response to the client for a specific request."""
|
||||
message = RTVIServerResponse(
|
||||
id=str(frame.client_msg.msg_id),
|
||||
data=RTVIRawServerResponseData(t=frame.client_msg.type, d=frame.data),
|
||||
)
|
||||
await self.push_transport_message_urgent(message)
|
||||
|
||||
async def _send_error_response(self, frame: RTVIServerResponseFrame):
|
||||
"""Send a response to the client for a specific request."""
|
||||
if self._params.errors_enabled:
|
||||
message = RTVIErrorResponse(
|
||||
id=str(frame.client_msg.msg_id), data=RTVIErrorResponseData(error=frame.error)
|
||||
)
|
||||
await self.push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
"""Main processor for handling RTVI protocol messages and actions.
|
||||
@@ -921,6 +1017,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
self._register_event_handler("on_bot_started")
|
||||
self._register_event_handler("on_client_ready")
|
||||
self._register_event_handler("on_client_message")
|
||||
|
||||
self._input_transport = None
|
||||
self._transport = transport
|
||||
@@ -936,6 +1033,15 @@ class RTVIProcessor(FrameProcessor):
|
||||
Args:
|
||||
action: The action to register.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The actions API is deprecated, use server and client messages instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
id = self._action_id(action.service, action.action)
|
||||
self._registered_actions[id] = action
|
||||
|
||||
@@ -945,6 +1051,15 @@ class RTVIProcessor(FrameProcessor):
|
||||
Args:
|
||||
service: The service to register.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The actions API is deprecated, use server and client messages instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._registered_services[service.name] = service
|
||||
|
||||
async def set_client_ready(self):
|
||||
@@ -970,6 +1085,21 @@ class RTVIProcessor(FrameProcessor):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
message = RTVIServerMessage(data=data)
|
||||
await self._send_server_message(message)
|
||||
|
||||
async def send_server_response(self, client_msg: RTVIClientMessage, data: Any):
|
||||
"""Send a server response for a given client message."""
|
||||
message = RTVIServerResponse(
|
||||
id=client_msg.msg_id, data=RTVIRawServerResponseData(t=client_msg.type, d=data)
|
||||
)
|
||||
await self._send_server_message(message)
|
||||
|
||||
async def send_error_response(self, client_msg: RTVIClientMessage, error: str):
|
||||
await self._send_error_response(id=client_msg.msg_id, error=error)
|
||||
|
||||
async def send_error(self, error: str):
|
||||
"""Send an error message to the client.
|
||||
|
||||
@@ -1148,6 +1278,9 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._handle_update_config(message.id, update_config)
|
||||
case "disconnect-bot":
|
||||
await self.push_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
||||
case "client-message":
|
||||
data = RTVIRawClientMessageData.model_validate(message.data)
|
||||
await self._handle_client_message(message.id, data)
|
||||
case "action":
|
||||
action = RTVIActionRun.model_validate(message.data)
|
||||
action_frame = RTVIActionFrame(message_id=message.id, rtvi_action_run=action)
|
||||
@@ -1155,11 +1288,14 @@ class RTVIProcessor(FrameProcessor):
|
||||
case "llm-function-call-result":
|
||||
data = RTVILLMFunctionCallResultData.model_validate(message.data)
|
||||
await self._handle_function_call_result(data)
|
||||
case "append-to-context":
|
||||
data = RTVIAppendToContextData.model_validate(message.data)
|
||||
await self._handle_update_context(data)
|
||||
case "raw-audio" | "raw-audio-batch":
|
||||
await self._handle_audio_buffer(message.data)
|
||||
|
||||
case _:
|
||||
await self._send_error_response(message.id, f"Unsupported type {message.type}")
|
||||
await self._send_error_response(message.id, f"UNSUPPORTED type {message.type}")
|
||||
|
||||
except ValidationError as e:
|
||||
await self._send_error_response(message.id, f"Invalid message: {e}")
|
||||
@@ -1201,18 +1337,45 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def _handle_describe_config(self, request_id: str):
|
||||
"""Handle a describe-config request."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Configuration helpers are deprecated. If your application needs this behavior, use custom server and client messages.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
services = list(self._registered_services.values())
|
||||
message = RTVIDescribeConfig(id=request_id, data=RTVIDescribeConfigData(config=services))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_describe_actions(self, request_id: str):
|
||||
"""Handle a describe-actions request."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The Actions API is deprecated, use custom server and client messages instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
actions = list(self._registered_actions.values())
|
||||
message = RTVIDescribeActions(id=request_id, data=RTVIDescribeActionsData(actions=actions))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_get_config(self, request_id: str):
|
||||
"""Handle a get-config request."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Configuration helpers are deprecated. If your application needs this behavior, use custom server and client messages.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
message = RTVIConfigResponse(id=request_id, data=self._config)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
@@ -1230,6 +1393,15 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def _update_service_config(self, config: RTVIServiceConfig):
|
||||
"""Update configuration for a specific service."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Configuration helpers are deprecated. If your application needs this behavior, use custom server and client messages.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
service = self._registered_services[config.service]
|
||||
for option in config.options:
|
||||
handler = service._options_dict[option.name].handler
|
||||
@@ -1238,6 +1410,15 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def _update_config(self, data: RTVIConfig, interrupt: bool):
|
||||
"""Update the RTVI configuration."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Configuration helpers are deprecated. If your application needs this behavior, use custom server and client messages.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if interrupt:
|
||||
await self.interrupt_bot()
|
||||
for service_config in data.config:
|
||||
@@ -1248,6 +1429,33 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._update_config(RTVIConfig(config=data.config), data.interrupt)
|
||||
await self._handle_get_config(request_id)
|
||||
|
||||
async def _handle_update_context(self, data: RTVIAppendToContextData):
|
||||
if data.run_immediately:
|
||||
await self.interrupt_bot()
|
||||
frame = LLMMessagesAppendFrame(
|
||||
messages=[{"role": data.role, "content": data.content}],
|
||||
run_llm=data.run_immediately,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_client_message(self, msg_id: str, data: RTVIRawClientMessageData):
|
||||
"""Handle a client message frame."""
|
||||
if not data:
|
||||
await self._send_error_response(msg_id, "Malformed client message")
|
||||
return
|
||||
|
||||
# Create a RTVIClientMessageFrame to push the message
|
||||
frame = RTVIClientMessageFrame(msg_id=msg_id, type=data.t, data=data.d)
|
||||
await self.push_frame(frame)
|
||||
await self._call_event_handler(
|
||||
"on_client_message",
|
||||
RTVIClientMessage(
|
||||
msg_id=msg_id,
|
||||
type=data.t,
|
||||
data=data.d,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_function_call_result(self, data):
|
||||
"""Handle a function call result from the client."""
|
||||
frame = FunctionCallResultFrame(
|
||||
@@ -1284,6 +1492,10 @@ class RTVIProcessor(FrameProcessor):
|
||||
)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_server_message(self, message: RTVIServerMessage | RTVIServerResponse):
|
||||
"""Send a message or response to the client."""
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_frame(self, frame: ErrorFrame):
|
||||
"""Send an error frame as an RTVI error message."""
|
||||
if self._errors_enabled:
|
||||
|
||||
Reference in New Issue
Block a user