diff --git a/src/pipecat/processors/async_generator.py b/src/pipecat/processors/async_generator.py index 66b2a3e99..4f9bc85d0 100644 --- a/src/pipecat/processors/async_generator.py +++ b/src/pipecat/processors/async_generator.py @@ -26,6 +26,8 @@ class AsyncGeneratorProcessor(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + if isinstance(frame, (CancelFrame, EndFrame)): await self._data_queue.put(None) else: diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 930b2331d..f88660f60 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -5,6 +5,7 @@ # import asyncio +import base64 from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, ValidationError @@ -20,8 +21,14 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + OutputAudioRawFrame, StartFrame, SystemFrame, + TTSStartedFrame, + TTSStoppedFrame, + TextFrame, TranscriptionFrame, TransportMessageFrame, UserStartedSpeakingFrame, @@ -34,7 +41,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from loguru import logger -RTVI_PROTOCOL_VERSION = "0.1" +RTVI_PROTOCOL_VERSION = "0.2" ActionResult = Union[bool, int, float, str, list, dict] @@ -242,33 +249,75 @@ class RTVILLMFunctionCallResultData(BaseModel): result: dict | str +class RTVIBotLLMStartedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-llm-started"] = "bot-llm-started" + + +class RTVIBotLLMStoppedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-llm-stopped"] = "bot-llm-stopped" + + +class RTVIBotTTSStartedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-tts-started"] = "bot-tts-started" + + +class RTVIBotTTSStoppedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-tts-stopped"] = "bot-tts-stopped" + + class RTVITextMessageData(BaseModel): text: str -class RTVILLMTextMessage(BaseModel): +class RTVIBotLLMTextMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" - type: Literal["llm-text"] = "llm-text" + type: Literal["bot-llm-text"] = "bot-llm-text" data: RTVITextMessageData -class RTVITTSTextMessage(BaseModel): +class RTVIBotTTSTextMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" - type: Literal["tts-text"] = "tts-text" + type: Literal["bot-tts-text"] = "bot-tts-text" data: RTVITextMessageData -class RTVITranscriptionMessageData(BaseModel): +class RTVIAudioMessageData(BaseModel): + audio: str + sample_rate: int + num_channels: int + + +class RTVIBotAudioMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-audio"] = "bot-audio" + data: RTVIAudioMessageData + + +class RTVIBotTranscriptionMessageData(BaseModel): + text: str + + +class RTVIBotTranscriptionMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-transcription"] = "bot-transcription" + data: RTVIBotTranscriptionMessageData + + +class RTVIUserTranscriptionMessageData(BaseModel): text: str user_id: str timestamp: str final: bool -class RTVITranscriptionMessage(BaseModel): +class RTVIUserTranscriptionMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["user-transcription"] = "user-transcription" - data: RTVITranscriptionMessageData + data: RTVIUserTranscriptionMessageData class RTVIUserStartedSpeakingMessage(BaseModel): @@ -295,6 +344,170 @@ class RTVIProcessorParams(BaseModel): send_bot_ready: bool = True +class RTVIFrameProcessor(FrameProcessor): + def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs): + super().__init__(**kwargs) + self._direction = direction + + async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True): + frame = TransportMessageFrame( + message=model.model_dump(exclude_none=exclude_none), urgent=True + ) + await self.push_frame(frame, self._direction) + + +class RTVISpeakingProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)): + await self._handle_interruptions(frame) + elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)): + await self._handle_bot_speaking(frame) + + async def _handle_interruptions(self, frame: Frame): + message = None + if isinstance(frame, UserStartedSpeakingFrame): + message = RTVIUserStartedSpeakingMessage() + elif isinstance(frame, UserStoppedSpeakingFrame): + message = RTVIUserStoppedSpeakingMessage() + + if message: + await self._push_transport_message(message) + + async def _handle_bot_speaking(self, frame: Frame): + message = None + if isinstance(frame, BotStartedSpeakingFrame): + message = RTVIBotStartedSpeakingMessage() + elif isinstance(frame, BotStoppedSpeakingFrame): + message = RTVIBotStoppedSpeakingMessage() + + if message: + await self._push_transport_message(message) + + +class RTVIUserTranscriptionProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)): + await self._handle_user_transcriptions(frame) + + async def _handle_user_transcriptions(self, frame: Frame): + message = None + if isinstance(frame, TranscriptionFrame): + message = RTVIUserTranscriptionMessage( + data=RTVIUserTranscriptionMessageData( + text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True + ) + ) + elif isinstance(frame, InterimTranscriptionFrame): + message = RTVIUserTranscriptionMessage( + data=RTVIUserTranscriptionMessageData( + text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False + ) + ) + + if message: + await self._push_transport_message(message) + + +class RTVIBotLLMProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + await self._push_transport_message(RTVIBotLLMStartedMessage()) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._push_transport_message(RTVIBotLLMStoppedMessage()) + + +class RTVIBotTTSProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TTSStartedFrame): + await self._push_transport_message(RTVIBotTTSStartedMessage()) + elif isinstance(frame, TTSStoppedFrame): + await self._push_transport_message(RTVIBotTTSStoppedMessage()) + + +class RTVIBotLLMTextProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._handle_text(frame) + + async def _handle_text(self, frame: TextFrame): + message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text)) + await self._push_transport_message(message) + + +class RTVIBotTTSTextProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._handle_text(frame) + + async def _handle_text(self, frame: TextFrame): + message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) + await self._push_transport_message(message) + + +class RTVIBotAudioProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, OutputAudioRawFrame): + await self._handle_audio(frame) + + async def _handle_audio(self, frame: OutputAudioRawFrame): + encoded = base64.b64encode(frame.audio).decode("utf-8") + message = RTVIBotAudioMessage( + data=RTVIAudioMessageData( + audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels + ) + ) + await self._push_transport_message(message) + + class RTVIProcessor(FrameProcessor): def __init__( self, @@ -394,22 +607,9 @@ class RTVIProcessor(FrameProcessor): # finish and the task finishes when EndFrame is processed. await self.push_frame(frame, direction) await self._stop(frame) - elif isinstance(frame, UserStartedSpeakingFrame) or isinstance( - frame, UserStoppedSpeakingFrame - ): - await self._handle_interruptions(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, BotStartedSpeakingFrame) or isinstance( - frame, BotStoppedSpeakingFrame - ): - await self._handle_bot_speaking(frame) - await self.push_frame(frame, direction) # Data frames - elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame): - await self._handle_transcriptions(frame) - await self.push_frame(frame, direction) elif isinstance(frame, TransportMessageFrame): - await self._message_queue.put(frame) + await self._handle_transport_message(frame) elif isinstance(frame, RTVIActionFrame): await self._action_queue.put(frame) # Other frames @@ -452,47 +652,6 @@ class RTVIProcessor(FrameProcessor): ) await self.push_frame(frame) - async def _handle_transcriptions(self, frame: Frame): - # TODO(aleix): Once we add support for using custom pipelines, the STTs will - # be in the pipeline after this processor. - - message = None - if isinstance(frame, TranscriptionFrame): - message = RTVITranscriptionMessage( - data=RTVITranscriptionMessageData( - text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True - ) - ) - elif isinstance(frame, InterimTranscriptionFrame): - message = RTVITranscriptionMessage( - data=RTVITranscriptionMessageData( - text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False - ) - ) - - if message: - await self._push_transport_message(message) - - async def _handle_interruptions(self, frame: Frame): - message = None - if isinstance(frame, UserStartedSpeakingFrame): - message = RTVIUserStartedSpeakingMessage() - elif isinstance(frame, UserStoppedSpeakingFrame): - message = RTVIUserStoppedSpeakingMessage() - - if message: - await self._push_transport_message(message) - - async def _handle_bot_speaking(self, frame: Frame): - message = None - if isinstance(frame, BotStartedSpeakingFrame): - message = RTVIBotStartedSpeakingMessage() - elif isinstance(frame, BotStoppedSpeakingFrame): - message = RTVIBotStoppedSpeakingMessage() - - if message: - await self._push_transport_message(message) - async def _action_task_handler(self): while True: try: