diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index d1d3806d5..afd057029 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -36,6 +36,7 @@ from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id if TYPE_CHECKING: + from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameProcessor @@ -474,6 +475,20 @@ class TranscriptionUpdateFrame(DataFrame): return f"{self.name}(pts: {pts}, messages: {len(self.messages)})" +@dataclass +class LLMContextFrame(Frame): + """Frame containing a universal LLM context. + + Used as a signal to LLM services to ingest the provided context and + generate a response based on it. + + Parameters: + context: The LLM context containing messages, tools, and configuration. + """ + + context: "LLMContext" + + @dataclass class LLMMessagesFrame(DataFrame): """Frame containing LLM messages for chat completion. diff --git a/src/pipecat/processors/aggregators/llm_context.py b/src/pipecat/processors/aggregators/llm_context.py new file mode 100644 index 000000000..3208d38f7 --- /dev/null +++ b/src/pipecat/processors/aggregators/llm_context.py @@ -0,0 +1,197 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Universal LLM context management for LLM services in Pipecat. + +Context contents are represented in a universal format (based on OpenAI) +that supports a union of known Pipecat LLM service functionality. + +Whenever an LLM service needs to access context, it does a just-in-time +translation from this universal context into whatever format it needs, using a +service-specific adapter. +""" + +import base64 +import io +from dataclasses import dataclass +from typing import Any, List, Optional + +from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN +from openai._types import NotGiven as OpenAINotGiven +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam, +) +from PIL import Image + +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.frames.frames import AudioRawFrame, Frame + +# "Re-export" types from OpenAI that we're using as universal context types. +# NOTE: this is just for convenience, for now. As soon as the universal types +# diverge from OpenAI's, we should ditch this. In fact, audio frames already +# diverge from OpenAI's standard format...we really ought to do this. +LLMContextMessage = ChatCompletionMessageParam +LLMContextTool = ChatCompletionToolParam +LLMContextToolChoice = ChatCompletionToolChoiceOptionParam +NOT_GIVEN = OPEN_AI_NOT_GIVEN +NotGiven = OpenAINotGiven + + +class LLMContext: + """Manages conversation context for LLM interactions. + + Handles message history, tool definitions, tool choices, and multimedia + content for LLM conversations. Provides methods for message manipulation, + and content formatting. + """ + + def __init__( + self, + messages: Optional[List[LLMContextMessage]] = None, + tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN, + tool_choice: LLMContextToolChoice | NotGiven = NOT_GIVEN, + ): + """Initialize the LLM context. + + Args: + messages: Initial list of conversation messages. + tools: Available tools for the LLM to use. + tool_choice: Tool selection strategy for the LLM. + """ + self._messages: List[LLMContextMessage] = messages if messages else [] + self._tools: List[LLMContextTool] | NotGiven | ToolsSchema = tools + self._tool_choice: LLMContextToolChoice | NotGiven = tool_choice + + @property + def messages(self) -> List[LLMContextMessage]: + """Get the current messages list. + + Returns: + List of conversation messages. + """ + return self._messages + + @property + def tools(self) -> List[LLMContextTool] | NotGiven | List[Any]: + """Get the tools list. + + Returns: + Tools list. + """ + return self._tools + + @property + def tool_choice(self) -> LLMContextToolChoice | NotGiven: + """Get the current tool choice setting. + + Returns: + The tool choice configuration. + """ + return self._tool_choice + + def add_message(self, message: LLMContextMessage): + """Add a single message to the context. + + Args: + message: The message to add to the conversation history. + """ + self._messages.append(message) + + def add_messages(self, messages: List[LLMContextMessage]): + """Add multiple messages to the context. + + Args: + messages: List of messages to add to the conversation history. + """ + self._messages.extend(messages) + + def set_messages(self, messages: List[LLMContextMessage]): + """Replace all messages in the context. + + Args: + messages: New list of messages to replace the current history. + """ + self._messages[:] = messages + + def set_tools(self, tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN): + """Set the available tools for the LLM. + + Args: + tools: List of tools available to the LLM, a ToolsSchema, or NOT_GIVEN to disable tools. + """ + # TODO: convert empty ToolsSchema to NOT_GIVEN if needed? + # TODO: maybe someday also convert provider-specific tools to ToolsSchema so it's always in a provider-neutral format here? See open_ai_adapter.py for related comment. Pipecat Flows is currently converting provider-specific tools to ToolsSchema... + if isinstance(tools, list) and len(tools) == 0: + tools = NOT_GIVEN + self._tools = tools + + def set_tool_choice(self, tool_choice: LLMContextToolChoice | NotGiven): + """Set the tool choice configuration. + + Args: + tool_choice: Tool selection strategy for the LLM. + """ + self._tool_choice = tool_choice + + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + """Add a message containing an image frame. + + Args: + format: Image format (e.g., 'RGB', 'RGBA'). + size: Image dimensions as (width, height) tuple. + image: Raw image bytes. + text: Optional text to include with the image. + """ + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + # TODO: we might not want the universal format to be base64 encoded, since encoding is not needed by all LLM services; today, te Gemini adapter has to decode from base64, which is less than ideal. + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + + content = [] + if text: + content.append({"type": "text", "text": text}) + content.append( + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, + ) + self.add_message({"role": "user", "content": content}) + + # NOTE: today we've only built support for audio frames with the Google + # LLM, so this "universal" representation skews towards that. + # When we add support for other LLMs, we may need to adjust this. + def add_audio_frames_message( + self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows" + ): + """Add a message containing audio frames. + + Args: + audio_frames: List of audio frame objects to include. + text: Optional text to include with the audio. + """ + if not audio_frames: + return + + sample_rate = audio_frames[0].sample_rate + num_channels = audio_frames[0].num_channels + + content = [] + content.append({"type": "text", "text": text}) + data = b"".join(frame.audio for frame in audio_frames) + # TODO: filter this out in OpenAI adapter, since it doesn't support audio frames + content.append( + { + "type": "input_audio", + "input_audio": { + "data": data, + "sample_rate": sample_rate, + "num_channels": num_channels, + }, + } + ) + self.add_message({"role": "user", "content": content}) diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py new file mode 100644 index 000000000..177a1c075 --- /dev/null +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -0,0 +1,850 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""LLM response aggregators for handling conversation context and message aggregation. + +This module provides aggregators that process and accumulate LLM responses, user inputs, +and conversation context. These aggregators handle the flow between speech-to-text, +LLM processing, and text-to-speech components in conversational AI pipelines. +""" + +import asyncio +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Set + +from loguru import logger + +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import ( + BotInterruptionFrame, + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, + EndFrame, + Frame, + FunctionCallCancelFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + FunctionCallsStartedFrame, + InputAudioRawFrame, + InterimTranscriptionFrame, + LLMContextAssistantTimestampFrame, + LLMContextFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesAppendFrame, + LLMMessagesUpdateFrame, + LLMSetToolChoiceFrame, + LLMSetToolsFrame, + SpeechControlParamsFrame, + StartFrame, + StartInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserImageRawFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.time import time_now_iso8601 + + +class LLMContextAggregator(FrameProcessor): + """Base LLM aggregator that uses an LLMContext for conversation storage. + + This aggregator maintains conversation state using an LLMContext and + pushes LLMContextFrame objects as aggregation frames. It provides + common functionality for context-based conversation management. + """ + + def __init__(self, *, context: LLMContext, role: str, **kwargs): + """Initialize the context response aggregator. + + Args: + context: The LLM context to use for conversation storage. + role: The role this aggregator represents (e.g. "user", "assistant"). + **kwargs: Additional arguments passed to parent class. + """ + super().__init__(**kwargs) + self._context = context + self._role = role + + self._aggregation: str = "" + + @property + def messages(self) -> List[dict]: + """Get messages from the LLM context. + + Returns: + List of message dictionaries from the context. + """ + return self._context.messages + + @property + def role(self) -> str: + """Get the role for this aggregator. + + Returns: + The role string for this aggregator. + """ + return self._role + + @property + def context(self): + """Get the LLM context. + + Returns: + The LLMContext instance used by this aggregator. + """ + return self._context + + def get_context_frame(self) -> LLMContextFrame: + """Create a context frame with the current context. + + Returns: + LLMContextFrame containing the current context. + """ + return LLMContextFrame(context=self._context) + + async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): + """Push a context frame in the specified direction. + + Args: + direction: The direction to push the frame (upstream or downstream). + """ + frame = self.get_context_frame() + await self.push_frame(frame, direction) + + def add_messages(self, messages): + """Add messages to the context. + + Args: + messages: Messages to add to the conversation context. + """ + self._context.add_messages(messages) + + def set_messages(self, messages): + """Set the context messages. + + Args: + messages: Messages to replace the current context messages. + """ + self._context.set_messages(messages) + + def set_tools(self, tools: List): + """Set tools in the context. + + Args: + tools: List of tool definitions to set in the context. + """ + self._context.set_tools(tools) + + def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict): + """Set tool choice in the context. + + Args: + tool_choice: Tool choice configuration for the context. + """ + self._context.set_tool_choice(tool_choice) + + async def reset(self): + """Reset the aggregation state.""" + self._aggregation = "" + + +# NOTE: the "universal" suffix is just meant to distinguish this aggregator +# from the old LLMUserContextAggregator while we gradually migrate service to +# use the new universal LLMContext and associated patterns. The suffix will go +# away once the migration is complete and the other LLMUserContextAggregator is +# deprecated. +class LLMUserAggregator(LLMContextAggregator): + """User LLM aggregator that processes speech-to-text transcriptions. + + This aggregator handles the complex logic of aggregating user speech transcriptions + from STT services. It manages multiple scenarios including: + + - Transcriptions received between VAD events + - Transcriptions received outside VAD events + - Interim vs final transcriptions + - User interruptions during bot speech + - Emulated VAD for whispered or short utterances + + The aggregator uses timeouts to handle cases where transcriptions arrive + after VAD events or when no VAD is available. + """ + + def __init__( + self, + context: LLMContext, + *, + params: Optional[LLMUserAggregatorParams] = None, + **kwargs, + ): + """Initialize the user context aggregator. + + Args: + context: The LLM context for conversation storage. + params: Configuration parameters for aggregation behavior. + **kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'. + """ + super().__init__(context=context, role="user", **kwargs) + self._params = params or LLMUserAggregatorParams() + self._vad_params: Optional[VADParams] = None + self._turn_params: Optional[SmartTurnParams] = None + + if "aggregation_timeout" in kwargs: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'aggregation_timeout' is deprecated, use 'params' instead.", + DeprecationWarning, + ) + + self._params.aggregation_timeout = kwargs["aggregation_timeout"] + + self._user_speaking = False + self._bot_speaking = False + self._was_bot_speaking = False + self._emulating_vad = False + self._seen_interim_results = False + self._waiting_for_aggregation = False + + self._aggregation_event = asyncio.Event() + self._aggregation_task = None + + async def reset(self): + """Reset the aggregation state and interruption strategies.""" + await super().reset() + self._was_bot_speaking = False + self._seen_interim_results = False + self._waiting_for_aggregation = False + [await s.reset() for s in self._interruption_strategies] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames for user speech aggregation and context management. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + # Push StartFrame before start(), because we want StartFrame to be + # processed by every processor before any other frame is processed. + await self.push_frame(frame, direction) + await self._start(frame) + elif isinstance(frame, EndFrame): + # Push EndFrame before stop(), because stop() waits on the task to + # finish and the task finishes when EndFrame is processed. + await self.push_frame(frame, direction) + await self._stop(frame) + elif isinstance(frame, CancelFrame): + await self._cancel(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, InputAudioRawFrame): + await self._handle_input_audio(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStartedSpeakingFrame): + await self._handle_user_started_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStoppedSpeakingFrame): + await self._handle_user_stopped_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, BotStartedSpeakingFrame): + await self._handle_bot_started_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, BotStoppedSpeakingFrame): + await self._handle_bot_stopped_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, TranscriptionFrame): + await self._handle_transcription(frame) + elif isinstance(frame, InterimTranscriptionFrame): + await self._handle_interim_transcription(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + await self._handle_llm_messages_append(frame) + elif isinstance(frame, LLMMessagesUpdateFrame): + await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) + elif isinstance(frame, LLMSetToolChoiceFrame): + self.set_tool_choice(frame.tool_choice) + elif isinstance(frame, SpeechControlParamsFrame): + self._vad_params = frame.vad_params + self._turn_params = frame.turn_params + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + + async def _process_aggregation(self): + """Process the current aggregation and push it downstream.""" + aggregation = self._aggregation + await self.reset() + self._context.add_message({"role": self.role, "content": aggregation}) + frame = LLMContextFrame(self._context) + await self.push_frame(frame) + + async def _push_aggregation(self): + """Push the current aggregation based on interruption strategies and conditions.""" + if len(self._aggregation) > 0: + if self.interruption_strategies and self._bot_speaking: + should_interrupt = await self._should_interrupt_based_on_strategies() + + if should_interrupt: + logger.debug( + "Interruption conditions met - pushing BotInterruptionFrame and aggregation" + ) + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + await self._process_aggregation() + else: + logger.debug("Interruption conditions not met - not pushing aggregation") + # Don't process aggregation, just reset it + await self.reset() + else: + # No interruption config - normal behavior (always push aggregation) + await self._process_aggregation() + # Handles the case where both the user and the bot are not speaking, + # and the bot was previously speaking before the user interruption. + # Normally, when the user stops speaking, new text is expected, + # which triggers the bot to respond. However, if no new text + # is received, this safeguard ensures + # the bot doesn't hang indefinitely while waiting to speak again. + elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: + logger.warning("User stopped speaking but no new aggregation received.") + # Resetting it so we don't trigger this twice + self._was_bot_speaking = False + # TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription + # So we need more tests and probably make this feature configurable, disabled it by default. + # We are just pushing the same previous context to be processed again in this case + # await self.push_frame(LLMContextFrame(self._context)) + + async def _should_interrupt_based_on_strategies(self) -> bool: + """Check if interruption should occur based on configured strategies. + + Returns: + True if any interruption strategy indicates interruption should occur. + """ + + async def should_interrupt(strategy: BaseInterruptionStrategy): + await strategy.append_text(self._aggregation) + return await strategy.should_interrupt() + + return any([await should_interrupt(s) for s in self._interruption_strategies]) + + async def _start(self, frame: StartFrame): + self._create_aggregation_task() + + async def _stop(self, frame: EndFrame): + await self._cancel_aggregation_task() + + async def _cancel(self, frame: CancelFrame): + await self._cancel_aggregation_task() + + async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame): + self.add_messages(frame.messages) + if frame.run_llm: + await self.push_context_frame() + + async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + if frame.run_llm: + await self.push_context_frame() + + async def _handle_input_audio(self, frame: InputAudioRawFrame): + for s in self.interruption_strategies: + await s.append_audio(frame.audio, frame.sample_rate) + + async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): + self._user_speaking = True + self._waiting_for_aggregation = True + self._was_bot_speaking = self._bot_speaking + + # If we get a non-emulated UserStartedSpeakingFrame but we are in the + # middle of emulating VAD, let's stop emulating VAD (i.e. don't send the + # EmulateUserStoppedSpeakingFrame). + if not frame.emulated and self._emulating_vad: + self._emulating_vad = False + + async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): + self._user_speaking = False + # We just stopped speaking. Let's see if there's some aggregation to + # push. If the last thing we saw is an interim transcription, let's wait + # pushing the aggregation as we will probably get a final transcription. + if len(self._aggregation) > 0: + if not self._seen_interim_results: + await self._push_aggregation() + # Handles the case where both the user and the bot are not speaking, + # and the bot was previously speaking before the user interruption. + # So in this case we are resetting the aggregation timer + elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: + # Reset aggregation timer. + self._aggregation_event.set() + + async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame): + self._bot_speaking = True + + async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame): + self._bot_speaking = False + + async def _handle_transcription(self, frame: TranscriptionFrame): + text = frame.text + + # Make sure we really have some text. + if not text.strip(): + return + + self._aggregation += f" {text}" if self._aggregation else text + # We just got a final result, so let's reset interim results. + self._seen_interim_results = False + # Reset aggregation timer. + self._aggregation_event.set() + + async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): + self._seen_interim_results = True + + def _create_aggregation_task(self): + if not self._aggregation_task: + self._aggregation_task = self.create_task(self._aggregation_task_handler()) + + async def _cancel_aggregation_task(self): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None + + async def _aggregation_task_handler(self): + while True: + try: + # The _aggregation_task_handler handles two distinct timeout scenarios: + # + # 1. When emulating_vad=True: Wait for emulated VAD timeout before + # pushing aggregation (simulating VAD behavior when no actual VAD + # detection occurred). + # + # 2. When emulating_vad=False: Use aggregation_timeout as a buffer + # to wait for potential late-arriving transcription frames after + # a real VAD event. + # + # For emulated VAD scenarios, the timeout strategy depends on whether + # a turn analyzer is configured: + # + # - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because + # the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech + # chunking to feed the turn analyzer. This low value is too fast + # for emulated VAD scenarios where we need to allow users time to + # finish speaking (e.g. 0.8s). + # + # - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain + # consistent user experience between real VAD detection and + # emulated VAD scenarios. + if not self._emulating_vad: + timeout = self._params.aggregation_timeout + elif self._turn_params: + timeout = self._params.turn_emulated_vad_timeout + else: + # Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params + timeout = ( + self._vad_params.stop_secs + if self._vad_params + else self._params.turn_emulated_vad_timeout + ) + await asyncio.wait_for(self._aggregation_event.wait(), timeout) + await self._maybe_emulate_user_speaking() + except asyncio.TimeoutError: + if not self._user_speaking: + await self._push_aggregation() + + # If we are emulating VAD we still need to send the user stopped + # speaking frame. + if self._emulating_vad: + await self.push_frame( + EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM + ) + self._emulating_vad = False + finally: + self.reset_watchdog() + self._aggregation_event.clear() + + async def _maybe_emulate_user_speaking(self): + """Maybe emulate user speaking based on transcription. + + Emulate user speaking if we got a transcription but it was not + detected by VAD. Behavior when bot is speaking depends on the + enable_emulated_vad_interruptions parameter. + """ + # Check if we received a transcription but VAD was not able to detect + # voice (e.g. when you whisper a short utterance). In that case, we need + # to emulate VAD (i.e. user start/stopped speaking), but we do it only + # if the bot is not speaking. If the bot is speaking and we really have + # a short utterance we don't really want to interrupt the bot. + if ( + not self._user_speaking + and not self._waiting_for_aggregation + and len(self._aggregation) > 0 + ): + if self._bot_speaking and not self._params.enable_emulated_vad_interruptions: + # If emulated VAD interruptions are disabled and bot is speaking, ignore + logger.debug("Ignoring user speaking emulation, bot is speaking.") + await self.reset() + else: + # Either bot is not speaking, or emulated VAD interruptions are enabled + # - trigger user speaking emulation. + await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) + self._emulating_vad = True + + +# NOTE: the "universal" suffix is just meant to distinguish this aggregator +# from the old LLMAssistantContextAggregator while we gradually migrate service +# to use the new universal LLMContext and associated patterns. The suffix will +# go away once the migration is complete and the other +# LLMAssistantContextAggregator is deprecated. +class LLMAssistantAggregator(LLMContextAggregator): + """Assistant LLM aggregator that processes bot responses and function calls. + + This aggregator handles the complex logic of processing assistant responses including: + + - Text frame aggregation between response start/end markers + - Function call lifecycle management + - Context updates with timestamps + - Tool execution and result handling + - Interruption handling during responses + + The aggregator manages function calls in progress and coordinates between + text generation and tool execution phases of LLM responses. + """ + + def __init__( + self, + context: LLMContext, + *, + params: Optional[LLMAssistantAggregatorParams] = None, + **kwargs, + ): + """Initialize the assistant context aggregator. + + Args: + context: The OpenAI LLM context for conversation storage. + params: Configuration parameters for aggregation behavior. + **kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'. + """ + super().__init__(context=context, role="assistant", **kwargs) + self._params = params or LLMAssistantAggregatorParams() + + if "expect_stripped_words" in kwargs: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'expect_stripped_words' is deprecated, use 'params' instead.", + DeprecationWarning, + ) + + self._params.expect_stripped_words = kwargs["expect_stripped_words"] + + self._started = 0 + self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {} + self._context_updated_tasks: Set[asyncio.Task] = set() + + @property + def has_function_calls_in_progress(self) -> bool: + """Check if there are any function calls currently in progress. + + Returns: + True if function calls are in progress, False otherwise. + """ + return bool(self._function_calls_in_progress) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames for assistant response aggregation and function call management. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, StartInterruptionFrame): + await self._handle_interruptions(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, LLMFullResponseStartFrame): + await self._handle_llm_start(frame) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._handle_llm_end(frame) + elif isinstance(frame, TextFrame): + await self._handle_text(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + await self._handle_llm_messages_append(frame) + elif isinstance(frame, LLMMessagesUpdateFrame): + await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) + elif isinstance(frame, LLMSetToolChoiceFrame): + self.set_tool_choice(frame.tool_choice) + elif isinstance(frame, FunctionCallsStartedFrame): + await self._handle_function_calls_started(frame) + elif isinstance(frame, FunctionCallInProgressFrame): + await self._handle_function_call_in_progress(frame) + elif isinstance(frame, FunctionCallResultFrame): + await self._handle_function_call_result(frame) + elif isinstance(frame, FunctionCallCancelFrame): + await self._handle_function_call_cancel(frame) + elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id: + await self._handle_user_image_frame(frame) + elif isinstance(frame, BotStoppedSpeakingFrame): + await self._push_aggregation() + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + + async def _push_aggregation(self): + """Push the current assistant aggregation with timestamp.""" + if not self._aggregation: + return + + aggregation = self._aggregation.strip() + await self.reset() + + if aggregation: + self._context.add_message({"role": "assistant", "content": aggregation}) + + # Push context frame + await self.push_context_frame() + + # Push timestamp frame with current time + timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) + await self.push_frame(timestamp_frame) + + async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame): + self.add_messages(frame.messages) + if frame.run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) + + async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + if frame.run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) + + async def _handle_interruptions(self, frame: StartInterruptionFrame): + await self._push_aggregation() + self._started = 0 + await self.reset() + + async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame): + function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls] + logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}") + for function_call in frame.function_calls: + self._function_calls_in_progress[function_call.tool_call_id] = None + + async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): + logger.debug( + f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]" + ) + + # Update context with the in-progress function call + self._context.add_message( + { + "role": "assistant", + "tool_calls": [ + { + "id": frame.tool_call_id, + "function": { + "name": frame.function_name, + "arguments": json.dumps(frame.arguments), + }, + "type": "function", + } + ], + } + ) + self._context.add_message( + { + "role": "tool", + "content": "IN_PROGRESS", + "tool_call_id": frame.tool_call_id, + } + ) + + self._function_calls_in_progress[frame.tool_call_id] = frame + + async def _handle_function_call_result(self, frame: FunctionCallResultFrame): + logger.debug( + f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]" + ) + if frame.tool_call_id not in self._function_calls_in_progress: + logger.warning( + f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running" + ) + return + + del self._function_calls_in_progress[frame.tool_call_id] + + properties = frame.properties + + # Update context with the function call result + if frame.result: + result = json.dumps(frame.result) + self._update_function_call_result(frame.function_name, frame.tool_call_id, result) + else: + self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED") + + run_llm = False + + # Run inference if the function call result requires it. + if frame.result: + if properties and properties.run_llm is not None: + # If the tool call result has a run_llm property, use it. + run_llm = properties.run_llm + elif frame.run_llm is not None: + # If the frame is indicating we should run the LLM, do it. + run_llm = frame.run_llm + else: + # If this is the last function call in progress, run the LLM. + run_llm = not bool(self._function_calls_in_progress) + + if run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) + + # Call the `on_context_updated` callback once the function call result + # is added to the context. Also, run this in a separate task to make + # sure we don't block the pipeline. + if properties and properties.on_context_updated: + task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" + task = self.create_task(properties.on_context_updated(), task_name) + self._context_updated_tasks.add(task) + task.add_done_callback(self._context_updated_task_finished) + + async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame): + logger.debug( + f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]" + ) + if frame.tool_call_id not in self._function_calls_in_progress: + return + + if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption: + # Update context with the function call cancellation + self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED") + del self._function_calls_in_progress[frame.tool_call_id] + + def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any): + for message in self._context.messages: + if ( + message["role"] == "tool" + and message["tool_call_id"] + and message["tool_call_id"] == tool_call_id + ): + message["content"] = result + + async def _handle_user_image_frame(self, frame: UserImageRawFrame): + logger.debug( + f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]" + ) + + if frame.request.tool_call_id not in self._function_calls_in_progress: + logger.warning( + f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running" + ) + return + + del self._function_calls_in_progress[frame.request.tool_call_id] + + # Update context with the image frame + await self._update_function_call_result( + frame.request.function_name, frame.request.tool_call_id, "COMPLETED" + ) + self._context.add_image_frame_message( + format=frame.format, + size=frame.size, + image=frame.image, + text=frame.request.context, + ) + + await self._push_aggregation() + await self.push_context_frame(FrameDirection.UPSTREAM) + + async def _handle_llm_start(self, _: LLMFullResponseStartFrame): + self._started += 1 + + async def _handle_llm_end(self, _: LLMFullResponseEndFrame): + self._started -= 1 + await self._push_aggregation() + + async def _handle_text(self, frame: TextFrame): + if not self._started: + return + + if self._params.expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text + + def _context_updated_task_finished(self, task: asyncio.Task): + self._context_updated_tasks.discard(task) + # The task is finished so this should exit immediately. We need to do + # this because otherwise the task manager would report a dangling task + # if we don't remove it. + asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop()) + + +@dataclass +class LLMContextAggregatorPair: + """Pair of LLM context aggregators for user and assistant messages. + + Parameters: + _user: User context aggregator for processing user messages. + _assistant: Assistant context aggregator for processing assistant messages. + """ + + _user: LLMUserAggregator + _assistant: LLMAssistantAggregator + + @staticmethod + def create( + context: LLMContext, + *, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + ) -> "LLMContextAggregatorPair": + """Factory method to create an LLMContextAggregatorPair. + + Args: + context: The context managed by the aggregators. + user_params: Parameters for the user context aggregator. + assistant_params: Parameters for the assistant context aggregator. + + Returns: + LLMContextAggregatorPair: A new instance with configured aggregators. + """ + user = LLMUserAggregator(context, params=user_params) + assistant = LLMAssistantAggregator(context, params=assistant_params) + return LLMContextAggregatorPair(_user=user, _assistant=assistant) + + def user(self) -> LLMUserAggregator: + """Get the user context aggregator. + + Returns: + The user context aggregator instance. + """ + return self._user + + def assistant(self) -> LLMAssistantAggregator: + """Get the assistant context aggregator. + + Returns: + The assistant context aggregator instance. + """ + return self._assistant