[WIP] Universal (LLM-agnostic) context machinery to support runtime LLM switching.
- Added universal `LLMContext` and associated context aggregators.
This commit is contained in:
@@ -36,6 +36,7 @@ from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
@@ -474,6 +475,20 @@ class TranscriptionUpdateFrame(DataFrame):
|
||||
return f"{self.name}(pts: {pts}, messages: {len(self.messages)})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextFrame(Frame):
|
||||
"""Frame containing a universal LLM context.
|
||||
|
||||
Used as a signal to LLM services to ingest the provided context and
|
||||
generate a response based on it.
|
||||
|
||||
Parameters:
|
||||
context: The LLM context containing messages, tools, and configuration.
|
||||
"""
|
||||
|
||||
context: "LLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesFrame(DataFrame):
|
||||
"""Frame containing LLM messages for chat completion.
|
||||
|
||||
197
src/pipecat/processors/aggregators/llm_context.py
Normal file
197
src/pipecat/processors/aggregators/llm_context.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Universal LLM context management for LLM services in Pipecat.
|
||||
|
||||
Context contents are represented in a universal format (based on OpenAI)
|
||||
that supports a union of known Pipecat LLM service functionality.
|
||||
|
||||
Whenever an LLM service needs to access context, it does a just-in-time
|
||||
translation from this universal context into whatever format it needs, using a
|
||||
service-specific adapter.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
from openai._types import NotGiven as OpenAINotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
|
||||
# "Re-export" types from OpenAI that we're using as universal context types.
|
||||
# NOTE: this is just for convenience, for now. As soon as the universal types
|
||||
# diverge from OpenAI's, we should ditch this. In fact, audio frames already
|
||||
# diverge from OpenAI's standard format...we really ought to do this.
|
||||
LLMContextMessage = ChatCompletionMessageParam
|
||||
LLMContextTool = ChatCompletionToolParam
|
||||
LLMContextToolChoice = ChatCompletionToolChoiceOptionParam
|
||||
NOT_GIVEN = OPEN_AI_NOT_GIVEN
|
||||
NotGiven = OpenAINotGiven
|
||||
|
||||
|
||||
class LLMContext:
|
||||
"""Manages conversation context for LLM interactions.
|
||||
|
||||
Handles message history, tool definitions, tool choices, and multimedia
|
||||
content for LLM conversations. Provides methods for message manipulation,
|
||||
and content formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[LLMContextMessage]] = None,
|
||||
tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN,
|
||||
tool_choice: LLMContextToolChoice | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
"""Initialize the LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial list of conversation messages.
|
||||
tools: Available tools for the LLM to use.
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
"""
|
||||
self._messages: List[LLMContextMessage] = messages if messages else []
|
||||
self._tools: List[LLMContextTool] | NotGiven | ToolsSchema = tools
|
||||
self._tool_choice: LLMContextToolChoice | NotGiven = tool_choice
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
"""Get the current messages list.
|
||||
|
||||
Returns:
|
||||
List of conversation messages.
|
||||
"""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def tools(self) -> List[LLMContextTool] | NotGiven | List[Any]:
|
||||
"""Get the tools list.
|
||||
|
||||
Returns:
|
||||
Tools list.
|
||||
"""
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> LLMContextToolChoice | NotGiven:
|
||||
"""Get the current tool choice setting.
|
||||
|
||||
Returns:
|
||||
The tool choice configuration.
|
||||
"""
|
||||
return self._tool_choice
|
||||
|
||||
def add_message(self, message: LLMContextMessage):
|
||||
"""Add a single message to the context.
|
||||
|
||||
Args:
|
||||
message: The message to add to the conversation history.
|
||||
"""
|
||||
self._messages.append(message)
|
||||
|
||||
def add_messages(self, messages: List[LLMContextMessage]):
|
||||
"""Add multiple messages to the context.
|
||||
|
||||
Args:
|
||||
messages: List of messages to add to the conversation history.
|
||||
"""
|
||||
self._messages.extend(messages)
|
||||
|
||||
def set_messages(self, messages: List[LLMContextMessage]):
|
||||
"""Replace all messages in the context.
|
||||
|
||||
Args:
|
||||
messages: New list of messages to replace the current history.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def set_tools(self, tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
Args:
|
||||
tools: List of tools available to the LLM, a ToolsSchema, or NOT_GIVEN to disable tools.
|
||||
"""
|
||||
# TODO: convert empty ToolsSchema to NOT_GIVEN if needed?
|
||||
# TODO: maybe someday also convert provider-specific tools to ToolsSchema so it's always in a provider-neutral format here? See open_ai_adapter.py for related comment. Pipecat Flows is currently converting provider-specific tools to ToolsSchema...
|
||||
if isinstance(tools, list) and len(tools) == 0:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def set_tool_choice(self, tool_choice: LLMContextToolChoice | NotGiven):
|
||||
"""Set the tool choice configuration.
|
||||
|
||||
Args:
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
# TODO: we might not want the universal format to be base64 encoded, since encoding is not needed by all LLM services; today, te Gemini adapter has to decode from base64, which is less than ideal.
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = []
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
# NOTE: today we've only built support for audio frames with the Google
|
||||
# LLM, so this "universal" representation skews towards that.
|
||||
# When we add support for other LLMs, we may need to adjust this.
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
|
||||
Args:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# TODO: filter this out in OpenAI adapter, since it doesn't support audio frames
|
||||
content.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": data,
|
||||
"sample_rate": sample_rate,
|
||||
"num_channels": num_channels,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
850
src/pipecat/processors/aggregators/llm_response_universal.py
Normal file
850
src/pipecat/processors/aggregators/llm_response_universal.py
Normal file
@@ -0,0 +1,850 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LLM response aggregators for handling conversation context and message aggregation.
|
||||
|
||||
This module provides aggregators that process and accumulate LLM responses, user inputs,
|
||||
and conversation context. These aggregators handle the flow between speech-to-text,
|
||||
LLM processing, and text-to-speech components in conversational AI pipelines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class LLMContextAggregator(FrameProcessor):
|
||||
"""Base LLM aggregator that uses an LLMContext for conversation storage.
|
||||
|
||||
This aggregator maintains conversation state using an LLMContext and
|
||||
pushes LLMContextFrame objects as aggregation frames. It provides
|
||||
common functionality for context-based conversation management.
|
||||
"""
|
||||
|
||||
def __init__(self, *, context: LLMContext, role: str, **kwargs):
|
||||
"""Initialize the context response aggregator.
|
||||
|
||||
Args:
|
||||
context: The LLM context to use for conversation storage.
|
||||
role: The role this aggregator represents (e.g. "user", "assistant").
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation: str = ""
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
"""Get messages from the LLM context.
|
||||
|
||||
Returns:
|
||||
List of message dictionaries from the context.
|
||||
"""
|
||||
return self._context.messages
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
"""Get the role for this aggregator.
|
||||
|
||||
Returns:
|
||||
The role string for this aggregator.
|
||||
"""
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
"""Get the LLM context.
|
||||
|
||||
Returns:
|
||||
The LLMContext instance used by this aggregator.
|
||||
"""
|
||||
return self._context
|
||||
|
||||
def get_context_frame(self) -> LLMContextFrame:
|
||||
"""Create a context frame with the current context.
|
||||
|
||||
Returns:
|
||||
LLMContextFrame containing the current context.
|
||||
"""
|
||||
return LLMContextFrame(context=self._context)
|
||||
|
||||
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a context frame in the specified direction.
|
||||
|
||||
Args:
|
||||
direction: The direction to push the frame (upstream or downstream).
|
||||
"""
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def add_messages(self, messages):
|
||||
"""Add messages to the context.
|
||||
|
||||
Args:
|
||||
messages: Messages to add to the conversation context.
|
||||
"""
|
||||
self._context.add_messages(messages)
|
||||
|
||||
def set_messages(self, messages):
|
||||
"""Set the context messages.
|
||||
|
||||
Args:
|
||||
messages: Messages to replace the current context messages.
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def set_tools(self, tools: List):
|
||||
"""Set tools in the context.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions to set in the context.
|
||||
"""
|
||||
self._context.set_tools(tools)
|
||||
|
||||
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
|
||||
"""Set tool choice in the context.
|
||||
|
||||
Args:
|
||||
tool_choice: Tool choice configuration for the context.
|
||||
"""
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state."""
|
||||
self._aggregation = ""
|
||||
|
||||
|
||||
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
|
||||
# from the old LLMUserContextAggregator while we gradually migrate service to
|
||||
# use the new universal LLMContext and associated patterns. The suffix will go
|
||||
# away once the migration is complete and the other LLMUserContextAggregator is
|
||||
# deprecated.
|
||||
class LLMUserAggregator(LLMContextAggregator):
|
||||
"""User LLM aggregator that processes speech-to-text transcriptions.
|
||||
|
||||
This aggregator handles the complex logic of aggregating user speech transcriptions
|
||||
from STT services. It manages multiple scenarios including:
|
||||
|
||||
- Transcriptions received between VAD events
|
||||
- Transcriptions received outside VAD events
|
||||
- Interim vs final transcriptions
|
||||
- User interruptions during bot speech
|
||||
- Emulated VAD for whispered or short utterances
|
||||
|
||||
The aggregator uses timeouts to handle cases where transcriptions arrive
|
||||
after VAD events or when no VAD is available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: LLMContext,
|
||||
*,
|
||||
params: Optional[LLMUserAggregatorParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the user context aggregator.
|
||||
|
||||
Args:
|
||||
context: The LLM context for conversation storage.
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'.
|
||||
"""
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserAggregatorParams()
|
||||
self._vad_params: Optional[VADParams] = None
|
||||
self._turn_params: Optional[SmartTurnParams] = None
|
||||
|
||||
if "aggregation_timeout" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
|
||||
|
||||
self._user_speaking = False
|
||||
self._bot_speaking = False
|
||||
self._was_bot_speaking = False
|
||||
self._emulating_vad = False
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state and interruption strategies."""
|
||||
await super().reset()
|
||||
self._was_bot_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
[await s.reset() for s in self._interruption_strategies]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for user speech aggregation and context management.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
# processed by every processor before any other frame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
self._vad_params = frame.vad_params
|
||||
self._turn_params = frame.turn_params
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_aggregation(self):
|
||||
"""Process the current aggregation and push it downstream."""
|
||||
aggregation = self._aggregation
|
||||
await self.reset()
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
frame = LLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
"""Push the current aggregation based on interruption strategies and conditions."""
|
||||
if len(self._aggregation) > 0:
|
||||
if self.interruption_strategies and self._bot_speaking:
|
||||
should_interrupt = await self._should_interrupt_based_on_strategies()
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
# Don't process aggregation, just reset it
|
||||
await self.reset()
|
||||
else:
|
||||
# No interruption config - normal behavior (always push aggregation)
|
||||
await self._process_aggregation()
|
||||
# Handles the case where both the user and the bot are not speaking,
|
||||
# and the bot was previously speaking before the user interruption.
|
||||
# Normally, when the user stops speaking, new text is expected,
|
||||
# which triggers the bot to respond. However, if no new text
|
||||
# is received, this safeguard ensures
|
||||
# the bot doesn't hang indefinitely while waiting to speak again.
|
||||
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
|
||||
logger.warning("User stopped speaking but no new aggregation received.")
|
||||
# Resetting it so we don't trigger this twice
|
||||
self._was_bot_speaking = False
|
||||
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription
|
||||
# So we need more tests and probably make this feature configurable, disabled it by default.
|
||||
# We are just pushing the same previous context to be processed again in this case
|
||||
# await self.push_frame(LLMContextFrame(self._context))
|
||||
|
||||
async def _should_interrupt_based_on_strategies(self) -> bool:
|
||||
"""Check if interruption should occur based on configured strategies.
|
||||
|
||||
Returns:
|
||||
True if any interruption strategy indicates interruption should occur.
|
||||
"""
|
||||
|
||||
async def should_interrupt(strategy: BaseInterruptionStrategy):
|
||||
await strategy.append_text(self._aggregation)
|
||||
return await strategy.should_interrupt()
|
||||
|
||||
return any([await should_interrupt(s) for s in self._interruption_strategies])
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._create_aggregation_task()
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_input_audio(self, frame: InputAudioRawFrame):
|
||||
for s in self.interruption_strategies:
|
||||
await s.append_audio(frame.audio, frame.sample_rate)
|
||||
|
||||
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
self._waiting_for_aggregation = True
|
||||
self._was_bot_speaking = self._bot_speaking
|
||||
|
||||
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
|
||||
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
|
||||
# EmulateUserStoppedSpeakingFrame).
|
||||
if not frame.emulated and self._emulating_vad:
|
||||
self._emulating_vad = False
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
# We just stopped speaking. Let's see if there's some aggregation to
|
||||
# push. If the last thing we saw is an interim transcription, let's wait
|
||||
# pushing the aggregation as we will probably get a final transcription.
|
||||
if len(self._aggregation) > 0:
|
||||
if not self._seen_interim_results:
|
||||
await self._push_aggregation()
|
||||
# Handles the case where both the user and the bot are not speaking,
|
||||
# and the bot was previously speaking before the user interruption.
|
||||
# So in this case we are resetting the aggregation timer
|
||||
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
|
||||
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
# Make sure we really have some text.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._aggregation += f" {text}" if self._aggregation else text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
|
||||
def _create_aggregation_task(self):
|
||||
if not self._aggregation_task:
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
async def _cancel_aggregation_task(self):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
# The _aggregation_task_handler handles two distinct timeout scenarios:
|
||||
#
|
||||
# 1. When emulating_vad=True: Wait for emulated VAD timeout before
|
||||
# pushing aggregation (simulating VAD behavior when no actual VAD
|
||||
# detection occurred).
|
||||
#
|
||||
# 2. When emulating_vad=False: Use aggregation_timeout as a buffer
|
||||
# to wait for potential late-arriving transcription frames after
|
||||
# a real VAD event.
|
||||
#
|
||||
# For emulated VAD scenarios, the timeout strategy depends on whether
|
||||
# a turn analyzer is configured:
|
||||
#
|
||||
# - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because
|
||||
# the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech
|
||||
# chunking to feed the turn analyzer. This low value is too fast
|
||||
# for emulated VAD scenarios where we need to allow users time to
|
||||
# finish speaking (e.g. 0.8s).
|
||||
#
|
||||
# - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain
|
||||
# consistent user experience between real VAD detection and
|
||||
# emulated VAD scenarios.
|
||||
if not self._emulating_vad:
|
||||
timeout = self._params.aggregation_timeout
|
||||
elif self._turn_params:
|
||||
timeout = self._params.turn_emulated_vad_timeout
|
||||
else:
|
||||
# Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params
|
||||
timeout = (
|
||||
self._vad_params.stop_secs
|
||||
if self._vad_params
|
||||
else self._params.turn_emulated_vad_timeout
|
||||
)
|
||||
await asyncio.wait_for(self._aggregation_event.wait(), timeout)
|
||||
await self._maybe_emulate_user_speaking()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
await self._push_aggregation()
|
||||
|
||||
# If we are emulating VAD we still need to send the user stopped
|
||||
# speaking frame.
|
||||
if self._emulating_vad:
|
||||
await self.push_frame(
|
||||
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
|
||||
)
|
||||
self._emulating_vad = False
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
self._aggregation_event.clear()
|
||||
|
||||
async def _maybe_emulate_user_speaking(self):
|
||||
"""Maybe emulate user speaking based on transcription.
|
||||
|
||||
Emulate user speaking if we got a transcription but it was not
|
||||
detected by VAD. Behavior when bot is speaking depends on the
|
||||
enable_emulated_vad_interruptions parameter.
|
||||
"""
|
||||
# Check if we received a transcription but VAD was not able to detect
|
||||
# voice (e.g. when you whisper a short utterance). In that case, we need
|
||||
# to emulate VAD (i.e. user start/stopped speaking), but we do it only
|
||||
# if the bot is not speaking. If the bot is speaking and we really have
|
||||
# a short utterance we don't really want to interrupt the bot.
|
||||
if (
|
||||
not self._user_speaking
|
||||
and not self._waiting_for_aggregation
|
||||
and len(self._aggregation) > 0
|
||||
):
|
||||
if self._bot_speaking and not self._params.enable_emulated_vad_interruptions:
|
||||
# If emulated VAD interruptions are disabled and bot is speaking, ignore
|
||||
logger.debug("Ignoring user speaking emulation, bot is speaking.")
|
||||
await self.reset()
|
||||
else:
|
||||
# Either bot is not speaking, or emulated VAD interruptions are enabled
|
||||
# - trigger user speaking emulation.
|
||||
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
self._emulating_vad = True
|
||||
|
||||
|
||||
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
|
||||
# from the old LLMAssistantContextAggregator while we gradually migrate service
|
||||
# to use the new universal LLMContext and associated patterns. The suffix will
|
||||
# go away once the migration is complete and the other
|
||||
# LLMAssistantContextAggregator is deprecated.
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""Assistant LLM aggregator that processes bot responses and function calls.
|
||||
|
||||
This aggregator handles the complex logic of processing assistant responses including:
|
||||
|
||||
- Text frame aggregation between response start/end markers
|
||||
- Function call lifecycle management
|
||||
- Context updates with timestamps
|
||||
- Tool execution and result handling
|
||||
- Interruption handling during responses
|
||||
|
||||
The aggregator manages function calls in progress and coordinates between
|
||||
text generation and tool execution phases of LLM responses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: LLMContext,
|
||||
*,
|
||||
params: Optional[LLMAssistantAggregatorParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the assistant context aggregator.
|
||||
|
||||
Args:
|
||||
context: The OpenAI LLM context for conversation storage.
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'.
|
||||
"""
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._params = params or LLMAssistantAggregatorParams()
|
||||
|
||||
if "expect_stripped_words" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
|
||||
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def has_function_calls_in_progress(self) -> bool:
|
||||
"""Check if there are any function calls currently in progress.
|
||||
|
||||
Returns:
|
||||
True if function calls are in progress, False otherwise.
|
||||
"""
|
||||
return bool(self._function_calls_in_progress)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for assistant response aggregation and function call management.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, FunctionCallsStartedFrame):
|
||||
await self._handle_function_calls_started(frame)
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
await self._handle_function_call_in_progress(frame)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
elif isinstance(frame, FunctionCallCancelFrame):
|
||||
await self._handle_function_call_cancel(frame)
|
||||
elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id:
|
||||
await self._handle_user_image_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._push_aggregation()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
"""Push the current assistant aggregation with timestamp."""
|
||||
if not self._aggregation:
|
||||
return
|
||||
|
||||
aggregation = self._aggregation.strip()
|
||||
await self.reset()
|
||||
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
await self.push_frame(timestamp_frame)
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
|
||||
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
|
||||
for function_call in frame.function_calls:
|
||||
self._function_calls_in_progress[function_call.tool_call_id] = None
|
||||
|
||||
async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
|
||||
# Update context with the in-progress function call
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
|
||||
async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
if frame.tool_call_id not in self._function_calls_in_progress:
|
||||
logger.warning(
|
||||
f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running"
|
||||
)
|
||||
return
|
||||
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
properties = frame.properties
|
||||
|
||||
# Update context with the function call result
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
|
||||
|
||||
run_llm = False
|
||||
|
||||
# Run inference if the function call result requires it.
|
||||
if frame.result:
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it.
|
||||
run_llm = properties.run_llm
|
||||
elif frame.run_llm is not None:
|
||||
# If the frame is indicating we should run the LLM, do it.
|
||||
run_llm = frame.run_llm
|
||||
else:
|
||||
# If this is the last function call in progress, run the LLM.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Call the `on_context_updated` callback once the function call result
|
||||
# is added to the context. Also, run this in a separate task to make
|
||||
# sure we don't block the pipeline.
|
||||
if properties and properties.on_context_updated:
|
||||
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
|
||||
task = self.create_task(properties.on_context_updated(), task_name)
|
||||
self._context_updated_tasks.add(task)
|
||||
task.add_done_callback(self._context_updated_task_finished)
|
||||
|
||||
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
if frame.tool_call_id not in self._function_calls_in_progress:
|
||||
return
|
||||
|
||||
if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption:
|
||||
# Update context with the function call cancellation
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED")
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
logger.debug(
|
||||
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
|
||||
)
|
||||
|
||||
if frame.request.tool_call_id not in self._function_calls_in_progress:
|
||||
logger.warning(
|
||||
f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running"
|
||||
)
|
||||
return
|
||||
|
||||
del self._function_calls_in_progress[frame.request.tool_call_id]
|
||||
|
||||
# Update context with the image frame
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
await self._push_aggregation()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started -= 1
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._params.expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
# The task is finished so this should exit immediately. We need to do
|
||||
# this because otherwise the task manager would report a dangling task
|
||||
# if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextAggregatorPair:
|
||||
"""Pair of LLM context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: LLMUserAggregator
|
||||
_assistant: LLMAssistantAggregator
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
context: LLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> "LLMContextAggregatorPair":
|
||||
"""Factory method to create an LLMContextAggregatorPair.
|
||||
|
||||
Args:
|
||||
context: The context managed by the aggregators.
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
LLMContextAggregatorPair: A new instance with configured aggregators.
|
||||
"""
|
||||
user = LLMUserAggregator(context, params=user_params)
|
||||
assistant = LLMAssistantAggregator(context, params=assistant_params)
|
||||
return LLMContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def user(self) -> LLMUserAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> LLMAssistantAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
Reference in New Issue
Block a user