LLMUserContextAggregator: interrupt the bot if VAD happened a while back
This commit is contained in:
@@ -5,10 +5,12 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -171,12 +173,20 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
|
||||
|
||||
class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
aggregation_timeout: float = 1.0,
|
||||
bot_interruption_timeout: float = 2.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._aggregation_timeout = aggregation_timeout
|
||||
self._bot_interruption_timeout = bot_interruption_timeout
|
||||
|
||||
self._seen_interim_results = False
|
||||
self._user_speaking = False
|
||||
self._last_user_speaking_time = 0
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
@@ -219,43 +229,63 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
self._create_aggregation_task()
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._last_user_speaking_time = time.time()
|
||||
self._user_speaking = False
|
||||
if not self._seen_interim_results:
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
self._aggregation += frame.text
|
||||
# We just got our final result, so let's reset interim results.
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Wakeup our task.
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
def _create_aggregation_task(self):
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
async def _cancel_aggregation_task(self):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
await self._aggregation_event.wait()
|
||||
await asyncio.sleep(self._aggregation_timeout)
|
||||
if not self._user_speaking:
|
||||
await self.push_aggregation()
|
||||
self._aggregation_event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
|
||||
await self._maybe_push_bot_interruption()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
await self.push_aggregation()
|
||||
finally:
|
||||
self._aggregation_event.clear()
|
||||
|
||||
async def _maybe_push_bot_interruption(self):
|
||||
"""If the user stopped speaking a while back and we got a transcription
|
||||
frame we might want to interrupt the bot.
|
||||
|
||||
"""
|
||||
if not self._user_speaking:
|
||||
diff_time = time.time() - self._last_user_speaking_time
|
||||
if diff_time > self._bot_interruption_timeout:
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
@@ -279,6 +309,12 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -27,6 +28,8 @@ from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
AGGREGATION_TIMEOUT = 0.1
|
||||
AGGREGATION_SLEEP = 0.15
|
||||
BOT_INTERRUPTION_TIMEOUT = 0.2
|
||||
BOT_INTERRUPTION_SLEEP = 0.25
|
||||
|
||||
|
||||
class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -274,6 +277,72 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?"
|
||||
|
||||
async def test_t(self):
|
||||
context = OpenAILLMContext()
|
||||
aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
(received_down, _) = await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert received_down[-1].context.messages[0]["content"] == "Hello!"
|
||||
|
||||
async def test_it(self):
|
||||
context = OpenAILLMContext()
|
||||
aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
(received_down, _) = await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!"
|
||||
|
||||
async def test_sie_delay_it(self):
|
||||
context = OpenAILLMContext()
|
||||
aggregator = LLMUserContextAggregator(
|
||||
context,
|
||||
aggregation_timeout=AGGREGATION_TIMEOUT,
|
||||
bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT,
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
SleepFrame(BOT_INTERRUPTION_SLEEP),
|
||||
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
(received_down, _) = await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert received_down[-1].context.messages[0]["content"] == "How are you?"
|
||||
|
||||
|
||||
class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_empty(self):
|
||||
|
||||
Reference in New Issue
Block a user