Merge pull request #1688 from pipecat-ai/aleix/add-rtvi-observer-params

RTVIObserver: add RTVIObserverParams to configure what to send
This commit is contained in:
Aleix Conchillo Flaqué
2025-04-30 15:11:18 -07:00
committed by GitHub

View File

@@ -395,6 +395,32 @@ class RTVIServerMessageFrame(SystemFrame):
return f"{self.name}(data: {self.data})"
@dataclass
class RTVIObserverParams:
"""
Parameters for configuring RTVI Observer behavior.
Attributes:
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
metrics_enabled (bool): Indicates if metrics messages should be sent.
errors_enabled (bool): Indicates if errors messages should be sent.
"""
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
user_llm_enabled: bool = True
user_speaking_enabled: bool = True
user_transcription_enabled: bool = True
metrics_enabled: bool = True
errors_enabled: bool = True
class RTVIObserver(BaseObserver):
"""Pipeline frame observer for RTVI server message handling.
@@ -407,14 +433,17 @@ class RTVIObserver(BaseObserver):
are handled by the RTVIProcessor.
Args:
rtvi (FrameProcessor): The RTVI processor to push frames to.
rtvi (RTVIProcessor): The RTVI processor to push frames to.
params (RTVIObserverParams): Settings to enable/disable specific messages.
"""
def __init__(self, rtvi: FrameProcessor):
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
super().__init__()
self._rtvi = rtvi
self._params = params
self._bot_transcription = ""
self._frames_seen = set()
rtvi.set_errors_enabled(self._params.errors_enabled)
async def on_push_frame(
self,
@@ -441,35 +470,41 @@ class RTVIObserver(BaseObserver):
# again the next time we see the frame.
mark_as_seen = True
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
if (
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
and self._params.user_speaking_enabled
):
await self._handle_interruptions(frame)
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
direction == FrameDirection.UPSTREAM
elif (
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
and (direction == FrameDirection.UPSTREAM)
and self._params.bot_speaking_enabled
):
await self._handle_bot_speaking(frame)
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
elif (
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
and self._params.user_transcription_enabled
):
await self._handle_user_transcriptions(frame)
elif isinstance(frame, OpenAILLMContextFrame):
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
await self._handle_context(frame)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._push_bot_transcription()
elif isinstance(frame, LLMFullResponseStartFrame):
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
elif isinstance(frame, LLMFullResponseEndFrame):
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
elif isinstance(frame, LLMTextFrame):
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
await self._handle_llm_text_frame(frame)
elif isinstance(frame, TTSStartedFrame):
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame):
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
elif isinstance(frame, TTSTextFrame):
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.push_transport_message_urgent(message)
else:
mark_as_seen = False
elif isinstance(frame, MetricsFrame):
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
message = RTVIServerMessage(data=frame.data)
@@ -612,6 +647,7 @@ class RTVIProcessor(FrameProcessor):
self._bot_ready = False
self._client_ready = False
self._client_ready_id = ""
self._errors_enabled = True
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -651,12 +687,14 @@ class RTVIProcessor(FrameProcessor):
await self._update_config(self._config, False)
await self._send_bot_ready()
def set_errors_enabled(self, enabled: bool):
self._errors_enabled = enabled
async def interrupt_bot(self):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def send_error(self, error: str):
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
await self._push_transport_message(message)
await self._send_error_frame(ErrorFrame(error=error))
async def handle_message(self, message: RTVIMessage):
await self._message_queue.put(message)
@@ -915,12 +953,14 @@ class RTVIProcessor(FrameProcessor):
await self._push_transport_message(message)
async def _send_error_frame(self, frame: ErrorFrame):
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
if self._errors_enabled:
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
async def _send_error_response(self, id: str, error: str):
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
if self._errors_enabled:
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
def _action_id(self, service: str, action: str) -> str:
return f"{service}:{action}"