diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b1a3bafd..ca7d08e3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new frames `EmulateUserStartedSpeakingFrame` and + `EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior + without VAD being present or not being triggered. + - Added a new `audio_in_stream_on_start` field to `TransportParams`. - Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 09d0a93b0..c9b812a1c 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame): pass +@dataclass +class EmulateUserStartedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user starts speaking.""" + + pass + + +@dataclass +class EmulateUserStoppedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user stops speaking.""" + + pass + + @dataclass class BotInterruptionFrame(SystemFrame): """Emitted by when the bot should be interrupted. This will mainly cause the diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index bb2e28ca7..950c155e6 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -12,6 +12,8 @@ from typing import List from pipecat.frames.frames import ( BotInterruptionFrame, CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, EndFrame, Frame, InterimTranscriptionFrame, @@ -227,6 +229,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): self._seen_interim_results = False self._user_speaking = False self._last_user_speaking_time = 0 + self._emulating_vad = False self._aggregation_event = asyncio.Event() self._aggregation_task = None @@ -314,6 +317,14 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): except asyncio.TimeoutError: if not self._user_speaking: await self.push_aggregation() + + # If we are emulating VAD we still need to send the user stopped + # speaking frame. + if self._emulating_vad: + await self.push_frame( + EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM + ) + self._emulating_vad = False finally: self._aggregation_event.clear() @@ -325,7 +336,13 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): if not self._user_speaking: diff_time = time.time() - self._last_user_speaking_time if diff_time > self._bot_interruption_timeout: - await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + # If we reach this case we received a transcription but VAD was + # not able to detect voice (e.g. when you whisper a short + # utterance). So, we need to emulate VAD (i.e. user + # start/stopped speaking). + await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) + self._emulating_vad = True + # Reset time so we don't interrupt again right away. self._last_user_speaking_time = time.time() diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 6bfe86001..42eb162da 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState from pipecat.frames.frames import ( BotInterruptionFrame, CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, EndFrame, FilterUpdateSettingsFrame, Frame, @@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor): await self.cancel(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotInterruptionFrame): - logger.debug("Bot interruption") - await self._start_interruption() - await self.push_frame(StartInterruptionFrame()) + await self._handle_bot_interruption(frame) + elif isinstance(frame, EmulateUserStartedSpeakingFrame): + logger.debug("Emulating user started speaking") + await self._handle_user_interruption(UserStartedSpeakingFrame()) + elif isinstance(frame, EmulateUserStoppedSpeakingFrame): + logger.debug("Emulating user stopped speaking") + await self._handle_user_interruption(UserStoppedSpeakingFrame()) # All other system frames elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) @@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor): # Handle interruptions # - async def _handle_interruptions(self, frame: Frame): + async def _handle_bot_interruption(self, frame: BotInterruptionFrame): + logger.debug("Bot interruption") + if self.interruptions_allowed: + await self._start_interruption() + await self.push_frame(StartInterruptionFrame()) + + async def _handle_user_interruption(self, frame: Frame): if isinstance(frame, UserStartedSpeakingFrame): logger.debug("User started speaking") # Make sure we notify about interruptions quickly out-of-band. @@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor): frame = UserStoppedSpeakingFrame() if frame: - await self._handle_interruptions(frame) + await self._handle_user_interruption(frame) vad_state = new_vad_state return vad_state diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 7190afce2..d4b8c35ce 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -9,7 +9,8 @@ import unittest import google.ai.generativelanguage as glm from pipecat.frames.frames import ( - BotInterruptionFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -348,7 +349,7 @@ class BaseTestUserContextAggregator: SleepFrame(sleep=AGGREGATION_SLEEP), ] expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send, @@ -370,7 +371,7 @@ class BaseTestUserContextAggregator: SleepFrame(sleep=AGGREGATION_SLEEP), ] expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send, @@ -404,7 +405,7 @@ class BaseTestUserContextAggregator: UserStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES, ] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send,