LLMUserContextAggregator: interrupt the bot if VAD happened a while back

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-12 18:53:51 -08:00
parent 91a628d1ba
commit 4cbcfe2b0b
2 changed files with 120 additions and 15 deletions

View File

@@ -5,10 +5,12 @@
#
import asyncio
import time
from abc import abstractmethod
from typing import List
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EndFrame,
Frame,
@@ -171,12 +173,20 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
class LLMUserContextAggregator(LLMContextResponseAggregator):
def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs):
def __init__(
self,
context: OpenAILLMContext,
aggregation_timeout: float = 1.0,
bot_interruption_timeout: float = 2.0,
**kwargs,
):
super().__init__(context=context, role="user", **kwargs)
self._aggregation_timeout = aggregation_timeout
self._bot_interruption_timeout = bot_interruption_timeout
self._seen_interim_results = False
self._user_speaking = False
self._last_user_speaking_time = 0
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
@@ -219,43 +229,63 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
await self.push_frame(frame, direction)
async def _start(self, frame: StartFrame):
self._aggregation_task = self.create_task(self._aggregation_task_handler())
self._create_aggregation_task()
async def _stop(self, frame: EndFrame):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
await self._cancel_aggregation_task()
async def _cancel(self, frame: CancelFrame):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
await self._cancel_aggregation_task()
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
self._user_speaking = True
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._last_user_speaking_time = time.time()
self._user_speaking = False
if not self._seen_interim_results:
await self.push_aggregation()
async def _handle_transcription(self, frame: TranscriptionFrame):
self._aggregation += frame.text
# We just got our final result, so let's reset interim results.
# We just got a final result, so let's reset interim results.
self._seen_interim_results = False
# Wakeup our task.
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
self._seen_interim_results = True
# Reset aggregation timer.
self._aggregation_event.set()
def _create_aggregation_task(self):
self._aggregation_task = self.create_task(self._aggregation_task_handler())
async def _cancel_aggregation_task(self):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _aggregation_task_handler(self):
while True:
await self._aggregation_event.wait()
await asyncio.sleep(self._aggregation_timeout)
if not self._user_speaking:
await self.push_aggregation()
self._aggregation_event.clear()
try:
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
await self._maybe_push_bot_interruption()
except asyncio.TimeoutError:
if not self._user_speaking:
await self.push_aggregation()
finally:
self._aggregation_event.clear()
async def _maybe_push_bot_interruption(self):
"""If the user stopped speaking a while back and we got a transcription
frame we might want to interrupt the bot.
"""
if not self._user_speaking:
diff_time = time.time() - self._last_user_speaking_time
if diff_time > self._bot_interruption_timeout:
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
@@ -279,6 +309,12 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
await self._handle_text(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
self.add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
else:
await self.push_frame(frame, direction)

View File

@@ -7,6 +7,7 @@
import unittest
from pipecat.frames.frames import (
BotInterruptionFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -27,6 +28,8 @@ from pipecat.tests.utils import SleepFrame, run_test
AGGREGATION_TIMEOUT = 0.1
AGGREGATION_SLEEP = 0.15
BOT_INTERRUPTION_TIMEOUT = 0.2
BOT_INTERRUPTION_SLEEP = 0.25
class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase):
@@ -274,6 +277,72 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase):
)
assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?"
async def test_t(self):
context = OpenAILLMContext()
aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_up_frames = [BotInterruptionFrame]
(received_down, _) = await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
assert received_down[-1].context.messages[0]["content"] == "Hello!"
async def test_it(self):
context = OpenAILLMContext()
aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_up_frames = [BotInterruptionFrame]
(received_down, _) = await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!"
async def test_sie_delay_it(self):
context = OpenAILLMContext()
aggregator = LLMUserContextAggregator(
context,
aggregation_timeout=AGGREGATION_TIMEOUT,
bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT,
)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
SleepFrame(BOT_INTERRUPTION_SLEEP),
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
expected_up_frames = [BotInterruptionFrame]
(received_down, _) = await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
assert received_down[-1].context.messages[0]["content"] == "How are you?"
class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase):
async def test_empty(self):