Update LLmUserContextAggregator to conditionally push_aggregation
This commit is contained in:
@@ -12,6 +12,7 @@ from typing import Dict, List, Literal, Optional, Set
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -23,6 +24,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionConfig,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -266,6 +268,8 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
|
||||
self._interruption_config: Optional[InterruptionConfig] = None
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
@@ -320,20 +324,46 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_aggregation(self):
|
||||
"""Process the current aggregation and push it downstream."""
|
||||
aggregation = self._aggregation
|
||||
self.reset()
|
||||
await self.handle_aggregation(aggregation)
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Pushes the current aggregation based on interruption configuration and conditions."""
|
||||
if len(self._aggregation) > 0:
|
||||
aggregation = self._aggregation
|
||||
if self._interruption_config and self._bot_speaking:
|
||||
should_interrupt = await self._should_interrupt_based_on_config()
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self.reset()
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
# Don't process aggregation, just reset it
|
||||
self.reset()
|
||||
else:
|
||||
# No interruption config - normal behavior (always push aggregation)
|
||||
await self._process_aggregation()
|
||||
|
||||
await self.handle_aggregation(aggregation)
|
||||
async def _should_interrupt_based_on_config(self) -> bool:
|
||||
"""Check if interruption should occur based on configured conditions."""
|
||||
assert self._interruption_config is not None
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
if not self._aggregation or self._interruption_config.min_words is None:
|
||||
return False
|
||||
|
||||
word_count = len(self._aggregation.split())
|
||||
return word_count >= self._interruption_config.min_words
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._interruption_config = frame.interruption_config
|
||||
self._create_aggregation_task()
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
|
||||
Reference in New Issue
Block a user