[WIP] Universal (LLM-agnostic) context machinery to support runtime LLM switching.

- Added universal `LLMContext` and associated context aggregators.
This commit is contained in:
Paul Kompfner
2025-08-12 14:20:56 -04:00
parent 37980b0854
commit ff8d158e18
3 changed files with 1062 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameProcessor
@@ -474,6 +475,20 @@ class TranscriptionUpdateFrame(DataFrame):
return f"{self.name}(pts: {pts}, messages: {len(self.messages)})"
@dataclass
class LLMContextFrame(Frame):
"""Frame containing a universal LLM context.
Used as a signal to LLM services to ingest the provided context and
generate a response based on it.
Parameters:
context: The LLM context containing messages, tools, and configuration.
"""
context: "LLMContext"
@dataclass
class LLMMessagesFrame(DataFrame):
"""Frame containing LLM messages for chat completion.

View File

@@ -0,0 +1,197 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Universal LLM context management for LLM services in Pipecat.
Context contents are represented in a universal format (based on OpenAI)
that supports a union of known Pipecat LLM service functionality.
Whenever an LLM service needs to access context, it does a just-in-time
translation from this universal context into whatever format it needs, using a
service-specific adapter.
"""
import base64
import io
from dataclasses import dataclass
from typing import Any, List, Optional
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
from openai._types import NotGiven as OpenAINotGiven
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
)
from PIL import Image
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import AudioRawFrame, Frame
# "Re-export" types from OpenAI that we're using as universal context types.
# NOTE: this is just for convenience, for now. As soon as the universal types
# diverge from OpenAI's, we should ditch this. In fact, audio frames already
# diverge from OpenAI's standard format...we really ought to do this.
LLMContextMessage = ChatCompletionMessageParam
LLMContextTool = ChatCompletionToolParam
LLMContextToolChoice = ChatCompletionToolChoiceOptionParam
NOT_GIVEN = OPEN_AI_NOT_GIVEN
NotGiven = OpenAINotGiven
class LLMContext:
"""Manages conversation context for LLM interactions.
Handles message history, tool definitions, tool choices, and multimedia
content for LLM conversations. Provides methods for message manipulation,
and content formatting.
"""
def __init__(
self,
messages: Optional[List[LLMContextMessage]] = None,
tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN,
tool_choice: LLMContextToolChoice | NotGiven = NOT_GIVEN,
):
"""Initialize the LLM context.
Args:
messages: Initial list of conversation messages.
tools: Available tools for the LLM to use.
tool_choice: Tool selection strategy for the LLM.
"""
self._messages: List[LLMContextMessage] = messages if messages else []
self._tools: List[LLMContextTool] | NotGiven | ToolsSchema = tools
self._tool_choice: LLMContextToolChoice | NotGiven = tool_choice
@property
def messages(self) -> List[LLMContextMessage]:
"""Get the current messages list.
Returns:
List of conversation messages.
"""
return self._messages
@property
def tools(self) -> List[LLMContextTool] | NotGiven | List[Any]:
"""Get the tools list.
Returns:
Tools list.
"""
return self._tools
@property
def tool_choice(self) -> LLMContextToolChoice | NotGiven:
"""Get the current tool choice setting.
Returns:
The tool choice configuration.
"""
return self._tool_choice
def add_message(self, message: LLMContextMessage):
"""Add a single message to the context.
Args:
message: The message to add to the conversation history.
"""
self._messages.append(message)
def add_messages(self, messages: List[LLMContextMessage]):
"""Add multiple messages to the context.
Args:
messages: List of messages to add to the conversation history.
"""
self._messages.extend(messages)
def set_messages(self, messages: List[LLMContextMessage]):
"""Replace all messages in the context.
Args:
messages: New list of messages to replace the current history.
"""
self._messages[:] = messages
def set_tools(self, tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN):
"""Set the available tools for the LLM.
Args:
tools: List of tools available to the LLM, a ToolsSchema, or NOT_GIVEN to disable tools.
"""
# TODO: convert empty ToolsSchema to NOT_GIVEN if needed?
# TODO: maybe someday also convert provider-specific tools to ToolsSchema so it's always in a provider-neutral format here? See open_ai_adapter.py for related comment. Pipecat Flows is currently converting provider-specific tools to ToolsSchema...
if isinstance(tools, list) and len(tools) == 0:
tools = NOT_GIVEN
self._tools = tools
def set_tool_choice(self, tool_choice: LLMContextToolChoice | NotGiven):
"""Set the tool choice configuration.
Args:
tool_choice: Tool selection strategy for the LLM.
"""
self._tool_choice = tool_choice
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add a message containing an image frame.
Args:
format: Image format (e.g., 'RGB', 'RGBA').
size: Image dimensions as (width, height) tuple.
image: Raw image bytes.
text: Optional text to include with the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
# TODO: we might not want the universal format to be base64 encoded, since encoding is not needed by all LLM services; today, te Gemini adapter has to decode from base64, which is less than ideal.
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
content = []
if text:
content.append({"type": "text", "text": text})
content.append(
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
)
self.add_message({"role": "user", "content": content})
# NOTE: today we've only built support for audio frames with the Google
# LLM, so this "universal" representation skews towards that.
# When we add support for other LLMs, we may need to adjust this.
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
):
"""Add a message containing audio frames.
Args:
audio_frames: List of audio frame objects to include.
text: Optional text to include with the audio.
"""
if not audio_frames:
return
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
content = []
content.append({"type": "text", "text": text})
data = b"".join(frame.audio for frame in audio_frames)
# TODO: filter this out in OpenAI adapter, since it doesn't support audio frames
content.append(
{
"type": "input_audio",
"input_audio": {
"data": data,
"sample_rate": sample_rate,
"num_channels": num_channels,
},
}
)
self.add_message({"role": "user", "content": content})

