Compare commits

...

12 Commits

Author SHA1 Message Date
Paul Kompfner
adff456daf Progress on LLM failover support 2025-07-28 10:09:23 -04:00
Paul Kompfner
3785481a45 Progress on LLM failover support 2025-07-25 16:55:17 -04:00
Paul Kompfner
ed2177a579 Progress on LLM failover support 2025-07-25 16:36:32 -04:00
Paul Kompfner
ee2aade12c Progress on LLM failover support 2025-07-25 14:56:38 -04:00
Paul Kompfner
602724b984 Progress on LLM failover support 2025-07-24 16:35:51 -04:00
Paul Kompfner
c437ff6a08 Progress on LLM failover support 2025-07-24 09:51:30 -04:00
Paul Kompfner
221e199fe0 Progress on LLM failover support 2025-07-24 09:43:53 -04:00
Paul Kompfner
35628f3af7 Progress on LLM failover support
Rename new `LLMUser/AssistantContextAggregator`s, adding a `_Universal` suffix, allowing old ones to be used while we migrate services gradually to use new universal `LLMContext` and associated patterns.
2025-07-24 09:40:28 -04:00
Paul Kompfner
1de3c9d5fd Progress on LLM failover support 2025-07-24 09:20:28 -04:00
Paul Kompfner
d5d7ee9803 Progress on LLM failover support 2025-07-23 15:42:30 -04:00
Paul Kompfner
e651f1e4df Progress on LLM failover support 2025-07-23 14:42:19 -04:00
Paul Kompfner
36fea8f9e8 Progress on LLM failover support 2025-07-23 14:34:17 -04:00
13 changed files with 1589 additions and 879 deletions

View File

