diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 0422524f8..0ed6f2f0e 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -5,10 +5,12 @@ # import asyncio +import time from abc import abstractmethod from typing import List from pipecat.frames.frames import ( + BotInterruptionFrame, CancelFrame, EndFrame, Frame, @@ -171,12 +173,20 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): class LLMUserContextAggregator(LLMContextResponseAggregator): - def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs): + def __init__( + self, + context: OpenAILLMContext, + aggregation_timeout: float = 1.0, + bot_interruption_timeout: float = 2.0, + **kwargs, + ): super().__init__(context=context, role="user", **kwargs) self._aggregation_timeout = aggregation_timeout + self._bot_interruption_timeout = bot_interruption_timeout self._seen_interim_results = False self._user_speaking = False + self._last_user_speaking_time = 0 self._aggregation_event = asyncio.Event() self._aggregation_task = None @@ -219,43 +229,63 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): await self.push_frame(frame, direction) async def _start(self, frame: StartFrame): - self._aggregation_task = self.create_task(self._aggregation_task_handler()) + self._create_aggregation_task() async def _stop(self, frame: EndFrame): - if self._aggregation_task: - await self.cancel_task(self._aggregation_task) - self._aggregation_task = None + await self._cancel_aggregation_task() async def _cancel(self, frame: CancelFrame): - if self._aggregation_task: - await self.cancel_task(self._aggregation_task) - self._aggregation_task = None + await self._cancel_aggregation_task() async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame): self._user_speaking = True async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): + self._last_user_speaking_time = time.time() self._user_speaking = False if not self._seen_interim_results: await self.push_aggregation() async def _handle_transcription(self, frame: TranscriptionFrame): self._aggregation += frame.text - # We just got our final result, so let's reset interim results. + # We just got a final result, so let's reset interim results. self._seen_interim_results = False - # Wakeup our task. + # Reset aggregation timer. self._aggregation_event.set() async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): self._seen_interim_results = True + # Reset aggregation timer. + self._aggregation_event.set() + + def _create_aggregation_task(self): + self._aggregation_task = self.create_task(self._aggregation_task_handler()) + + async def _cancel_aggregation_task(self): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None async def _aggregation_task_handler(self): while True: - await self._aggregation_event.wait() - await asyncio.sleep(self._aggregation_timeout) - if not self._user_speaking: - await self.push_aggregation() - self._aggregation_event.clear() + try: + await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout) + await self._maybe_push_bot_interruption() + except asyncio.TimeoutError: + if not self._user_speaking: + await self.push_aggregation() + finally: + self._aggregation_event.clear() + + async def _maybe_push_bot_interruption(self): + """If the user stopped speaking a while back and we got a transcription + frame we might want to interrupt the bot. + + """ + if not self._user_speaking: + diff_time = time.time() - self._last_user_speaking_time + if diff_time > self._bot_interruption_timeout: + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) class LLMAssistantContextAggregator(LLMContextResponseAggregator): @@ -279,6 +309,12 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): await self._handle_llm_end(frame) elif isinstance(frame, TextFrame): await self._handle_text(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + self.add_messages(frame.messages) + elif isinstance(frame, LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) else: await self.push_frame(frame, direction) diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py index cb4620633..e8ec92b0d 100644 --- a/tests/test_llm_response.py +++ b/tests/test_llm_response.py @@ -7,6 +7,7 @@ import unittest from pipecat.frames.frames import ( + BotInterruptionFrame, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -27,6 +28,8 @@ from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 AGGREGATION_SLEEP = 0.15 +BOT_INTERRUPTION_TIMEOUT = 0.2 +BOT_INTERRUPTION_SLEEP = 0.25 class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): @@ -274,6 +277,72 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): ) assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + async def test_t(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [OpenAILLMContextFrame] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello!" + + async def test_it(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [OpenAILLMContextFrame] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + + async def test_sie_delay_it(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator( + context, + aggregation_timeout=AGGREGATION_TIMEOUT, + bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT, + ) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + SleepFrame(BOT_INTERRUPTION_SLEEP), + InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): async def test_empty(self):