Merge pull request #1688 from pipecat-ai/aleix/add-rtvi-observer-params
RTVIObserver: add RTVIObserverParams to configure what to send
This commit is contained in:
@@ -395,6 +395,32 @@ class RTVIServerMessageFrame(SystemFrame):
|
||||
return f"{self.name}(data: {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIObserverParams:
|
||||
"""
|
||||
Parameters for configuring RTVI Observer behavior.
|
||||
|
||||
Attributes:
|
||||
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
|
||||
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
|
||||
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
|
||||
metrics_enabled (bool): Indicates if metrics messages should be sent.
|
||||
errors_enabled (bool): Indicates if errors messages should be sent.
|
||||
"""
|
||||
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
user_llm_enabled: bool = True
|
||||
user_speaking_enabled: bool = True
|
||||
user_transcription_enabled: bool = True
|
||||
metrics_enabled: bool = True
|
||||
errors_enabled: bool = True
|
||||
|
||||
|
||||
class RTVIObserver(BaseObserver):
|
||||
"""Pipeline frame observer for RTVI server message handling.
|
||||
|
||||
@@ -407,14 +433,17 @@ class RTVIObserver(BaseObserver):
|
||||
are handled by the RTVIProcessor.
|
||||
|
||||
Args:
|
||||
rtvi (FrameProcessor): The RTVI processor to push frames to.
|
||||
rtvi (RTVIProcessor): The RTVI processor to push frames to.
|
||||
params (RTVIObserverParams): Settings to enable/disable specific messages.
|
||||
"""
|
||||
|
||||
def __init__(self, rtvi: FrameProcessor):
|
||||
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
|
||||
super().__init__()
|
||||
self._rtvi = rtvi
|
||||
self._params = params
|
||||
self._bot_transcription = ""
|
||||
self._frames_seen = set()
|
||||
rtvi.set_errors_enabled(self._params.errors_enabled)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
@@ -441,35 +470,41 @@ class RTVIObserver(BaseObserver):
|
||||
# again the next time we see the frame.
|
||||
mark_as_seen = True
|
||||
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
|
||||
if (
|
||||
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
|
||||
and self._params.user_speaking_enabled
|
||||
):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
|
||||
direction == FrameDirection.UPSTREAM
|
||||
elif (
|
||||
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.UPSTREAM)
|
||||
and self._params.bot_speaking_enabled
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
elif (
|
||||
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
|
||||
and self._params.user_transcription_enabled
|
||||
):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
|
||||
await self._handle_context(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._push_bot_transcription()
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
|
||||
await self._handle_llm_text_frame(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.push_transport_message_urgent(message)
|
||||
else:
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
@@ -612,6 +647,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._bot_ready = False
|
||||
self._client_ready = False
|
||||
self._client_ready_id = ""
|
||||
self._errors_enabled = True
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -651,12 +687,14 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._update_config(self._config, False)
|
||||
await self._send_bot_ready()
|
||||
|
||||
def set_errors_enabled(self, enabled: bool):
|
||||
self._errors_enabled = enabled
|
||||
|
||||
async def interrupt_bot(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
|
||||
await self._push_transport_message(message)
|
||||
await self._send_error_frame(ErrorFrame(error=error))
|
||||
|
||||
async def handle_message(self, message: RTVIMessage):
|
||||
await self._message_queue.put(message)
|
||||
@@ -915,12 +953,14 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_frame(self, frame: ErrorFrame):
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_response(self, id: str, error: str):
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _action_id(self, service: str, action: str) -> str:
|
||||
return f"{service}:{action}"
|
||||
|
||||
Reference in New Issue
Block a user