View File

@@ -0,0 +1,850 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""LLM response aggregators for handling conversation context and message aggregation.
This module provides aggregators that process and accumulate LLM responses, user inputs,
and conversation context. These aggregators handle the flow between speech-to-text,
LLM processing, and text-to-speech components in conversational AI pipelines.
"""
import asyncio
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Set
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
LLMContextAssistantTimestampFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserImageRawFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601
class LLMContextAggregator(FrameProcessor):
"""Base LLM aggregator that uses an LLMContext for conversation storage.
This aggregator maintains conversation state using an LLMContext and
pushes LLMContextFrame objects as aggregation frames. It provides
common functionality for context-based conversation management.
"""
def __init__(self, *, context: LLMContext, role: str, **kwargs):
"""Initialize the context response aggregator.
Args:
context: The LLM context to use for conversation storage.
role: The role this aggregator represents (e.g. "user", "assistant").
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self._context = context
self._role = role
self._aggregation: str = ""
@property
def messages(self) -> List[dict]:
"""Get messages from the LLM context.
Returns:
List of message dictionaries from the context.
"""
return self._context.messages
@property
def role(self) -> str:
"""Get the role for this aggregator.
Returns:
The role string for this aggregator.
"""
return self._role
@property
def context(self):
"""Get the LLM context.
Returns:
The LLMContext instance used by this aggregator.
"""
return self._context
def get_context_frame(self) -> LLMContextFrame:
"""Create a context frame with the current context.
Returns:
LLMContextFrame containing the current context.
"""
return LLMContextFrame(context=self._context)
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a context frame in the specified direction.
Args:
direction: The direction to push the frame (upstream or downstream).
"""
frame = self.get_context_frame()
await self.push_frame(frame, direction)
def add_messages(self, messages):
"""Add messages to the context.
Args:
messages: Messages to add to the conversation context.
"""
self._context.add_messages(messages)
def set_messages(self, messages):
"""Set the context messages.
Args:
messages: Messages to replace the current context messages.
"""
self._context.set_messages(messages)
def set_tools(self, tools: List):
"""Set tools in the context.
Args:
tools: List of tool definitions to set in the context.
"""
self._context.set_tools(tools)
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
"""Set tool choice in the context.
Args:
tool_choice: Tool choice configuration for the context.
"""
self._context.set_tool_choice(tool_choice)
async def reset(self):
"""Reset the aggregation state."""
self._aggregation = ""
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
# from the old LLMUserContextAggregator while we gradually migrate service to
# use the new universal LLMContext and associated patterns. The suffix will go
# away once the migration is complete and the other LLMUserContextAggregator is
# deprecated.
class LLMUserAggregator(LLMContextAggregator):
"""User LLM aggregator that processes speech-to-text transcriptions.
This aggregator handles the complex logic of aggregating user speech transcriptions
from STT services. It manages multiple scenarios including:
- Transcriptions received between VAD events
- Transcriptions received outside VAD events
- Interim vs final transcriptions
- User interruptions during bot speech
- Emulated VAD for whispered or short utterances
The aggregator uses timeouts to handle cases where transcriptions arrive
after VAD events or when no VAD is available.
"""
def __init__(
self,
context: LLMContext,
*,
params: Optional[LLMUserAggregatorParams] = None,
**kwargs,
):
"""Initialize the user context aggregator.
Args:
context: The LLM context for conversation storage.
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'.
"""
super().__init__(context=context, role="user", **kwargs)
self._params = params or LLMUserAggregatorParams()
self._vad_params: Optional[VADParams] = None
self._turn_params: Optional[SmartTurnParams] = None
if "aggregation_timeout" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
self._user_speaking = False
self._bot_speaking = False
self._was_bot_speaking = False
self._emulating_vad = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
async def reset(self):
"""Reset the aggregation state and interruption strategies."""
await super().reset()
self._was_bot_speaking = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
[await s.reset() for s in self._interruption_strategies]
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for user speech aggregation and context management.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
# processed by every processor before any other frame is processed.
await self.push_frame(frame, direction)
await self._start(frame)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self.push_frame(frame, direction)
await self._stop(frame)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, InputAudioRawFrame):
await self._handle_input_audio(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame):
await self._handle_bot_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._handle_bot_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
await self._handle_transcription(frame)
elif isinstance(frame, InterimTranscriptionFrame):
await self._handle_interim_transcription(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
elif isinstance(frame, SpeechControlParamsFrame):
self._vad_params = frame.vad_params
self._turn_params = frame.turn_params
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _process_aggregation(self):
"""Process the current aggregation and push it downstream."""
aggregation = self._aggregation
await self.reset()
self._context.add_message({"role": self.role, "content": aggregation})
frame = LLMContextFrame(self._context)
await self.push_frame(frame)
async def _push_aggregation(self):
"""Push the current aggregation based on interruption strategies and conditions."""
if len(self._aggregation) > 0:
if self.interruption_strategies and self._bot_speaking:
should_interrupt = await self._should_interrupt_based_on_strategies()
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
)
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")
# Don't process aggregation, just reset it
await self.reset()
else:
# No interruption config - normal behavior (always push aggregation)
await self._process_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# Normally, when the user stops speaking, new text is expected,
# which triggers the bot to respond. However, if no new text
# is received, this safeguard ensures
# the bot doesn't hang indefinitely while waiting to speak again.
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
logger.warning("User stopped speaking but no new aggregation received.")
# Resetting it so we don't trigger this twice
self._was_bot_speaking = False
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription
# So we need more tests and probably make this feature configurable, disabled it by default.
# We are just pushing the same previous context to be processed again in this case
# await self.push_frame(LLMContextFrame(self._context))
async def _should_interrupt_based_on_strategies(self) -> bool:
"""Check if interruption should occur based on configured strategies.
Returns:
True if any interruption strategy indicates interruption should occur.
"""
async def should_interrupt(strategy: BaseInterruptionStrategy):
await strategy.append_text(self._aggregation)
return await strategy.should_interrupt()
return any([await should_interrupt(s) for s in self._interruption_strategies])
async def _start(self, frame: StartFrame):
self._create_aggregation_task()
async def _stop(self, frame: EndFrame):
await self._cancel_aggregation_task()
async def _cancel(self, frame: CancelFrame):
await self._cancel_aggregation_task()
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame()
async def _handle_input_audio(self, frame: InputAudioRawFrame):
for s in self.interruption_strategies:
await s.append_audio(frame.audio, frame.sample_rate)
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
self._user_speaking = True
self._waiting_for_aggregation = True
self._was_bot_speaking = self._bot_speaking
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
# EmulateUserStoppedSpeakingFrame).
if not frame.emulated and self._emulating_vad:
self._emulating_vad = False
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._user_speaking = False
# We just stopped speaking. Let's see if there's some aggregation to
# push. If the last thing we saw is an interim transcription, let's wait
# pushing the aggregation as we will probably get a final transcription.
if len(self._aggregation) > 0:
if not self._seen_interim_results:
await self._push_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# So in this case we are resetting the aggregation timer
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
self._bot_speaking = True
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
self._bot_speaking = False
async def _handle_transcription(self, frame: TranscriptionFrame):
text = frame.text
# Make sure we really have some text.
if not text.strip():
return
self._aggregation += f" {text}" if self._aggregation else text
# We just got a final result, so let's reset interim results.
self._seen_interim_results = False
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
self._seen_interim_results = True
def _create_aggregation_task(self):
if not self._aggregation_task:
self._aggregation_task = self.create_task(self._aggregation_task_handler())
async def _cancel_aggregation_task(self):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _aggregation_task_handler(self):
while True:
try:
# The _aggregation_task_handler handles two distinct timeout scenarios:
#
# 1. When emulating_vad=True: Wait for emulated VAD timeout before
# pushing aggregation (simulating VAD behavior when no actual VAD
# detection occurred).
#
# 2. When emulating_vad=False: Use aggregation_timeout as a buffer
# to wait for potential late-arriving transcription frames after
# a real VAD event.
#
# For emulated VAD scenarios, the timeout strategy depends on whether
# a turn analyzer is configured:
#
# - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because
# the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech
# chunking to feed the turn analyzer. This low value is too fast
# for emulated VAD scenarios where we need to allow users time to
# finish speaking (e.g. 0.8s).
#
# - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain
# consistent user experience between real VAD detection and
# emulated VAD scenarios.
if not self._emulating_vad:
timeout = self._params.aggregation_timeout
elif self._turn_params:
timeout = self._params.turn_emulated_vad_timeout
else:
# Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params
timeout = (
self._vad_params.stop_secs
if self._vad_params
else self._params.turn_emulated_vad_timeout
)
await asyncio.wait_for(self._aggregation_event.wait(), timeout)
await self._maybe_emulate_user_speaking()
except asyncio.TimeoutError:
if not self._user_speaking:
await self._push_aggregation()
# If we are emulating VAD we still need to send the user stopped
# speaking frame.
if self._emulating_vad:
await self.push_frame(
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
)
self._emulating_vad = False
finally:
self.reset_watchdog()
self._aggregation_event.clear()
async def _maybe_emulate_user_speaking(self):
"""Maybe emulate user speaking based on transcription.
Emulate user speaking if we got a transcription but it was not
detected by VAD. Behavior when bot is speaking depends on the
enable_emulated_vad_interruptions parameter.
"""
# Check if we received a transcription but VAD was not able to detect
# voice (e.g. when you whisper a short utterance). In that case, we need
# to emulate VAD (i.e. user start/stopped speaking), but we do it only
# if the bot is not speaking. If the bot is speaking and we really have
# a short utterance we don't really want to interrupt the bot.
if (
not self._user_speaking
and not self._waiting_for_aggregation
and len(self._aggregation) > 0
):
if self._bot_speaking and not self._params.enable_emulated_vad_interruptions:
# If emulated VAD interruptions are disabled and bot is speaking, ignore
logger.debug("Ignoring user speaking emulation, bot is speaking.")
await self.reset()
else:
# Either bot is not speaking, or emulated VAD interruptions are enabled
# - trigger user speaking emulation.
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
self._emulating_vad = True
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
# from the old LLMAssistantContextAggregator while we gradually migrate service
# to use the new universal LLMContext and associated patterns. The suffix will
# go away once the migration is complete and the other
# LLMAssistantContextAggregator is deprecated.
class LLMAssistantAggregator(LLMContextAggregator):
"""Assistant LLM aggregator that processes bot responses and function calls.
This aggregator handles the complex logic of processing assistant responses including:
- Text frame aggregation between response start/end markers
- Function call lifecycle management
- Context updates with timestamps
- Tool execution and result handling
- Interruption handling during responses
The aggregator manages function calls in progress and coordinates between
text generation and tool execution phases of LLM responses.
"""
def __init__(
self,
context: LLMContext,
*,
params: Optional[LLMAssistantAggregatorParams] = None,
**kwargs,
):
"""Initialize the assistant context aggregator.
Args:
context: The OpenAI LLM context for conversation storage.
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'.
"""
super().__init__(context=context, role="assistant", **kwargs)
self._params = params or LLMAssistantAggregatorParams()
if "expect_stripped_words" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
self._started = 0
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
self._context_updated_tasks: Set[asyncio.Task] = set()
@property
def has_function_calls_in_progress(self) -> bool:
"""Check if there are any function calls currently in progress.
Returns:
True if function calls are in progress, False otherwise.
"""
return bool(self._function_calls_in_progress)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for assistant response aggregation and function call management.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
await self._handle_text(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
elif isinstance(frame, FunctionCallsStartedFrame):
await self._handle_function_calls_started(frame)
elif isinstance(frame, FunctionCallInProgressFrame):
await self._handle_function_call_in_progress(frame)
elif isinstance(frame, FunctionCallResultFrame):
await self._handle_function_call_result(frame)
elif isinstance(frame, FunctionCallCancelFrame):
await self._handle_function_call_cancel(frame)
elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id:
await self._handle_user_image_frame(frame)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._push_aggregation()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _push_aggregation(self):
"""Push the current assistant aggregation with timestamp."""
if not self._aggregation:
return
aggregation = self._aggregation.strip()
await self.reset()
if aggregation:
self._context.add_message({"role": "assistant", "content": aggregation})
# Push context frame
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
await self.push_frame(timestamp_frame)
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: StartInterruptionFrame):
await self._push_aggregation()
self._started = 0
await self.reset()
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
for function_call in frame.function_calls:
self._function_calls_in_progress[function_call.tool_call_id] = None
async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
logger.debug(
f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
# Update context with the in-progress function call
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
)
self._function_calls_in_progress[frame.tool_call_id] = frame
async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
logger.debug(
f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
if frame.tool_call_id not in self._function_calls_in_progress:
logger.warning(
f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running"
)
return
del self._function_calls_in_progress[frame.tool_call_id]
properties = frame.properties
# Update context with the function call result
if frame.result:
result = json.dumps(frame.result)
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
run_llm = False
# Run inference if the function call result requires it.
if frame.result:
if properties and properties.run_llm is not None:
# If the tool call result has a run_llm property, use it.
run_llm = properties.run_llm
elif frame.run_llm is not None:
# If the frame is indicating we should run the LLM, do it.
run_llm = frame.run_llm
else:
# If this is the last function call in progress, run the LLM.
run_llm = not bool(self._function_calls_in_progress)
if run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
# Call the `on_context_updated` callback once the function call result
# is added to the context. Also, run this in a separate task to make
# sure we don't block the pipeline.
if properties and properties.on_context_updated:
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
task = self.create_task(properties.on_context_updated(), task_name)
self._context_updated_tasks.add(task)
task.add_done_callback(self._context_updated_task_finished)
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
logger.debug(
f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
if frame.tool_call_id not in self._function_calls_in_progress:
return
if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption:
# Update context with the function call cancellation
self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED")
del self._function_calls_in_progress[frame.tool_call_id]
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
for message in self._context.messages:
if (
message["role"] == "tool"
and message["tool_call_id"]
and message["tool_call_id"] == tool_call_id
):
message["content"] = result
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
logger.debug(
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
)
if frame.request.tool_call_id not in self._function_calls_in_progress:
logger.warning(
f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running"
)
return
del self._function_calls_in_progress[frame.request.tool_call_id]
# Update context with the image frame
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
await self._push_aggregation()
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started += 1
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
self._started -= 1
await self._push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
return
if self._params.expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
def _context_updated_task_finished(self, task: asyncio.Task):
self._context_updated_tasks.discard(task)
# The task is finished so this should exit immediately. We need to do
# this because otherwise the task manager would report a dangling task
# if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
@dataclass
class LLMContextAggregatorPair:
"""Pair of LLM context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: LLMUserAggregator
_assistant: LLMAssistantAggregator
@staticmethod
def create(
context: LLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> "LLMContextAggregatorPair":
"""Factory method to create an LLMContextAggregatorPair.
Args:
context: The context managed by the aggregators.
user_params: Parameters for the user context aggregator.
assistant_params: Parameters for the assistant context aggregator.
Returns:
LLMContextAggregatorPair: A new instance with configured aggregators.
"""
user = LLMUserAggregator(context, params=user_params)
assistant = LLMAssistantAggregator(context, params=assistant_params)
return LLMContextAggregatorPair(_user=user, _assistant=assistant)
def user(self) -> LLMUserAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> LLMAssistantAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant