diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index c43d6ae12..55e91d7ff 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -395,6 +395,32 @@ class RTVIServerMessageFrame(SystemFrame): return f"{self.name}(data: {self.data})" +@dataclass +class RTVIObserverParams: + """ + Parameters for configuring RTVI Observer behavior. + + Attributes: + bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent. + bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent. + bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent. + user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent. + user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent. + user_transcription_enabled (bool): Indicates if user's transcription messages should be sent. + metrics_enabled (bool): Indicates if metrics messages should be sent. + errors_enabled (bool): Indicates if errors messages should be sent. + """ + + bot_llm_enabled: bool = True + bot_tts_enabled: bool = True + bot_speaking_enabled: bool = True + user_llm_enabled: bool = True + user_speaking_enabled: bool = True + user_transcription_enabled: bool = True + metrics_enabled: bool = True + errors_enabled: bool = True + + class RTVIObserver(BaseObserver): """Pipeline frame observer for RTVI server message handling. @@ -407,14 +433,17 @@ class RTVIObserver(BaseObserver): are handled by the RTVIProcessor. Args: - rtvi (FrameProcessor): The RTVI processor to push frames to. + rtvi (RTVIProcessor): The RTVI processor to push frames to. + params (RTVIObserverParams): Settings to enable/disable specific messages. """ - def __init__(self, rtvi: FrameProcessor): + def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()): super().__init__() self._rtvi = rtvi + self._params = params self._bot_transcription = "" self._frames_seen = set() + rtvi.set_errors_enabled(self._params.errors_enabled) async def on_push_frame( self, @@ -441,35 +470,41 @@ class RTVIObserver(BaseObserver): # again the next time we see the frame. mark_as_seen = True - if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)): + if ( + isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)) + and self._params.user_speaking_enabled + ): await self._handle_interruptions(frame) - elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and ( - direction == FrameDirection.UPSTREAM + elif ( + isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) + and (direction == FrameDirection.UPSTREAM) + and self._params.bot_speaking_enabled ): await self._handle_bot_speaking(frame) - elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)): + elif ( + isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)) + and self._params.user_transcription_enabled + ): await self._handle_user_transcriptions(frame) - elif isinstance(frame, OpenAILLMContextFrame): + elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled: await self._handle_context(frame) - elif isinstance(frame, UserStartedSpeakingFrame): - await self._push_bot_transcription() - elif isinstance(frame, LLMFullResponseStartFrame): + elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled: await self.push_transport_message_urgent(RTVIBotLLMStartedMessage()) - elif isinstance(frame, LLMFullResponseEndFrame): + elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled: await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage()) - elif isinstance(frame, LLMTextFrame): + elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled: await self._handle_llm_text_frame(frame) - elif isinstance(frame, TTSStartedFrame): + elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled: await self.push_transport_message_urgent(RTVIBotTTSStartedMessage()) - elif isinstance(frame, TTSStoppedFrame): + elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled: await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage()) - elif isinstance(frame, TTSTextFrame): + elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled: if isinstance(src, BaseOutputTransport): message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) await self.push_transport_message_urgent(message) else: mark_as_seen = False - elif isinstance(frame, MetricsFrame): + elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled: await self._handle_metrics(frame) elif isinstance(frame, RTVIServerMessageFrame): message = RTVIServerMessage(data=frame.data) @@ -612,6 +647,7 @@ class RTVIProcessor(FrameProcessor): self._bot_ready = False self._client_ready = False self._client_ready_id = "" + self._errors_enabled = True self._registered_actions: Dict[str, RTVIAction] = {} self._registered_services: Dict[str, RTVIService] = {} @@ -651,12 +687,14 @@ class RTVIProcessor(FrameProcessor): await self._update_config(self._config, False) await self._send_bot_ready() + def set_errors_enabled(self, enabled: bool): + self._errors_enabled = enabled + async def interrupt_bot(self): await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) async def send_error(self, error: str): - message = RTVIError(data=RTVIErrorData(error=error, fatal=False)) - await self._push_transport_message(message) + await self._send_error_frame(ErrorFrame(error=error)) async def handle_message(self, message: RTVIMessage): await self._message_queue.put(message) @@ -915,12 +953,14 @@ class RTVIProcessor(FrameProcessor): await self._push_transport_message(message) async def _send_error_frame(self, frame: ErrorFrame): - message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal)) - await self._push_transport_message(message) + if self._errors_enabled: + message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal)) + await self._push_transport_message(message) async def _send_error_response(self, id: str, error: str): - message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error)) - await self._push_transport_message(message) + if self._errors_enabled: + message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error)) + await self._push_transport_message(message) def _action_id(self, service: str, action: str) -> str: return f"{service}:{action}"