LLMContextResponseAggregator: add VAD emulation support

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-13 13:59:43 -08:00
parent a6502df72c
commit 5909dff423
5 changed files with 60 additions and 10 deletions

View File

@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added new frames `EmulateUserStartedSpeakingFrame` and
`EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior
without VAD being present or not being triggered.
- Added a new `audio_in_stream_on_start` field to `TransportParams`.
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.

View File

@@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame):
pass
@dataclass
class EmulateUserStartedSpeakingFrame(SystemFrame):
"""Emitted by internal processors upstream to emulate VAD behavior when a
user starts speaking."""
pass
@dataclass
class EmulateUserStoppedSpeakingFrame(SystemFrame):
"""Emitted by internal processors upstream to emulate VAD behavior when a
user stops speaking."""
pass
@dataclass
class BotInterruptionFrame(SystemFrame):
"""Emitted by when the bot should be interrupted. This will mainly cause the

View File

@@ -12,6 +12,8 @@ from typing import List
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
@@ -227,6 +229,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
self._seen_interim_results = False
self._user_speaking = False
self._last_user_speaking_time = 0
self._emulating_vad = False
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
@@ -314,6 +317,14 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
except asyncio.TimeoutError:
if not self._user_speaking:
await self.push_aggregation()
# If we are emulating VAD we still need to send the user stopped
# speaking frame.
if self._emulating_vad:
await self.push_frame(
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
)
self._emulating_vad = False
finally:
self._aggregation_event.clear()
@@ -325,7 +336,13 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
if not self._user_speaking:
diff_time = time.time() - self._last_user_speaking_time
if diff_time > self._bot_interruption_timeout:
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
# If we reach this case we received a transcription but VAD was
# not able to detect voice (e.g. when you whisper a short
# utterance). So, we need to emulate VAD (i.e. user
# start/stopped speaking).
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
self._emulating_vad = True
# Reset time so we don't interrupt again right away.
self._last_user_speaking_time = time.time()

View File

@@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
FilterUpdateSettingsFrame,
Frame,
@@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
logger.debug("Bot interruption")
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
await self._handle_bot_interruption(frame)
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
logger.debug("Emulating user started speaking")
await self._handle_user_interruption(UserStartedSpeakingFrame())
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
logger.debug("Emulating user stopped speaking")
await self._handle_user_interruption(UserStoppedSpeakingFrame())
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
@@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor):
# Handle interruptions
#
async def _handle_interruptions(self, frame: Frame):
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
logger.debug("Bot interruption")
if self.interruptions_allowed:
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
async def _handle_user_interruption(self, frame: Frame):
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
# Make sure we notify about interruptions quickly out-of-band.
@@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor):
frame = UserStoppedSpeakingFrame()
if frame:
await self._handle_interruptions(frame)
await self._handle_user_interruption(frame)
vad_state = new_vad_state
return vad_state

View File

@@ -9,7 +9,8 @@ import unittest
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
BotInterruptionFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -348,7 +349,7 @@ class BaseTestUserContextAggregator:
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [BotInterruptionFrame]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -370,7 +371,7 @@ class BaseTestUserContextAggregator:
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [BotInterruptionFrame]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -404,7 +405,7 @@ class BaseTestUserContextAggregator:
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_up_frames = [BotInterruptionFrame]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,