processors(rtvi): all transport messages should be urgent

This commit is contained in:
Aleix Conchillo Flaqué
2024-08-15 08:34:08 -07:00
parent 425a730d7c
commit b2a7ff6fd3

View File

@@ -287,8 +287,7 @@ class RTVIProcessor(FrameProcessor):
async def send_error(self, error: str):
message = RTVIError(data=RTVIErrorData(message=error))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def handle_function_call(
self,
@@ -302,14 +301,12 @@ class RTVIProcessor(FrameProcessor):
tool_call_id=tool_call_id,
args=arguments)
message = RTVILLMFunctionCallMessage(data=fn)
frame = TransportMessageFrame(message=message.model_dump())
await self.push_frame(frame)
await self._push_transport_message(message, exclude_none=False)
async def handle_function_call_start(self, function_name: str):
fn = RTVILLMFunctionCallStartMessageData(function_name=function_name)
message = RTVILLMFunctionCallStartMessage(data=fn)
frame = TransportMessageFrame(message=message.model_dump())
await self.push_frame(frame)
await self._push_transport_message(message, exclude_none=False)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
if isinstance(frame, SystemFrame):
@@ -388,6 +385,12 @@ class RTVIProcessor(FrameProcessor):
except asyncio.CancelledError:
break
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
frame = TransportMessageFrame(
message=model.model_dump(exclude_none=exclude_none),
urgent=True)
await self.push_frame(frame)
async def _handle_transcriptions(self, frame: Frame):
# TODO(aleix): Once we add support for using custom pipelines, the STTs will
# be in the pipeline after this processor.
@@ -409,8 +412,7 @@ class RTVIProcessor(FrameProcessor):
final=False))
if message:
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _handle_interruptions(self, frame: Frame):
message = None
@@ -420,8 +422,7 @@ class RTVIProcessor(FrameProcessor):
message = RTVIUserStoppedSpeakingMessage()
if message:
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _message_task_handler(self):
while True:
@@ -471,19 +472,16 @@ class RTVIProcessor(FrameProcessor):
async def _handle_describe_config(self, request_id: str):
services = list(self._registered_services.values())
message = RTVIDescribeConfig(id=request_id, data=RTVIDescribeConfigData(config=services))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _handle_describe_actions(self, request_id: str):
actions = list(self._registered_actions.values())
message = RTVIDescribeActions(id=request_id, data=RTVIDescribeActionsData(actions=actions))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _handle_get_config(self, request_id: str):
message = RTVIConfigResponse(id=request_id, data=self._config)
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
def _update_config_option(self, service: str, config: RTVIServiceOptionConfig):
for service_config in self._config.config:
@@ -540,8 +538,7 @@ class RTVIProcessor(FrameProcessor):
arguments[arg.name] = arg.value
result = await action.handler(self, action.service, arguments)
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _transport_on_joined(self, transport, participant):
self._transport_joined = True
@@ -558,13 +555,11 @@ class RTVIProcessor(FrameProcessor):
data=RTVIBotReadyData(
version=RTVI_PROTOCOL_VERSION,
config=self._config.config))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
async def _send_error_response(self, id: str, error: str):
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
await self._push_transport_message(message)
def _action_id(self, service: str, action: str) -> str:
return f"{service}:{action}"