processors(rtvi): handle frames pushed from outside in order

This commit is contained in:
Aleix Conchillo Flaqué
2024-08-11 23:07:28 -07:00
parent 0d85c0085f
commit c4c2058df9

View File

@@ -219,6 +219,12 @@ class RTVIProcessor(FrameProcessor):
def register_service(self, service: RTVIService):
self._registered_services[service.name] = service
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
if isinstance(frame, SystemFrame):
await super().push_frame(frame, direction)
else:
await self._internal_push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -232,15 +238,24 @@ class RTVIProcessor(FrameProcessor):
# Control frames
elif isinstance(frame, StartFrame):
await self._start(frame)
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
await self._stop(frame)
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
# Data frames
elif isinstance(frame, TransportMessageFrame):
await self._handle_message(frame)
elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
await self._handle_transcriptions(frame)
await self.push_frame(frame, direction)
# Other frames
else:
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
async def cleanup(self):
if self._pipeline:
@@ -268,23 +283,12 @@ class RTVIProcessor(FrameProcessor):
while running:
try:
(frame, direction) = await self._frame_queue.get()
await self._handle_frame(frame, direction)
await super().push_frame(frame, direction)
self._frame_queue.task_done()
running = not isinstance(frame, EndFrame)
except asyncio.CancelledError:
break
async def _handle_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, TransportMessageFrame):
await self._handle_message(frame)
else:
await self.push_frame(frame, direction)
if isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
await self._handle_transcriptions(frame)
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_interruptions(frame)
async def _handle_transcriptions(self, frame: Frame):
# TODO(aleix): Once we add support for using custom pipelines, the STTs will
# be in the pipeline after this processor.