introduce StartInterruptionFrame and StopInterruptionFrame

This commit is contained in:
Aleix Conchillo Flaqué
2024-05-17 11:07:12 -07:00
parent f62fe059b1
commit 0bef44c2ff
3 changed files with 38 additions and 9 deletions

View File

@@ -216,6 +216,28 @@ class StopTaskFrame(SystemFrame):
pass
@dataclass
class StartInterruptionFrame(SystemFrame):
"""Emitted by VAD to indicate that a user has started speaking (i.e. is
interruption). This is similar to UserStartedSpeakingFrame except that it
should be pushed concurrently with other frames (so the order is not
guaranteed).
"""
pass
@dataclass
class StopInterruptionFrame(SystemFrame):
"""Emitted by VAD to indicate that a user has stopped speaking (i.e. no more
interruptions). This is similar to UserStoppedSpeakingFrame except that it
should be pushed concurrently with other frames (so the order is not
guaranteed).
"""
pass
#
# Control frames
#

View File

@@ -14,6 +14,8 @@ from pipecat.frames.frames import (
StartFrame,
EndFrame,
Frame,
StartInterruptionFrame,
StopInterruptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.transports.base_transport import TransportParams
@@ -123,9 +125,14 @@ class BaseInputTransport(FrameProcessor):
#
async def _handle_interruptions(self, frame: Frame):
if self._allow_interruptions and isinstance(frame, UserStartedSpeakingFrame):
self._push_frame_task.cancel()
self._create_push_task()
if self._allow_interruptions:
# Make sure we notify about interruptions quickly out-of-band
if isinstance(frame, UserStartedSpeakingFrame):
self._push_frame_task.cancel()
self._create_push_task()
await self.push_frame(StartInterruptionFrame())
elif isinstance(frame, UserStoppedSpeakingFrame):
await self.push_frame(StopInterruptionFrame())
await self._internal_push_frame(frame)
#

View File

@@ -23,9 +23,9 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
ImageRawFrame,
TransportMessageFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
StartInterruptionFrame,
StopInterruptionFrame,
TransportMessageFrame)
from pipecat.transports.base_transport import TransportParams
from loguru import logger
@@ -104,7 +104,7 @@ class BaseOutputTransport(FrameProcessor):
elif isinstance(frame, CancelFrame):
await self.stop()
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame):
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif self._frame_managed_by_sink(frame):
@@ -129,9 +129,9 @@ class BaseOutputTransport(FrameProcessor):
if not self._allow_interruptions:
return
if isinstance(frame, UserStartedSpeakingFrame):
if isinstance(frame, StartInterruptionFrame):
self._is_interrupted.set()
elif isinstance(frame, UserStoppedSpeakingFrame):
elif isinstance(frame, StopInterruptionFrame):
self._is_interrupted.clear()
def _sink_thread_handler(self):