Merge pull request #515 from pipecat-ai/aleix/rtvi-frame-processors
RTVI frame processors
This commit is contained in:
@@ -26,6 +26,8 @@ class AsyncGeneratorProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (CancelFrame, EndFrame)):
|
||||
await self._data_queue.put(None)
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
@@ -20,8 +21,14 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -34,7 +41,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
|
||||
|
||||
RTVI_PROTOCOL_VERSION = "0.1"
|
||||
RTVI_PROTOCOL_VERSION = "0.2"
|
||||
|
||||
ActionResult = Union[bool, int, float, str, list, dict]
|
||||
|
||||
@@ -242,33 +249,75 @@ class RTVILLMFunctionCallResultData(BaseModel):
|
||||
result: dict | str
|
||||
|
||||
|
||||
class RTVIBotLLMStartedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-llm-started"] = "bot-llm-started"
|
||||
|
||||
|
||||
class RTVIBotLLMStoppedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-llm-stopped"] = "bot-llm-stopped"
|
||||
|
||||
|
||||
class RTVIBotTTSStartedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-tts-started"] = "bot-tts-started"
|
||||
|
||||
|
||||
class RTVIBotTTSStoppedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-tts-stopped"] = "bot-tts-stopped"
|
||||
|
||||
|
||||
class RTVITextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVILLMTextMessage(BaseModel):
|
||||
class RTVIBotLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["llm-text"] = "llm-text"
|
||||
type: Literal["bot-llm-text"] = "bot-llm-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVITTSTextMessage(BaseModel):
|
||||
class RTVIBotTTSTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["tts-text"] = "tts-text"
|
||||
type: Literal["bot-tts-text"] = "bot-tts-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVITranscriptionMessageData(BaseModel):
|
||||
class RTVIAudioMessageData(BaseModel):
|
||||
audio: str
|
||||
sample_rate: int
|
||||
num_channels: int
|
||||
|
||||
|
||||
class RTVIBotAudioMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-audio"] = "bot-audio"
|
||||
data: RTVIAudioMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-transcription"] = "bot-transcription"
|
||||
data: RTVIBotTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
user_id: str
|
||||
timestamp: str
|
||||
final: bool
|
||||
|
||||
|
||||
class RTVITranscriptionMessage(BaseModel):
|
||||
class RTVIUserTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-transcription"] = "user-transcription"
|
||||
data: RTVITranscriptionMessageData
|
||||
data: RTVIUserTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
@@ -295,6 +344,170 @@ class RTVIProcessorParams(BaseModel):
|
||||
send_bot_ready: bool = True
|
||||
|
||||
|
||||
class RTVIFrameProcessor(FrameProcessor):
|
||||
def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._direction = direction
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none), urgent=True
|
||||
)
|
||||
await self.push_frame(frame, self._direction)
|
||||
|
||||
|
||||
class RTVISpeakingProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)):
|
||||
await self._handle_bot_speaking(frame)
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
message = RTVIUserStartedSpeakingMessage()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_bot_speaking(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
message = RTVIBotStartedSpeakingMessage()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
message = RTVIBotStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
|
||||
async def _handle_user_transcriptions(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
message = RTVIUserTranscriptionMessage(
|
||||
data=RTVIUserTranscriptionMessageData(
|
||||
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True
|
||||
)
|
||||
)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
message = RTVIUserTranscriptionMessage(
|
||||
data=RTVIUserTranscriptionMessageData(
|
||||
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False
|
||||
)
|
||||
)
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIBotLLMProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStoppedMessage())
|
||||
|
||||
|
||||
class RTVIBotTTSProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStoppedMessage())
|
||||
|
||||
|
||||
class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIBotTTSTextProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIBotAudioProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
|
||||
async def _handle_audio(self, frame: OutputAudioRawFrame):
|
||||
encoded = base64.b64encode(frame.audio).decode("utf-8")
|
||||
message = RTVIBotAudioMessage(
|
||||
data=RTVIAudioMessageData(
|
||||
audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels
|
||||
)
|
||||
)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -394,22 +607,9 @@ class RTVIProcessor(FrameProcessor):
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(
|
||||
frame, UserStoppedSpeakingFrame
|
||||
):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame) or isinstance(
|
||||
frame, BotStoppedSpeakingFrame
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# Data frames
|
||||
elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_transcriptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
await self._message_queue.put(frame)
|
||||
await self._handle_transport_message(frame)
|
||||
elif isinstance(frame, RTVIActionFrame):
|
||||
await self._action_queue.put(frame)
|
||||
# Other frames
|
||||
@@ -452,47 +652,6 @@ class RTVIProcessor(FrameProcessor):
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_transcriptions(self, frame: Frame):
|
||||
# TODO(aleix): Once we add support for using custom pipelines, the STTs will
|
||||
# be in the pipeline after this processor.
|
||||
|
||||
message = None
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
message = RTVITranscriptionMessage(
|
||||
data=RTVITranscriptionMessageData(
|
||||
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True
|
||||
)
|
||||
)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
message = RTVITranscriptionMessage(
|
||||
data=RTVITranscriptionMessageData(
|
||||
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False
|
||||
)
|
||||
)
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
message = RTVIUserStartedSpeakingMessage()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_bot_speaking(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
message = RTVIBotStartedSpeakingMessage()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
message = RTVIBotStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _action_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user