From b2a7ff6fd3ec96dec7fda75be7bdf85d9eec5660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 15 Aug 2024 08:34:08 -0700 Subject: [PATCH] processors(rtvi): all transport messages should be urgent --- src/pipecat/processors/frameworks/rtvi.py | 39 ++++++++++------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index ac74c6fcf..a0d3e7e39 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -287,8 +287,7 @@ class RTVIProcessor(FrameProcessor): async def send_error(self, error: str): message = RTVIError(data=RTVIErrorData(message=error)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def handle_function_call( self, @@ -302,14 +301,12 @@ class RTVIProcessor(FrameProcessor): tool_call_id=tool_call_id, args=arguments) message = RTVILLMFunctionCallMessage(data=fn) - frame = TransportMessageFrame(message=message.model_dump()) - await self.push_frame(frame) + await self._push_transport_message(message, exclude_none=False) async def handle_function_call_start(self, function_name: str): fn = RTVILLMFunctionCallStartMessageData(function_name=function_name) message = RTVILLMFunctionCallStartMessage(data=fn) - frame = TransportMessageFrame(message=message.model_dump()) - await self.push_frame(frame) + await self._push_transport_message(message, exclude_none=False) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): if isinstance(frame, SystemFrame): @@ -388,6 +385,12 @@ class RTVIProcessor(FrameProcessor): except asyncio.CancelledError: break + async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True): + frame = TransportMessageFrame( + message=model.model_dump(exclude_none=exclude_none), + urgent=True) + await self.push_frame(frame) + async def _handle_transcriptions(self, frame: Frame): # TODO(aleix): Once we add support for using custom pipelines, the STTs will # be in the pipeline after this processor. @@ -409,8 +412,7 @@ class RTVIProcessor(FrameProcessor): final=False)) if message: - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _handle_interruptions(self, frame: Frame): message = None @@ -420,8 +422,7 @@ class RTVIProcessor(FrameProcessor): message = RTVIUserStoppedSpeakingMessage() if message: - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _message_task_handler(self): while True: @@ -471,19 +472,16 @@ class RTVIProcessor(FrameProcessor): async def _handle_describe_config(self, request_id: str): services = list(self._registered_services.values()) message = RTVIDescribeConfig(id=request_id, data=RTVIDescribeConfigData(config=services)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _handle_describe_actions(self, request_id: str): actions = list(self._registered_actions.values()) message = RTVIDescribeActions(id=request_id, data=RTVIDescribeActionsData(actions=actions)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _handle_get_config(self, request_id: str): message = RTVIConfigResponse(id=request_id, data=self._config) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) def _update_config_option(self, service: str, config: RTVIServiceOptionConfig): for service_config in self._config.config: @@ -540,8 +538,7 @@ class RTVIProcessor(FrameProcessor): arguments[arg.name] = arg.value result = await action.handler(self, action.service, arguments) message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _transport_on_joined(self, transport, participant): self._transport_joined = True @@ -558,13 +555,11 @@ class RTVIProcessor(FrameProcessor): data=RTVIBotReadyData( version=RTVI_PROTOCOL_VERSION, config=self._config.config)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) async def _send_error_response(self, id: str, error: str): message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error)) - frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) - await self.push_frame(frame) + await self._push_transport_message(message) def _action_id(self, service: str, action: str) -> str: return f"{service}:{action}"