@@ -11,21 +11,45 @@ adapters that handle tool format conversion and standardization.
"""
from abc import ABC, abstractmethod
from typing import Any, List, Union, cast
from typing import Any, Generic, List, TypedDict, TypeVar, Union, cast
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
# Should be a TypedDict
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
class BaseLLMAdapter(ABC):
# TODO: fix everywhere we subclass BaseLLMAdapter...
class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
"""Abstract base class for LLM provider adapters.
Provides a standard interface for converting between Pipecat's standardized
tool schemas and provider-specific tool formats. Subclasses must implement
provider-specific conversion logic.
Provides a standard interface for converting to provider-specific formats.
Handles:
- Extracting provider-specific parameters for LLM invocation from a
universal LLM context
- Converting standardized tools schema to provider-specific tool formats.
- Extracting messages from the LLM context for the purposes of logging
about the specific provider.
Subclasses must implement provider-specific conversion logic.
"""
@abstractmethod
def get_llm_invocation_params(self, context: LLMContext) -> TLLMInvocationParams:
"""Get provider-specific LLM invocation parameters from a universal LLM context.
Args:
context: The LLM context containing messages, tools, etc.
Returns:
Provider-specific parameters for invoking the LLM.
"""
pass
@abstractmethod
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
"""Convert tools schema to the provider's specific format.
@@ -38,6 +62,20 @@ class BaseLLMAdapter(ABC):
"""
pass
@abstractmethod
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
"""Get messages from the LLM context in a format ready for logging about this provider.
Args:
context: The LLM context containing messages.
Returns:
List of messages in a format ready for logging about this
provider.
"""
pass
# TODO: should this also be able to return NotGiven?
def from_standard_tools(self, tools: Any) -> List[Any]:
"""Convert tools from standard format to provider format.
@@ -54,4 +92,38 @@ class BaseLLMAdapter(ABC):
# Fallback to return the same tools in case they are not in a standard format
return tools
def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size):
"""Create a WAV file header for audio data.
Args:
sample_rate: Audio sample rate in Hz.
num_channels: Number of audio channels.
bits_per_sample: Bits per audio sample.
data_size: Size of audio data in bytes.
Returns:
WAV header as a bytearray.
"""
# RIFF chunk descriptor
header = bytearray()
header.extend(b"RIFF") # ChunkID
header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8
header.extend(b"WAVE") # Format
# "fmt " sub-chunk
header.extend(b"fmt ") # Subchunk1ID
header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM)
header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM)
header.extend(num_channels.to_bytes(2, "little")) # NumChannels
header.extend(sample_rate.to_bytes(4, "little")) # SampleRate
# Calculate byte rate and block align
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
block_align = num_channels * (bits_per_sample // 8)
header.extend(byte_rate.to_bytes(4, "little")) # ByteRate
header.extend(block_align.to_bytes(2, "little")) # BlockAlign
header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample
# "data" sub-chunk
header.extend(b"data") # Subchunk2ID
header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size
return header
# TODO: we can move the logic to also handle the Messages here

View File

@@ -6,21 +6,68 @@
"""Gemini LLM adapter for Pipecat."""
from typing import Any, Dict, List, Union
import base64
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypedDict, Union
from loguru import logger
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
try:
from google.genai.types import (
Blob,
Content,
FunctionCall,
FunctionResponse,
Part,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
raise Exception(f"Missing module: {e}")
class GeminiLLMAdapter(BaseLLMAdapter):
"""LLM adapter for Google's Gemini service.
class GeminiLLMInvocationParams(TypedDict):
"""Context-based parameters for invoking Gemini LLM."""
Provides tool schema conversion functionality to transform standard tool
definitions into Gemini's specific function-calling format for use with
Gemini LLM models.
system_instruction: Optional[str]
messages: List[Content]
tools: List[Any]
class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
"""Gemini-specific adapter for Pipecat.
Handles:
- Extracting parameters for Gemini's API from a universal
LLM context
- Converting Pipecat's standardized tools schema to Gemini's function-calling format.
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
"""
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
Args:
context: The LLM context containing messages, tools, etc.
Returns:
Dictionary of parameters for Gemini's API.
"""
# TODO: remove when done testing
print(f"[pk] {self}: Getting LLM invocation params...")
messages = self._from_standard_messages(context.messages)
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
"tools": self.from_standard_tools(context.tools),
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
"""Convert tool schemas to Gemini's function-calling format.
Args:
@@ -39,3 +86,217 @@ class GeminiLLMAdapter(BaseLLMAdapter):
custom_gemini_tools = tools_schema.custom_tools.get(AdapterType.GEMINI, [])
return formatted_standard_tools + custom_gemini_tools
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
"""Get messages from the LLM context in a format ready for logging about Gemini.
Removes or truncates sensitive data like image content for safe logging.
Args:
context: The LLM context containing messages.
Returns:
List of messages in a format ready for logging about Gemini.
"""
# Get messages in Gemini's format
messages = self._from_standard_messages(context.messages).messages
# Sanitize messages for logging
messages_for_logging = []
for message in messages:
obj = message.to_json_dict()
try:
if "parts" in obj:
for part in obj["parts"]:
if "inline_data" in part:
part["inline_data"]["data"] = "..."
except Exception as e:
logger.debug(f"Error: {e}")
messages_for_logging.append(obj)
return messages_for_logging
@dataclass
class ConvertedMessages:
"""Container for converted messages.
Holds the converted messages in a format suitable for Gemini's API.
"""
messages: List[Content]
system_instruction: Optional[str] = None
def _from_standard_messages(
self, standard_messages: List[LLMContextMessage]
) -> ConvertedMessages:
"""Restructures messages to ensure proper Google format and message ordering.
This method handles conversion of OpenAI-formatted messages to Google format,
with special handling for function calls, function responses, and system messages.
System messages are added back to the context as user messages when needed.
The final message order is preserved as:
1. Function calls (from model)
2. Function responses (from user)
3. Text messages (converted from system messages)
Note:
System messages are only added back when there are no regular text
messages in the context, ensuring proper conversation continuity
after function calls.
"""
system_instruction = None
messages = []
# Process each message, preserving Google-formatted messages and converting others
for message in standard_messages:
if isinstance(message, Content):
# Keep existing Google-formatted messages (e.g., function calls/responses)
# TODO: this branch is probably not needed anymore, since LLMContext contains a universal format
messages.append(message)
continue
# Convert standard format to Google format
converted = self._from_standard_message(message)
if isinstance(converted, Content):
# Regular (non-system) message
messages.append(converted)
else:
# System instruction
system_instruction = converted
# Check if we only have function-related messages (no regular text)
has_regular_messages = any(
len(msg.parts) == 1
and getattr(msg.parts[0], "text", None)
and not getattr(msg.parts[0], "function_call", None)
and not getattr(msg.parts[0], "function_response", None)
for msg in messages
)
# Add system instruction back as a user message if we only have function messages
if system_instruction and not has_regular_messages:
messages.append(Content(role="user", parts=[Part(text=system_instruction)]))
# Remove any empty messages
messages = [m for m in messages if m.parts]
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
def _from_standard_message(self, message: LLMContextMessage) -> Content | str:
"""Convert standard format message to Google Content object.
Handles conversion of text, images, and function calls to Google's
format.
System instructions are returned as a plain string.
Args:
message: Message in standard format.
Returns:
Content object with role and parts, or a plain string for system
messages.
Examples:
Standard text message::
{
"role": "user",
"content": "Hello there"
}
Converts to Google Content with::
Content(
role="user",
parts=[Part(text="Hello there")]
)
Standard function call message::
{
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "search",
"arguments": '{"query": "test"}'
}
}
]
}
Converts to Google Content with::
Content(
role="model",
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
)
"""
role = message["role"]
content = message.get("content", [])
if role == "system":
# System instructions are returned as plain text
# TODO: here we've always assumed that system instructions are plain text...is that a safe assumption?
return content
elif role == "assistant":
role = "model"
parts = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
Part(
function_call=FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
)
elif role == "tool":
role = "model"
parts.append(
Part(
function_response=FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
Part(
inline_data=Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
elif c["type"] == "input_audio":
input_audio = c["input_audio"]
parts.append(
Part(
inline_data=Blob(
mime_type="audio/wav",
data=(
bytes(
self.create_wav_header(
input_audio["sample_rate"],
input_audio["num_channels"],
16,
len(input_audio["data"]),
)
+ input_audio["data"]
)
),
)
)
)
message = Content(role=role, parts=parts)
return message

View File

@@ -6,22 +6,62 @@
"""OpenAI LLM adapter for Pipecat."""
from typing import List
import copy
import json
from typing import Any, List, TypedDict
from openai.types.chat import ChatCompletionToolParam
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
from openai._types import NotGiven as OpenAINotGiven
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
)
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMContextToolChoice,
NotGiven,
)
class OpenAILLMAdapter(BaseLLMAdapter):
"""Adapter for converting tool schemas to OpenAI's format.
class OpenAILLMInvocationParams(TypedDict):
"""Context-based parameters for invoking OpenAI ChatCompletion API."""
Provides conversion utilities for transforming Pipecat's standard tool
schemas into the format expected by OpenAI's ChatCompletion API for
function calling capabilities.
messages: List[ChatCompletionMessageParam]
tools: List[ChatCompletionToolParam] | OpenAINotGiven
tool_choice: ChatCompletionToolChoiceOptionParam | OpenAINotGiven
class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
"""OpenAI-specific adapter for Pipecat.
Handles:
- Extracting parameters for OpenAI's ChatCompletion API from a universal
LLM context
- Converting Pipecat's standardized tools schema to OpenAI's function-calling format.
- Extracting and sanitizing messages from the LLM context for logging with OpenAI.
"""
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
Args:
context: The LLM context containing messages, tools, etc.
Returns:
Dictionary of parameters for OpenAI's ChatCompletion API.
"""
return {
"messages": self._from_standard_messages(context.messages),
# TODO: doesn't seem right that we may or may not need to convert tools here; they should already be guaranteed to exist in a universal format in the LLMContext, right?
"tools": self.from_standard_tools(context.tools),
"tool_choice": context.tool_choice,
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]:
"""Convert function schemas to OpenAI's function-calling format.
@@ -37,3 +77,40 @@ class OpenAILLMAdapter(BaseLLMAdapter):
ChatCompletionToolParam(type="function", function=func.to_default_dict())
for func in functions_schema
]
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
"""Get messages from the LLM context in a format ready for logging about OpenAI.
Removes or truncates sensitive data like image content for safe logging.
Args:
context: The LLM context containing messages.
Returns:
List of messages in a format ready for logging about OpenAI.
"""
msgs = []
for message in context.messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item["type"] == "image_url":
if item["image_url"]["url"].startswith("data:image/"):
item["image_url"]["url"] = "data:image/..."
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
msg["data"] = "..."
msgs.append(msg)
return json.dumps(msgs, ensure_ascii=False)
def _from_standard_messages(
self, messages: List[LLMContextMessage]
) -> List[ChatCompletionMessageParam]:
# Just a pass-through: messages is already the right type
return messages
def _from_standard_tool_choice(
self, tool_choice: LLMContextToolChoice | NotGiven
) -> ChatCompletionToolChoiceOptionParam | OpenAINotGiven:
# Just a pass-through: tool_choice is already the right type
return tool_choice

View File

