Merge pull request #515 from pipecat-ai/aleix/rtvi-frame-processors

RTVI frame processors
This commit is contained in:
Aleix Conchillo Flaqué
2024-09-27 00:48:09 -07:00
committed by GitHub
2 changed files with 224 additions and 63 deletions

View File

@@ -26,6 +26,8 @@ class AsyncGeneratorProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, (CancelFrame, EndFrame)):
await self._data_queue.put(None)
else:

View File

@@ -5,6 +5,7 @@
#
import asyncio
import base64
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
@@ -20,8 +21,14 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OutputAudioRawFrame,
StartFrame,
SystemFrame,
TTSStartedFrame,
TTSStoppedFrame,
TextFrame,
TranscriptionFrame,
TransportMessageFrame,
UserStartedSpeakingFrame,
@@ -34,7 +41,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from loguru import logger
RTVI_PROTOCOL_VERSION = "0.1"
RTVI_PROTOCOL_VERSION = "0.2"
ActionResult = Union[bool, int, float, str, list, dict]
@@ -242,33 +249,75 @@ class RTVILLMFunctionCallResultData(BaseModel):
result: dict | str
class RTVIBotLLMStartedMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-llm-started"] = "bot-llm-started"
class RTVIBotLLMStoppedMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-llm-stopped"] = "bot-llm-stopped"
class RTVIBotTTSStartedMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-tts-started"] = "bot-tts-started"
class RTVIBotTTSStoppedMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-tts-stopped"] = "bot-tts-stopped"
class RTVITextMessageData(BaseModel):
text: str
class RTVILLMTextMessage(BaseModel):
class RTVIBotLLMTextMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["llm-text"] = "llm-text"
type: Literal["bot-llm-text"] = "bot-llm-text"
data: RTVITextMessageData
class RTVITTSTextMessage(BaseModel):
class RTVIBotTTSTextMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["tts-text"] = "tts-text"
type: Literal["bot-tts-text"] = "bot-tts-text"
data: RTVITextMessageData
class RTVITranscriptionMessageData(BaseModel):
class RTVIAudioMessageData(BaseModel):
audio: str
sample_rate: int
num_channels: int
class RTVIBotAudioMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-audio"] = "bot-audio"
data: RTVIAudioMessageData
class RTVIBotTranscriptionMessageData(BaseModel):
text: str
class RTVIBotTranscriptionMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-transcription"] = "bot-transcription"
data: RTVIBotTranscriptionMessageData
class RTVIUserTranscriptionMessageData(BaseModel):
text: str
user_id: str
timestamp: str
final: bool
class RTVITranscriptionMessage(BaseModel):
class RTVIUserTranscriptionMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-transcription"] = "user-transcription"
data: RTVITranscriptionMessageData
data: RTVIUserTranscriptionMessageData
class RTVIUserStartedSpeakingMessage(BaseModel):
@@ -295,6 +344,170 @@ class RTVIProcessorParams(BaseModel):
send_bot_ready: bool = True
class RTVIFrameProcessor(FrameProcessor):
def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs):
super().__init__(**kwargs)
self._direction = direction
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
frame = TransportMessageFrame(
message=model.model_dump(exclude_none=exclude_none), urgent=True
)
await self.push_frame(frame, self._direction)
class RTVISpeakingProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
await self._handle_interruptions(frame)
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)):
await self._handle_bot_speaking(frame)
async def _handle_interruptions(self, frame: Frame):
message = None
if isinstance(frame, UserStartedSpeakingFrame):
message = RTVIUserStartedSpeakingMessage()
elif isinstance(frame, UserStoppedSpeakingFrame):
message = RTVIUserStoppedSpeakingMessage()
if message:
await self._push_transport_message(message)
async def _handle_bot_speaking(self, frame: Frame):
message = None
if isinstance(frame, BotStartedSpeakingFrame):
message = RTVIBotStartedSpeakingMessage()
elif isinstance(frame, BotStoppedSpeakingFrame):
message = RTVIBotStoppedSpeakingMessage()
if message:
await self._push_transport_message(message)
class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
await self._handle_user_transcriptions(frame)
async def _handle_user_transcriptions(self, frame: Frame):
message = None
if isinstance(frame, TranscriptionFrame):
message = RTVIUserTranscriptionMessage(
data=RTVIUserTranscriptionMessageData(
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True
)
)
elif isinstance(frame, InterimTranscriptionFrame):
message = RTVIUserTranscriptionMessage(
data=RTVIUserTranscriptionMessageData(
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False
)
)
if message:
await self._push_transport_message(message)
class RTVIBotLLMProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
await self._push_transport_message(RTVIBotLLMStartedMessage())
elif isinstance(frame, LLMFullResponseEndFrame):
await self._push_transport_message(RTVIBotLLMStoppedMessage())
class RTVIBotTTSProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, TTSStartedFrame):
await self._push_transport_message(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame):
await self._push_transport_message(RTVIBotTTSStoppedMessage())
class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, TextFrame):
await self._handle_text(frame)
async def _handle_text(self, frame: TextFrame):
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
await self._push_transport_message(message)
class RTVIBotTTSTextProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, TextFrame):
await self._handle_text(frame)
async def _handle_text(self, frame: TextFrame):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self._push_transport_message(message)
class RTVIBotAudioProcessor(RTVIFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, OutputAudioRawFrame):
await self._handle_audio(frame)
async def _handle_audio(self, frame: OutputAudioRawFrame):
encoded = base64.b64encode(frame.audio).decode("utf-8")
message = RTVIBotAudioMessage(
data=RTVIAudioMessageData(
audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels
)
)
await self._push_transport_message(message)
class RTVIProcessor(FrameProcessor):
def __init__(
self,
@@ -394,22 +607,9 @@ class RTVIProcessor(FrameProcessor):
# finish and the task finishes when EndFrame is processed.
await self.push_frame(frame, direction)
await self._stop(frame)
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(
frame, UserStoppedSpeakingFrame
):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame) or isinstance(
frame, BotStoppedSpeakingFrame
):
await self._handle_bot_speaking(frame)
await self.push_frame(frame, direction)
# Data frames
elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
await self._handle_transcriptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TransportMessageFrame):
await self._message_queue.put(frame)
await self._handle_transport_message(frame)
elif isinstance(frame, RTVIActionFrame):
await self._action_queue.put(frame)
# Other frames
@@ -452,47 +652,6 @@ class RTVIProcessor(FrameProcessor):
)
await self.push_frame(frame)
async def _handle_transcriptions(self, frame: Frame):
# TODO(aleix): Once we add support for using custom pipelines, the STTs will
# be in the pipeline after this processor.
message = None
if isinstance(frame, TranscriptionFrame):
message = RTVITranscriptionMessage(
data=RTVITranscriptionMessageData(
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True
)
)
elif isinstance(frame, InterimTranscriptionFrame):
message = RTVITranscriptionMessage(
data=RTVITranscriptionMessageData(
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False
)
)
if message:
await self._push_transport_message(message)
async def _handle_interruptions(self, frame: Frame):
message = None
if isinstance(frame, UserStartedSpeakingFrame):
message = RTVIUserStartedSpeakingMessage()
elif isinstance(frame, UserStoppedSpeakingFrame):
message = RTVIUserStoppedSpeakingMessage()
if message:
await self._push_transport_message(message)
async def _handle_bot_speaking(self, frame: Frame):
message = None
if isinstance(frame, BotStartedSpeakingFrame):
message = RTVIBotStartedSpeakingMessage()
elif isinstance(frame, BotStoppedSpeakingFrame):
message = RTVIBotStoppedSpeakingMessage()
if message:
await self._push_transport_message(message)
async def _action_task_handler(self):
while True:
try: