processors(rtvi): all transport messages should be urgent
This commit is contained in:
@@ -287,8 +287,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(message=error))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def handle_function_call(
|
||||
self,
|
||||
@@ -302,14 +301,12 @@ class RTVIProcessor(FrameProcessor):
|
||||
tool_call_id=tool_call_id,
|
||||
args=arguments)
|
||||
message = RTVILLMFunctionCallMessage(data=fn)
|
||||
frame = TransportMessageFrame(message=message.model_dump())
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
|
||||
async def handle_function_call_start(self, function_name: str):
|
||||
fn = RTVILLMFunctionCallStartMessageData(function_name=function_name)
|
||||
message = RTVILLMFunctionCallStartMessage(data=fn)
|
||||
frame = TransportMessageFrame(message=message.model_dump())
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if isinstance(frame, SystemFrame):
|
||||
@@ -388,6 +385,12 @@ class RTVIProcessor(FrameProcessor):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none),
|
||||
urgent=True)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_transcriptions(self, frame: Frame):
|
||||
# TODO(aleix): Once we add support for using custom pipelines, the STTs will
|
||||
# be in the pipeline after this processor.
|
||||
@@ -409,8 +412,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
final=False))
|
||||
|
||||
if message:
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
message = None
|
||||
@@ -420,8 +422,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _message_task_handler(self):
|
||||
while True:
|
||||
@@ -471,19 +472,16 @@ class RTVIProcessor(FrameProcessor):
|
||||
async def _handle_describe_config(self, request_id: str):
|
||||
services = list(self._registered_services.values())
|
||||
message = RTVIDescribeConfig(id=request_id, data=RTVIDescribeConfigData(config=services))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_describe_actions(self, request_id: str):
|
||||
actions = list(self._registered_actions.values())
|
||||
message = RTVIDescribeActions(id=request_id, data=RTVIDescribeActionsData(actions=actions))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _handle_get_config(self, request_id: str):
|
||||
message = RTVIConfigResponse(id=request_id, data=self._config)
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _update_config_option(self, service: str, config: RTVIServiceOptionConfig):
|
||||
for service_config in self._config.config:
|
||||
@@ -540,8 +538,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
arguments[arg.name] = arg.value
|
||||
result = await action.handler(self, action.service, arguments)
|
||||
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _transport_on_joined(self, transport, participant):
|
||||
self._transport_joined = True
|
||||
@@ -558,13 +555,11 @@ class RTVIProcessor(FrameProcessor):
|
||||
data=RTVIBotReadyData(
|
||||
version=RTVI_PROTOCOL_VERSION,
|
||||
config=self._config.config))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_response(self, id: str, error: str):
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _action_id(self, service: str, action: str) -> str:
|
||||
return f"{service}:{action}"
|
||||
|
||||
Reference in New Issue
Block a user