@@ -378,7 +378,7 @@ class TranslationFrame(TextFrame):
@dataclass
class OpenAILLMContextAssistantTimestampFrame(DataFrame):
class LLMContextAssistantTimestampFrame(DataFrame):
"""Timestamp information for assistant messages in LLM context.
Parameters:

View File

@@ -0,0 +1,211 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Universal LLM context management for LLM services in Pipecat.
Context contents are represented in a generic format (extended from OpenAI)
that supports a union of known Pipecat LLM service functionality.
Whenever an LLM service needs to access context, it does a just-in-time
translation from this universal context into whatever format it needs, using a
service-specific adapter.
"""
import base64
import io
from dataclasses import dataclass
from typing import Any, List, Optional
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
from openai._types import NotGiven as OpenAINotGiven
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
)
from PIL import Image
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import AudioRawFrame, Frame
# "Re-export" types from OpenAI that we're using as universal context types.
# NOTE: this is just for convenience, for now. As soon as the universal types
# diverge from OpenAI's, we should ditch this. In fact, audio frames already
# diverge from OpenAI's standard format...we really ought to do this.
LLMContextMessage = ChatCompletionMessageParam
LLMContextTool = ChatCompletionToolParam
LLMContextToolChoice = ChatCompletionToolChoiceOptionParam
NOT_GIVEN = OPEN_AI_NOT_GIVEN
NotGiven = OpenAINotGiven
class LLMContext:
"""Manages conversation context for LLM interactions.
Handles message history, tool definitions, tool choices, and multimedia
content for LLM conversations. Provides methods for message manipulation,
and content formatting.
"""
def __init__(
self,
messages: Optional[List[LLMContextMessage]] = None,
tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN,
tool_choice: LLMContextToolChoice | NotGiven = NOT_GIVEN,
):
"""Initialize the LLM context.
Args:
messages: Initial list of conversation messages.
tools: Available tools for the LLM to use.
tool_choice: Tool selection strategy for the LLM.
"""
self._messages: List[LLMContextMessage] = messages if messages else []
self._tools: List[LLMContextTool] | NotGiven | ToolsSchema = tools
self._tool_choice: LLMContextToolChoice | NotGiven = tool_choice
@property
def messages(self) -> List[LLMContextMessage]:
"""Get the current messages list.
Returns:
List of conversation messages.
"""
return self._messages
@property
def tools(self) -> List[LLMContextTool] | NotGiven | List[Any]:
"""Get the tools list.
Returns:
Tools list.
"""
return self._tools
@property
def tool_choice(self) -> LLMContextToolChoice | NotGiven:
"""Get the current tool choice setting.
Returns:
The tool choice configuration.
"""
return self._tool_choice
def add_message(self, message: LLMContextMessage):
"""Add a single message to the context.
Args:
message: The message to add to the conversation history.
"""
self._messages.append(message)
def add_messages(self, messages: List[LLMContextMessage]):
"""Add multiple messages to the context.
Args:
messages: List of messages to add to the conversation history.
"""
self._messages.extend(messages)
def set_messages(self, messages: List[LLMContextMessage]):
"""Replace all messages in the context.
Args:
messages: New list of messages to replace the current history.
"""
self._messages[:] = messages
def set_tools(self, tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN):
"""Set the available tools for the LLM.
Args:
tools: List of tools available to the LLM, a ToolsSchema, or NOT_GIVEN to disable tools.
"""
# TODO: convert empty ToolsSchema to NOT_GIVEN if needed?
# TODO: maybe someday also convert provider-specific tools to ToolsSchema so it's always in a provider-neutral format here? See open_ai_adapter.py for related comment. Pipecat Flows is currently converting provider-specific tools to ToolsSchema...
if isinstance(tools, list) and len(tools) == 0:
tools = NOT_GIVEN
self._tools = tools
def set_tool_choice(self, tool_choice: LLMContextToolChoice | NotGiven):
"""Set the tool choice configuration.
Args:
tool_choice: Tool selection strategy for the LLM.
"""
self._tool_choice = tool_choice
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add a message containing an image frame.
Args:
format: Image format (e.g., 'RGB', 'RGBA').
size: Image dimensions as (width, height) tuple.
image: Raw image bytes.
text: Optional text to include with the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
# TODO: we might not want the universal format to be base64 encoded, since encoding is not needed by all LLM services; today, te Gemini adapter has to decode from base64, which is less than ideal.
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
content = []
if text:
content.append({"type": "text", "text": text})
content.append(
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
)
self.add_message({"role": "user", "content": content})
# NOTE: today we've only built support for audio frames with the Google
# LLM, so this "universal" representation skews towards that.
# When we add support for other LLMs, we may need to adjust this.
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
):
"""Add a message containing audio frames.
Args:
audio_frames: List of audio frame objects to include.
text: Optional text to include with the audio.
"""
if not audio_frames:
return
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
content = []
content.append({"type": "text", "text": text})
data = b"".join(frame.audio for frame in audio_frames)
# TODO: filter this out in OpenAI adapter, since it doesn't support audio frames
content.append(
{
"type": "input_audio",
"input_audio": {
"data": data,
"sample_rate": sample_rate,
"num_channels": num_channels,
},
}
)
self.add_message({"role": "user", "content": content})
@dataclass
class LLMContextFrame(Frame):
"""Frame containing LLM context.
Used as a signal to LLM services to ingest the provided context and
generate a response based on it.
Parameters:
context: The LLM context containing messages, tools, and configuration.
"""
context: LLMContext

View File

