diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index ed7ac75e8..9d5b15b36 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -216,6 +216,28 @@ class StopTaskFrame(SystemFrame): pass +@dataclass +class StartInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has started speaking (i.e. is + interruption). This is similar to UserStartedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + +@dataclass +class StopInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has stopped speaking (i.e. no more + interruptions). This is similar to UserStoppedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + # # Control frames # diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 92473c769..f3ee492d3 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -14,6 +14,8 @@ from pipecat.frames.frames import ( StartFrame, EndFrame, Frame, + StartInterruptionFrame, + StopInterruptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame) from pipecat.transports.base_transport import TransportParams @@ -123,9 +125,14 @@ class BaseInputTransport(FrameProcessor): # async def _handle_interruptions(self, frame: Frame): - if self._allow_interruptions and isinstance(frame, UserStartedSpeakingFrame): - self._push_frame_task.cancel() - self._create_push_task() + if self._allow_interruptions: + # Make sure we notify about interruptions quickly out-of-band + if isinstance(frame, UserStartedSpeakingFrame): + self._push_frame_task.cancel() + self._create_push_task() + await self.push_frame(StartInterruptionFrame()) + elif isinstance(frame, UserStoppedSpeakingFrame): + await self.push_frame(StopInterruptionFrame()) await self._internal_push_frame(frame) # diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index ff9040406..7d18f1265 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -23,9 +23,9 @@ from pipecat.frames.frames import ( EndFrame, Frame, ImageRawFrame, - TransportMessageFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame) + StartInterruptionFrame, + StopInterruptionFrame, + TransportMessageFrame) from pipecat.transports.base_transport import TransportParams from loguru import logger @@ -104,7 +104,7 @@ class BaseOutputTransport(FrameProcessor): elif isinstance(frame, CancelFrame): await self.stop() await self.push_frame(frame, direction) - elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame): + elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): await self._handle_interruptions(frame) await self.push_frame(frame, direction) elif self._frame_managed_by_sink(frame): @@ -129,9 +129,9 @@ class BaseOutputTransport(FrameProcessor): if not self._allow_interruptions: return - if isinstance(frame, UserStartedSpeakingFrame): + if isinstance(frame, StartInterruptionFrame): self._is_interrupted.set() - elif isinstance(frame, UserStoppedSpeakingFrame): + elif isinstance(frame, StopInterruptionFrame): self._is_interrupted.clear() def _sink_thread_handler(self):