LLMContextResponseAggregator: add VAD emulation support
This commit is contained in:
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new frames `EmulateUserStartedSpeakingFrame` and
|
||||
`EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior
|
||||
without VAD being present or not being triggered.
|
||||
|
||||
- Added a new `audio_in_stream_on_start` field to `TransportParams`.
|
||||
|
||||
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.
|
||||
|
||||
@@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStartedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user starts speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user stops speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Emitted by when the bot should be interrupted. This will mainly cause the
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import List
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -227,6 +229,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
self._seen_interim_results = False
|
||||
self._user_speaking = False
|
||||
self._last_user_speaking_time = 0
|
||||
self._emulating_vad = False
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
@@ -314,6 +317,14 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
await self.push_aggregation()
|
||||
|
||||
# If we are emulating VAD we still need to send the user stopped
|
||||
# speaking frame.
|
||||
if self._emulating_vad:
|
||||
await self.push_frame(
|
||||
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
|
||||
)
|
||||
self._emulating_vad = False
|
||||
finally:
|
||||
self._aggregation_event.clear()
|
||||
|
||||
@@ -325,7 +336,13 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
if not self._user_speaking:
|
||||
diff_time = time.time() - self._last_user_speaking_time
|
||||
if diff_time > self._bot_interruption_timeout:
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
# If we reach this case we received a transcription but VAD was
|
||||
# not able to detect voice (e.g. when you whisper a short
|
||||
# utterance). So, we need to emulate VAD (i.e. user
|
||||
# start/stopped speaking).
|
||||
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
self._emulating_vad = True
|
||||
|
||||
# Reset time so we don't interrupt again right away.
|
||||
self._last_user_speaking_time = time.time()
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
FilterUpdateSettingsFrame,
|
||||
Frame,
|
||||
@@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
await self._handle_bot_interruption(frame)
|
||||
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
|
||||
logger.debug("Emulating user started speaking")
|
||||
await self._handle_user_interruption(UserStartedSpeakingFrame())
|
||||
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
|
||||
logger.debug("Emulating user stopped speaking")
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
if self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
|
||||
async def _handle_user_interruption(self, frame: Frame):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.debug("User started speaking")
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
@@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
|
||||
if frame:
|
||||
await self._handle_interruptions(frame)
|
||||
await self._handle_user_interruption(frame)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
@@ -9,7 +9,8 @@ import unittest
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -348,7 +349,7 @@ class BaseTestUserContextAggregator:
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -370,7 +371,7 @@ class BaseTestUserContextAggregator:
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -404,7 +405,7 @@ class BaseTestUserContextAggregator:
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
|
||||
Reference in New Issue
Block a user