@@ -36,6 +36,7 @@ from pipecat.frames.frames import (
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
LLMContextAssistantTimestampFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
@@ -44,7 +45,6 @@ from pipecat.frames.frames import (
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
LLMTextFrame,
OpenAILLMContextAssistantTimestampFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
@@ -864,7 +864,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
await self.push_frame(timestamp_frame)
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):

View File

@@ -0,0 +1,874 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""LLM response aggregators for handling conversation context and message aggregation.
This module provides aggregators that process and accumulate LLM responses, user inputs,
and conversation context. These aggregators handle the flow between speech-to-text,
LLM processing, and text-to-speech components in conversational AI pipelines.
"""
import asyncio
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Set
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
LLMContextAssistantTimestampFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserImageRawFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601
@dataclass
class LLMUserContextAggregatorParams:
"""Parameters for configuring LLM user context aggregation behavior.
Parameters:
aggregation_timeout: Maximum time in seconds to wait for additional
transcription content before pushing aggregated result. This
timeout is used only when the transcription is slow to arrive.
turn_emulated_vad_timeout: Maximum time in seconds to wait for emulated
VAD when using turn-based analysis. Applied when transcription is
received but VAD didn't detect speech (e.g., whispered utterances).
"""
aggregation_timeout: float = 0.5
turn_emulated_vad_timeout: float = 0.8
@dataclass
class LLMAssistantContextAggregatorParams:
"""Parameters for configuring LLM assistant context aggregation behavior.
Parameters:
expect_stripped_words: Whether to expect and handle stripped words
in text frames by adding spaces between tokens.
"""
expect_stripped_words: bool = True
class LLMContextAggregator(FrameProcessor):
"""Base LLM aggregator that uses an LLMContext for conversation storage.
This aggregator maintains conversation state using an LLMContext and
pushes LLMContextFrame objects as aggregation frames. It provides
common functionality for context-based conversation management.
"""
def __init__(self, *, context: LLMContext, role: str, **kwargs):
"""Initialize the context response aggregator.
Args:
context: The LLM context to use for conversation storage.
role: The role this aggregator represents (e.g. "user", "assistant").
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self._context = context
self._role = role
self._aggregation: str = ""
@property
def messages(self) -> List[dict]:
"""Get messages from the LLM context.
Returns:
List of message dictionaries from the context.
"""
return self._context.messages
@property
def role(self) -> str:
"""Get the role for this aggregator.
Returns:
The role string for this aggregator.
"""
return self._role
@property
def context(self):
"""Get the LLM context.
Returns:
The LLMContext instance used by this aggregator.
"""
return self._context
def get_context_frame(self) -> LLMContextFrame:
"""Create a context frame with the current context.
Returns:
LLMContextFrame containing the current context.
"""
return LLMContextFrame(context=self._context)
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a context frame in the specified direction.
Args:
direction: The direction to push the frame (upstream or downstream).
"""
frame = self.get_context_frame()
await self.push_frame(frame, direction)
def add_messages(self, messages):
"""Add messages to the context.
Args:
messages: Messages to add to the conversation context.
"""
self._context.add_messages(messages)
def set_messages(self, messages):
"""Set the context messages.
Args:
messages: Messages to replace the current context messages.
"""
self._context.set_messages(messages)
def set_tools(self, tools: List):
"""Set tools in the context.
Args:
tools: List of tool definitions to set in the context.
"""
self._context.set_tools(tools)
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
"""Set tool choice in the context.
Args:
tool_choice: Tool choice configuration for the context.
"""
self._context.set_tool_choice(tool_choice)
async def reset(self):
"""Reset the aggregation state."""
self._aggregation = ""
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
# from the old LLMUserContextAggregator while we gradually migrate service to
# use the new universal LLMContext and associated patterns. The suffix will go
# away once the migration is complete and the other LLMUserContextAggregator is
# deprecated.
class LLMUserContextAggregator_Universal(LLMContextAggregator):
"""User LLM aggregator that processes speech-to-text transcriptions.
This aggregator handles the complex logic of aggregating user speech transcriptions
from STT services. It manages multiple scenarios including:
- Transcriptions received between VAD events
- Transcriptions received outside VAD events
- Interim vs final transcriptions
- User interruptions during bot speech
- Emulated VAD for whispered or short utterances
The aggregator uses timeouts to handle cases where transcriptions arrive
after VAD events or when no VAD is available.
"""
def __init__(
self,
context: LLMContext,
*,
params: Optional[LLMUserContextAggregatorParams] = None,
**kwargs,
):
"""Initialize the user context aggregator.
Args:
context: The LLM context for conversation storage.
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'.
"""
super().__init__(context=context, role="user", **kwargs)
self._params = params or LLMUserContextAggregatorParams()
self._vad_params: Optional[VADParams] = None
self._turn_params: Optional[SmartTurnParams] = None
if "aggregation_timeout" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
self._user_speaking = False
self._bot_speaking = False
self._was_bot_speaking = False
self._emulating_vad = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
async def reset(self):
"""Reset the aggregation state and interruption strategies."""
await super().reset()
self._was_bot_speaking = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
[await s.reset() for s in self._interruption_strategies]
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for user speech aggregation and context management.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
# processed by every processor before any other frame is processed.
await self.push_frame(frame, direction)
await self._start(frame)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self.push_frame(frame, direction)
await self._stop(frame)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, InputAudioRawFrame):
await self._handle_input_audio(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame):
await self._handle_bot_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._handle_bot_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
await self._handle_transcription(frame)
elif isinstance(frame, InterimTranscriptionFrame):
await self._handle_interim_transcription(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
elif isinstance(frame, SpeechControlParamsFrame):
self._vad_params = frame.vad_params
self._turn_params = frame.turn_params
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _process_aggregation(self):
"""Process the current aggregation and push it downstream."""
aggregation = self._aggregation
await self.reset()
self._context.add_message({"role": self.role, "content": aggregation})
frame = LLMContextFrame(self._context)
await self.push_frame(frame)
async def _push_aggregation(self):
"""Push the current aggregation based on interruption strategies and conditions."""
if len(self._aggregation) > 0:
if self.interruption_strategies and self._bot_speaking:
should_interrupt = await self._should_interrupt_based_on_strategies()
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
)
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")
# Don't process aggregation, just reset it
await self.reset()
else:
# No interruption config - normal behavior (always push aggregation)
await self._process_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# Normally, when the user stops speaking, new text is expected,
# which triggers the bot to respond. However, if no new text
# is received, this safeguard ensures
# the bot doesn't hang indefinitely while waiting to speak again.
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
logger.warning("User stopped speaking but no new aggregation received.")
# Resetting it so we don't trigger this twice
self._was_bot_speaking = False
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription
# So we need more tests and probably make this feature configurable, disabled it by default.
# We are just pushing the same previous context to be processed again in this case
# await self.push_frame(LLMContextFrame(self._context))
async def _should_interrupt_based_on_strategies(self) -> bool:
"""Check if interruption should occur based on configured strategies.
Returns:
True if any interruption strategy indicates interruption should occur.
"""
async def should_interrupt(strategy: BaseInterruptionStrategy):
await strategy.append_text(self._aggregation)
return await strategy.should_interrupt()
return any([await should_interrupt(s) for s in self._interruption_strategies])
async def _start(self, frame: StartFrame):
self._create_aggregation_task()
async def _stop(self, frame: EndFrame):
await self._cancel_aggregation_task()
async def _cancel(self, frame: CancelFrame):
await self._cancel_aggregation_task()
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame()
async def _handle_input_audio(self, frame: InputAudioRawFrame):
for s in self.interruption_strategies:
await s.append_audio(frame.audio, frame.sample_rate)
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
self._user_speaking = True
self._waiting_for_aggregation = True
self._was_bot_speaking = self._bot_speaking
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
# EmulateUserStoppedSpeakingFrame).
if not frame.emulated and self._emulating_vad:
self._emulating_vad = False
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._user_speaking = False
# We just stopped speaking. Let's see if there's some aggregation to
# push. If the last thing we saw is an interim transcription, let's wait
# pushing the aggregation as we will probably get a final transcription.
if len(self._aggregation) > 0:
if not self._seen_interim_results:
await self._push_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# So in this case we are resetting the aggregation timer
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
self._bot_speaking = True
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
self._bot_speaking = False
async def _handle_transcription(self, frame: TranscriptionFrame):
text = frame.text
# Make sure we really have some text.
if not text.strip():
return
self._aggregation += f" {text}" if self._aggregation else text
# We just got a final result, so let's reset interim results.
self._seen_interim_results = False
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
self._seen_interim_results = True
def _create_aggregation_task(self):
if not self._aggregation_task:
self._aggregation_task = self.create_task(self._aggregation_task_handler())
async def _cancel_aggregation_task(self):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _aggregation_task_handler(self):
while True:
try:
# The _aggregation_task_handler handles two distinct timeout scenarios:
#
# 1. When emulating_vad=True: Wait for emulated VAD timeout before
# pushing aggregation (simulating VAD behavior when no actual VAD
# detection occurred).
#
# 2. When emulating_vad=False: Use aggregation_timeout as a buffer
# to wait for potential late-arriving transcription frames after
# a real VAD event.
#
# For emulated VAD scenarios, the timeout strategy depends on whether
# a turn analyzer is configured:
#
# - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because
# the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech
# chunking to feed the turn analyzer. This low value is too fast
# for emulated VAD scenarios where we need to allow users time to
# finish speaking (e.g. 0.8s).
#
# - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain
# consistent user experience between real VAD detection and
# emulated VAD scenarios.
if not self._emulating_vad:
timeout = self._params.aggregation_timeout
elif self._turn_params:
timeout = self._params.turn_emulated_vad_timeout
else:
# Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params
timeout = (
self._vad_params.stop_secs
if self._vad_params
else self._params.turn_emulated_vad_timeout
)
await asyncio.wait_for(self._aggregation_event.wait(), timeout)
await self._maybe_emulate_user_speaking()
except asyncio.TimeoutError:
if not self._user_speaking:
await self._push_aggregation()
# If we are emulating VAD we still need to send the user stopped
# speaking frame.
if self._emulating_vad:
await self.push_frame(
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
)
self._emulating_vad = False
finally:
self.reset_watchdog()
self._aggregation_event.clear()
async def _maybe_emulate_user_speaking(self):
"""Maybe emulate user speaking based on transcription.
Emulate user speaking if we got a transcription but it was not
detected by VAD. Only do that if the bot is not speaking.
"""
# Check if we received a transcription but VAD was not able to detect
# voice (e.g. when you whisper a short utterance). In that case, we need
# to emulate VAD (i.e. user start/stopped speaking), but we do it only
# if the bot is not speaking. If the bot is speaking and we really have
# a short utterance we don't really want to interrupt the bot.
if (
not self._user_speaking
and not self._waiting_for_aggregation
and len(self._aggregation) > 0
):
if self._bot_speaking:
# If we reached this case and the bot is speaking, let's ignore
# what the user said.
logger.debug("Ignoring user speaking emulation, bot is speaking.")
await self.reset()
else:
# The bot is not speaking so, let's trigger user speaking
# emulation.
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
self._emulating_vad = True
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
# from the old LLMAssistantContextAggregator while we gradually migrate service
# to use the new universal LLMContext and associated patterns. The suffix will
# go away once the migration is complete and the other
# LLMAssistantContextAggregator is deprecated.
class LLMAssistantContextAggregator_Universal(LLMContextAggregator):
"""Assistant LLM aggregator that processes bot responses and function calls.
This aggregator handles the complex logic of processing assistant responses including:
- Text frame aggregation between response start/end markers
- Function call lifecycle management
- Context updates with timestamps
- Tool execution and result handling
- Interruption handling during responses
The aggregator manages function calls in progress and coordinates between
text generation and tool execution phases of LLM responses.
"""
def __init__(
self,
context: LLMContext,
*,
params: Optional[LLMAssistantContextAggregatorParams] = None,
**kwargs,
):
"""Initialize the assistant context aggregator.
Args:
context: The OpenAI LLM context for conversation storage.
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'.
"""
super().__init__(context=context, role="assistant", **kwargs)
self._params = params or LLMAssistantContextAggregatorParams()
if "expect_stripped_words" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
self._started = 0
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
self._context_updated_tasks: Set[asyncio.Task] = set()
@property
def has_function_calls_in_progress(self) -> bool:
"""Check if there are any function calls currently in progress.
Returns:
True if function calls are in progress, False otherwise.
"""
return bool(self._function_calls_in_progress)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for assistant response aggregation and function call management.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
await self._handle_text(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
elif isinstance(frame, FunctionCallsStartedFrame):
await self._handle_function_calls_started(frame)
elif isinstance(frame, FunctionCallInProgressFrame):
await self._handle_function_call_in_progress(frame)
elif isinstance(frame, FunctionCallResultFrame):
await self._handle_function_call_result(frame)
elif isinstance(frame, FunctionCallCancelFrame):
await self._handle_function_call_cancel(frame)
elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id:
await self._handle_user_image_frame(frame)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._push_aggregation()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _push_aggregation(self):
"""Push the current assistant aggregation with timestamp."""
if not self._aggregation:
return
aggregation = self._aggregation.strip()
await self.reset()
if aggregation:
self._context.add_message({"role": "assistant", "content": aggregation})
# Push context frame
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
await self.push_frame(timestamp_frame)
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: StartInterruptionFrame):
await self._push_aggregation()
self._started = 0
await self.reset()
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
for function_call in frame.function_calls:
self._function_calls_in_progress[function_call.tool_call_id] = None
async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
logger.debug(
f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
# Update context with the in-progress function call
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
)
self._function_calls_in_progress[frame.tool_call_id] = frame
async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
logger.debug(
f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
if frame.tool_call_id not in self._function_calls_in_progress:
logger.warning(
f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running"
)
return
del self._function_calls_in_progress[frame.tool_call_id]
properties = frame.properties
# Update context with the function call result
if frame.result:
result = json.dumps(frame.result)
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
run_llm = False
# Run inference if the function call result requires it.
if frame.result:
if properties and properties.run_llm is not None:
# If the tool call result has a run_llm property, use it.
run_llm = properties.run_llm
elif frame.run_llm is not None:
# If the frame is indicating we should run the LLM, do it.
run_llm = frame.run_llm
else:
# If this is the last function call in progress, run the LLM.
run_llm = not bool(self._function_calls_in_progress)
if run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
# Call the `on_context_updated` callback once the function call result
# is added to the context. Also, run this in a separate task to make
# sure we don't block the pipeline.
if properties and properties.on_context_updated:
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
task = self.create_task(properties.on_context_updated(), task_name)
self._context_updated_tasks.add(task)
task.add_done_callback(self._context_updated_task_finished)
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
logger.debug(
f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]"
)
if frame.tool_call_id not in self._function_calls_in_progress:
return
if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption:
# Update context with the function call cancellation
self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED")
del self._function_calls_in_progress[frame.tool_call_id]
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
for message in self._context.messages:
if (
message["role"] == "tool"
and message["tool_call_id"]
and message["tool_call_id"] == tool_call_id
):
message["content"] = result
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
logger.debug(
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
)
if frame.request.tool_call_id not in self._function_calls_in_progress:
logger.warning(
f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running"
)
return
del self._function_calls_in_progress[frame.request.tool_call_id]
# Update context with the image frame
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
await self._push_aggregation()
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started += 1
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
self._started -= 1
await self._push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
return
if self._params.expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
def _context_updated_task_finished(self, task: asyncio.Task):
self._context_updated_tasks.discard(task)
# The task is finished so this should exit immediately. We need to do
# this because otherwise the task manager would report a dangling task
# if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
@dataclass
class LLMContextAggregatorPair:
"""Pair of LLM context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: LLMUserContextAggregator_Universal
_assistant: LLMAssistantContextAggregator_Universal
@staticmethod
def create(
context: LLMContext,
*,
user_params: LLMUserContextAggregatorParams = LLMUserContextAggregatorParams(),
assistant_params: LLMAssistantContextAggregatorParams = LLMAssistantContextAggregatorParams(),
) -> "LLMContextAggregatorPair":
"""Factory method to create an LLMContextAggregatorPair.
Args:
context: The context managed by the aggregators.
user_params: Parameters for the user context aggregator.
assistant_params: Parameters for the assistant context aggregator.
Returns:
LLMContextAggregatorPair: A new instance with configured aggregators.
"""
user = LLMUserContextAggregator_Universal(context, params=user_params)
assistant = LLMAssistantContextAggregator_Universal(context, params=assistant_params)
return LLMContextAggregatorPair(_user=user, _assistant=assistant)
def user(self) -> LLMUserContextAggregator_Universal:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> LLMAssistantContextAggregator_Universal:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant

View File

@@ -72,7 +72,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import (
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
)
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport

View File

@@ -10,49 +10,29 @@ This module provides Google Gemini integration for the Pipecat framework,
including LLM services, context management, and message aggregation.
"""
import base64
import io
import json
import os
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter, GeminiLLMInvocationParams
from pipecat.frames.frames import (
AudioRawFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
UserImageRawFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.google.frames import LLMSearchResponseFrame
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_llm
@@ -63,13 +43,8 @@ try:
from google import genai
from google.api_core.exceptions import DeadlineExceeded
from google.genai.types import (
Blob,
Content,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
HttpOptions,
Part,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
@@ -77,577 +52,12 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
"""Google-specific user context aggregator.
Extends OpenAI user context aggregator to handle Google AI's specific
Content and Part message format for user messages.
"""
async def push_aggregation(self):
"""Push aggregated user text as a Google Content message."""
if len(self._aggregation) > 0:
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
# Push context frame
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
# Reset our accumulator state.
await self.reset()
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Google-specific assistant context aggregator.
Extends OpenAI assistant context aggregator to handle Google AI's specific
Content and Part message format for assistant responses and function calls.
"""
async def handle_aggregation(self, aggregation: str):
"""Handle aggregated assistant text response.
Args:
aggregation: The aggregated text response from the assistant.
"""
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle function call in progress frame.
Args:
frame: Frame containing function call details.
"""
self._context.add_message(
Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
)
)
],
)
)
self._context.add_message(
Content(
role="user",
parts=[
Part(
function_response=FunctionResponse(
id=frame.tool_call_id,
name=frame.function_name,
response={"response": "IN_PROGRESS"},
)
)
],
)
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle function call result frame.
Args:
frame: Frame containing function call result.
"""
if frame.result:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, frame.result
)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle function call cancellation frame.
Args:
frame: Frame containing function call cancellation details.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if message.role == "user":
for part in message.parts:
if part.function_response and part.function_response.id == tool_call_id:
part.function_response.response = {"value": json.dumps(result)}
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle user image frame.
Args:
frame: Frame containing user image data and request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
@dataclass
class GoogleContextAggregatorPair:
"""Pair of Google context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for handling user messages.
_assistant: Assistant context aggregator for handling assistant responses.
"""
_user: GoogleUserContextAggregator
_assistant: GoogleAssistantContextAggregator
def user(self) -> GoogleUserContextAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> GoogleAssistantContextAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class GoogleLLMContext(OpenAILLMContext):
"""Google AI LLM context that extends OpenAI context for Google-specific formatting.
This class handles conversion between OpenAI-style messages and Google AI's
Content/Part format, including system messages, function calls, and media.
"""
def __init__(
self,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
):
"""Initialize GoogleLLMContext.
Args:
messages: Initial messages in OpenAI format.
tools: Available tools/functions for the model.
tool_choice: Tool choice configuration.
"""
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self.system_message = None
@staticmethod
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
"""Upgrade an OpenAI context to a Google context.
Args:
obj: OpenAI LLM context to upgrade.
Returns:
GoogleLLMContext instance with converted messages.
"""
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
logger.debug(f"Upgrading to Google: {obj}")
obj.__class__ = GoogleLLMContext
obj._restructure_from_openai_messages()
return obj
def set_messages(self, messages: List):
"""Set messages and restructure them for Google format.
Args:
messages: List of messages to set.
"""
self._messages[:] = messages
self._restructure_from_openai_messages()
def add_messages(self, messages: List):
"""Add messages to the context, converting to Google format as needed.
Args:
messages: List of messages to add (can be mixed formats).
"""
# Convert each message individually
converted_messages = []
for msg in messages:
if isinstance(msg, Content):
# Already in Gemini format
converted_messages.append(msg)
else:
# Convert from standard format to Gemini format
converted = self.from_standard_message(msg)
if converted is not None:
converted_messages.append(converted)
# Add the converted messages to our existing messages
self._messages.extend(converted_messages)
def get_messages_for_logging(self):
"""Get messages formatted for logging with sensitive data redacted.
Returns:
List of message dictionaries with inline data redacted.
"""
msgs = []
for message in self.messages:
obj = message.to_json_dict()
try:
if "parts" in obj:
for part in obj["parts"]:
if "inline_data" in part:
part["inline_data"]["data"] = "..."
except Exception as e:
logger.debug(f"Error: {e}")
msgs.append(obj)
return msgs
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add an image message to the context.
Args:
format: Image format (e.g., 'RGB', 'RGBA').
size: Image dimensions as (width, height).
image: Raw image bytes.
text: Optional text to accompany the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
parts = []
if text:
parts.append(Part(text=text))
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
self.add_message(Content(role="user", parts=parts))
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
):
"""Add audio frames as a message to the context.
Args:
audio_frames: List of audio frames to add.
text: Text description of the audio content.
"""
if not audio_frames:
return
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
parts = []
data = b"".join(frame.audio for frame in audio_frames)
# NOTE(aleix): According to the docs only text or inline_data should be needed.
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
parts.append(Part(text=text))
parts.append(
Part(
inline_data=Blob(
mime_type="audio/wav",
data=(
bytes(
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
)
),
)
),
)
self.add_message(Content(role="user", parts=parts))
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
# self.add_message(message)
def from_standard_message(self, message):
"""Convert standard format message to Google Content object.
Handles conversion of text, images, and function calls to Google's format.
System messages are stored separately and return None.
Args:
message: Message in standard format.
Returns:
Content object with role and parts, or None for system messages.
Examples:
Standard text message::
{
"role": "user",
"content": "Hello there"
}
Converts to Google Content with::
Content(
role="user",
parts=[Part(text="Hello there")]
)
Standard function call message::
{
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "search",
"arguments": '{"query": "test"}'
}
}
]
}
Converts to Google Content with::
Content(
role="model",
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
)
System message returns None and stores content in self.system_message.
"""
role = message["role"]
content = message.get("content", [])
if role == "system":
self.system_message = content
return None
elif role == "assistant":
role = "model"
parts = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
Part(
function_call=FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
)
elif role == "tool":
role = "model"
parts.append(
Part(
function_response=FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
Part(
inline_data=Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
message = Content(role=role, parts=parts)
return message
def to_standard_messages(self, obj) -> list:
"""Convert Google Content object to standard structured format.
Handles text, images, and function calls from Google's Content/Part objects.
Args:
obj: Google Content object with role and parts.
Returns:
List containing a single message in standard format.
Examples:
Google Content with text::
Content(
role="user",
parts=[Part(text="Hello")]
)
Converts to::
[
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}
]
Google Content with function call::
Content(
role="model",
parts=[Part(function_call=FunctionCall(name="search", args={"q": "test"}))]
)
Converts to::
[
{
"role": "assistant",
"tool_calls": [
{
"id": "search",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q": "test"}'
}
}
]
}
]
Google Content with image::
Content(
role="user",
parts=[Part(inline_data=Blob(mime_type="image/jpeg", data=bytes_data))]
)
Converts to::
[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,<encoded_data>"}
}
]
}
]
"""
msg = {"role": obj.role, "content": []}
if msg["role"] == "model":
msg["role"] = "assistant"
for part in obj.parts:
if part.text:
msg["content"].append({"type": "text", "text": part.text})
elif part.inline_data:
encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
msg["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
}
)
elif part.function_call:
args = part.function_call.args if hasattr(part.function_call, "args") else {}
msg["tool_calls"] = [
{
"id": part.function_call.name,
"type": "function",
"function": {
"name": part.function_call.name,
"arguments": json.dumps(args),
},
}
]
elif part.function_response:
msg["role"] = "tool"
resp = (
part.function_response.response
if hasattr(part.function_response, "response")
else {}
)
msg["tool_call_id"] = part.function_response.name
msg["content"] = json.dumps(resp)
# there might be no content parts for tool_calls messages
if not msg["content"]:
del msg["content"]
return [msg]
def _restructure_from_openai_messages(self):
"""Restructures messages to ensure proper Google format and message ordering.
This method handles conversion of OpenAI-formatted messages to Google format,
with special handling for function calls, function responses, and system messages.
System messages are added back to the context as user messages when needed.
The final message order is preserved as:
1. Function calls (from model)
2. Function responses (from user)
3. Text messages (converted from system messages)
Note:
System messages are only added back when there are no regular text
messages in the context, ensuring proper conversation continuity
after function calls.
"""
self.system_message = None
converted_messages = []
# Process each message, preserving Google-formatted messages and converting others
for message in self._messages:
if isinstance(message, Content):
# Keep existing Google-formatted messages (e.g., function calls/responses)
converted_messages.append(message)
continue
# Convert OpenAI format to Google format, system messages return None
converted = self.from_standard_message(message)
if converted is not None:
converted_messages.append(converted)
# Update message list
self._messages[:] = converted_messages
# Check if we only have function-related messages (no regular text)
has_regular_messages = any(
len(msg.parts) == 1
and getattr(msg.parts[0], "text", None)
and not getattr(msg.parts[0], "function_call", None)
and not getattr(msg.parts[0], "function_response", None)
for msg in self._messages
)
# Add system message back as a user message if we only have function messages
if self.system_message and not has_regular_messages:
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
# Remove any empty messages
self._messages = [m for m in self._messages if m.parts]
class GoogleLLMService(LLMService):
"""Google AI (Gemini) LLM service implementation.
This class implements inference with Google's AI models, translating internally
from OpenAILLMContext to the messages format expected by the Google AI model.
We use OpenAILLMContext as a lingua franca for all LLM services to enable
easy switching between different LLMs.
from the universal LLMContext to the message format expected by the Google
AI model.
"""
# Overriding the default adapter to use the Gemini one.
@@ -750,7 +160,7 @@ class GoogleLLMService(LLMService):
logger.exception(f"Failed to unset thinking budget: {e}")
@traced_llm
async def _process_context(self, context: OpenAILLMContext):
async def _process_context(self, context: LLMContext):
await self.push_frame(LLMFullResponseStartFrame())
prompt_tokens = 0
@@ -763,19 +173,31 @@ class GoogleLLMService(LLMService):
search_result = ""
try:
logger.debug(
# f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
f"{self}: Generating chat [{context.get_messages_for_logging()}]"
adapter = self.get_llm_adapter()
llm_invocation_params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
context
)
messages = context.messages
if context.system_message and self._system_instruction != context.system_message:
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
logger.debug(
# TODO: figure out a nice way to also log system instruction
# f"{self}: Generating chat [{self._system_instruction}] | [{adapter.get_messages_for_logging(context)}]"
f"{self}: Generating chat [{adapter.get_messages_for_logging(context)}]"
)
messages = llm_invocation_params["messages"]
if (
llm_invocation_params.get("system_instruction")
and self._system_instruction != llm_invocation_params["system_instruction"]
):
logger.debug(
f"System instruction changed: {llm_invocation_params['system_instruction']}"
)
self._system_instruction = llm_invocation_params["system_instruction"]
# TODO: test what happens when there are no tools
tools = []
if context.tools:
tools = context.tools
if llm_invocation_params.get("tools"):
tools = llm_invocation_params["tools"]
elif self._tools:
tools = self._tools
tool_config = None
@@ -922,12 +344,12 @@ class GoogleLLMService(LLMService):
context = None
if isinstance(frame, OpenAILLMContextFrame):
context = GoogleLLMContext.upgrade_to_google(frame.context)
if isinstance(frame, LLMContextFrame):
context = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = GoogleLLMContext(frame.messages)
context = LLMContext(messages=frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = GoogleLLMContext()
context = LLMContext()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
@@ -938,34 +360,3 @@ class GoogleLLMService(LLMService):
if context:
await self._process_context(context)
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GoogleContextAggregatorPair:
"""Create Google-specific context aggregators.
Creates a pair of context aggregators optimized for Google's message format,
including support for function calls, tool usage, and image handling.
Args:
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
GoogleContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
GoogleContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
if isinstance(context, OpenAILLMContext):
context = GoogleLLMContext.upgrade_to_google(context)
user = GoogleUserContextAggregator(context, params=user_params)
assistant = GoogleAssistantContextAggregator(context, params=assistant_params)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -41,6 +41,7 @@ from pipecat.frames.frames import (
StartInterruptionFrame,
UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -89,7 +90,8 @@ class FunctionCallParams:
tool_call_id: str
arguments: Mapping[str, Any]
llm: "LLMService"
context: OpenAILLMContext
# TODO: after migration of all services to universal LLMContext, OpenAILLMContext can be removed
context: LLMContext | OpenAILLMContext
result_callback: FunctionCallResultCallback
@@ -418,7 +420,10 @@ class LLMService(AIService):
else:
await self._sequential_runner_queue.put(runner_item)
async def _call_start_function(self, context: OpenAILLMContext, function_name: str):
# TODO: after migration of all services to universal LLMContext, OpenAILLMContext can be removed
async def _call_start_function(
self, context: LLMContext | OpenAILLMContext, function_name: str
):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base OpenAI LLM service implementation."""
"""Base LLM service implementation for services that use the AsyncOpenAI client."""
import base64
import json
@@ -21,6 +21,7 @@ from openai import (
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel, Field
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.frames.frames import (
Frame,
LLMFullResponseEndFrame,
@@ -31,9 +32,9 @@ from pipecat.frames.frames import (
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
@@ -44,8 +45,8 @@ from pipecat.utils.tracing.service_decorators import traced_llm
class BaseOpenAILLMService(LLMService):
"""Base class for all services that use the AsyncOpenAI client.
This service consumes OpenAILLMContextFrame frames, which contain a reference
to an OpenAILLMContext object. The context defines what is sent to the LLM for
This service consumes LLMContextFrame frames, which contain a reference to
an LLMContext object. The context defines what is sent to the LLM for
completion, including user, assistant, and system messages, as well as tool
choices and function call configurations.
"""
@@ -173,13 +174,13 @@ class BaseOpenAILLMService(LLMService):
return True
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
self, params_from_context: OpenAILLMInvocationParams
) -> AsyncStream[ChatCompletionChunk]:
"""Get streaming chat completions from OpenAI API.
Args:
context: The LLM context containing tools and configuration.
messages: List of chat completion messages to send.
params_from_context: Parameters, derived from the LLM context, to
use for the chat completion. Contains messages, tools, and tool choice.
Returns:
Async stream of chat completion chunks.
@@ -187,9 +188,6 @@ class BaseOpenAILLMService(LLMService):
params = {
"model": self.model_name,
"stream": True,
"messages": messages,
"tools": context.tools,
"tool_choice": context.tool_choice,
"stream_options": {"include_usage": True},
"frequency_penalty": self._settings["frequency_penalty"],
"presence_penalty": self._settings["presence_penalty"],
@@ -200,39 +198,28 @@ class BaseOpenAILLMService(LLMService):
"max_completion_tokens": self._settings["max_completion_tokens"],
}
# Messages, tools, tool_choice
params.update(params_from_context)
params.update(self._settings["extra"])
chunks = await self._client.chat.completions.create(**params)
return chunks
async def _stream_chat_completions(
self, context: OpenAILLMContext
self, context: LLMContext
) -> AsyncStream[ChatCompletionChunk]:
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
adapter = self.get_llm_adapter()
logger.debug(f"{self}: Generating chat [{adapter.get_messages_for_logging(context)}]")
messages: List[ChatCompletionMessageParam] = context.get_messages()
params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(context)
# base64 encode any images
for message in messages:
if message.get("mime_type") == "image/jpeg":
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
text = message["content"]
message["content"] = [
{"type": "text", "text": text},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
]
del message["data"]
del message["mime_type"]
chunks = await self.get_chat_completions(context, messages)
chunks = await self.get_chat_completions(params)
return chunks
@traced_llm
async def _process_context(self, context: OpenAILLMContext):
async def _process_context(self, context: LLMContext):
functions_list = []
arguments_list = []
tool_id_list = []
@@ -331,7 +318,7 @@ class BaseOpenAILLMService(LLMService):
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for LLM completion requests.
Handles OpenAILLMContextFrame, LLMMessagesFrame, VisionImageRawFrame,
Handles LLMContextFrame, LLMMessagesFrame, VisionImageRawFrame,
and LLMUpdateSettingsFrame to trigger LLM completions and manage settings.
Args:
@@ -341,12 +328,12 @@ class BaseOpenAILLMService(LLMService):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
if isinstance(frame, LLMContextFrame):
context: LLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
context = LLMContext(messages=frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext()
context = LLMContext()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)

View File

@@ -16,45 +16,16 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
UserImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantContextAggregator_Universal,
LLMAssistantContextAggregatorParams,
LLMUserContextAggregator_Universal,
LLMUserContextAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.base_llm import BaseOpenAILLMService
@dataclass
class OpenAIContextAggregatorPair:
"""Pair of OpenAI context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: "OpenAIUserContextAggregator"
_assistant: "OpenAIAssistantContextAggregator"
def user(self) -> "OpenAIUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "OpenAIAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class OpenAILLMService(BaseOpenAILLMService):
"""OpenAI LLM service implementation.
@@ -78,141 +49,3 @@ class OpenAILLMService(BaseOpenAILLMService):
**kwargs: Additional arguments passed to the parent BaseOpenAILLMService.
"""
super().__init__(model=model, params=params, **kwargs)
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create OpenAI-specific context aggregators.
Creates a pair of context aggregators optimized for OpenAI's message format,
including support for function calls, tool usage, and image handling.
Args:
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, params=user_params)
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
class OpenAIUserContextAggregator(LLMUserContextAggregator):
"""OpenAI-specific user context aggregator.
Handles aggregation of user messages for OpenAI LLM services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
"""OpenAI-specific assistant context aggregator.
Handles aggregation of assistant messages for OpenAI LLM services,
with specialized support for OpenAI's function calling format,
tool usage tracking, and image message handling.
"""
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle a function call in progress.
Adds the function call to the context with an IN_PROGRESS status
to track ongoing function execution.
Args:
frame: Frame containing function call progress information.
"""
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle the result of a function call.
Updates the context with the function call result, replacing any
previous IN_PROGRESS status.
Args:
frame: Frame containing the function call result.
"""
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle a cancelled function call.
Updates the context to mark the function call as cancelled.
Args:
frame: Frame containing the function call cancellation information.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if (
message["role"] == "tool"
and message["tool_call_id"]
and message["tool_call_id"] == tool_call_id
):
message["content"] = result
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle a user image frame from a function call request.
Marks the associated function call as completed and adds the image
to the context for processing.
Args:
frame: Frame containing the user image and request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)

View File

@@ -17,9 +17,9 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
FunctionCallResultProperties,
InterimTranscriptionFrame,
LLMContextAssistantTimestampFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OpenAILLMContextAssistantTimestampFrame,
SpeechControlParamsFrame,
StartInterruptionFrame,
TextFrame,
@@ -738,7 +738,7 @@ class TestAnthropicAssistantContextAggregator(
):
CONTEXT_CLASS = AnthropicLLMContext
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -773,7 +773,7 @@ class TestAWSBedrockAssistantContextAggregator(
):
CONTEXT_CLASS = AWSBedrockLLMContext
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -814,7 +814,7 @@ class TestGoogleAssistantContextAggregator(
):
CONTEXT_CLASS = GoogleLLMContext
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = context.messages[index].to_json_dict()
@@ -848,4 +848,4 @@ class TestOpenAIAssistantContextAggregator(
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]