Compare commits
12 Commits
hush/TurnT
...
pk/prototy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
adff456daf | ||
|
|
3785481a45 | ||
|
|
ed2177a579 | ||
|
|
ee2aade12c | ||
|
|
602724b984 | ||
|
|
c437ff6a08 | ||
|
|
221e199fe0 | ||
|
|
35628f3af7 | ||
|
|
1de3c9d5fd | ||
|
|
d5d7ee9803 | ||
|
|
e651f1e4df | ||
|
|
36fea8f9e8 |
@@ -11,21 +11,45 @@ adapters that handle tool format conversion and standardization.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Union, cast
|
||||
from typing import Any, Generic, List, TypedDict, TypeVar, Union, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
|
||||
|
||||
class BaseLLMAdapter(ABC):
|
||||
# TODO: fix everywhere we subclass BaseLLMAdapter...
|
||||
class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""Abstract base class for LLM provider adapters.
|
||||
|
||||
Provides a standard interface for converting between Pipecat's standardized
|
||||
tool schemas and provider-specific tool formats. Subclasses must implement
|
||||
provider-specific conversion logic.
|
||||
Provides a standard interface for converting to provider-specific formats.
|
||||
|
||||
Handles:
|
||||
- Extracting provider-specific parameters for LLM invocation from a
|
||||
universal LLM context
|
||||
- Converting standardized tools schema to provider-specific tool formats.
|
||||
- Extracting messages from the LLM context for the purposes of logging
|
||||
about the specific provider.
|
||||
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
|
||||
Returns:
|
||||
Provider-specific parameters for invoking the LLM.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
|
||||
"""Convert tools schema to the provider's specific format.
|
||||
@@ -38,6 +62,20 @@ class BaseLLMAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
|
||||
"""Get messages from the LLM context in a format ready for logging about this provider.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about this
|
||||
provider.
|
||||
"""
|
||||
pass
|
||||
|
||||
# TODO: should this also be able to return NotGiven?
|
||||
def from_standard_tools(self, tools: Any) -> List[Any]:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
@@ -54,4 +92,38 @@ class BaseLLMAdapter(ABC):
|
||||
# Fallback to return the same tools in case they are not in a standard format
|
||||
return tools
|
||||
|
||||
def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size):
|
||||
"""Create a WAV file header for audio data.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
num_channels: Number of audio channels.
|
||||
bits_per_sample: Bits per audio sample.
|
||||
data_size: Size of audio data in bytes.
|
||||
|
||||
Returns:
|
||||
WAV header as a bytearray.
|
||||
"""
|
||||
# RIFF chunk descriptor
|
||||
header = bytearray()
|
||||
header.extend(b"RIFF") # ChunkID
|
||||
header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8
|
||||
header.extend(b"WAVE") # Format
|
||||
# "fmt " sub-chunk
|
||||
header.extend(b"fmt ") # Subchunk1ID
|
||||
header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM)
|
||||
header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM)
|
||||
header.extend(num_channels.to_bytes(2, "little")) # NumChannels
|
||||
header.extend(sample_rate.to_bytes(4, "little")) # SampleRate
|
||||
# Calculate byte rate and block align
|
||||
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
|
||||
block_align = num_channels * (bits_per_sample // 8)
|
||||
header.extend(byte_rate.to_bytes(4, "little")) # ByteRate
|
||||
header.extend(block_align.to_bytes(2, "little")) # BlockAlign
|
||||
header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample
|
||||
# "data" sub-chunk
|
||||
header.extend(b"data") # Subchunk2ID
|
||||
header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size
|
||||
return header
|
||||
|
||||
# TODO: we can move the logic to also handle the Messages here
|
||||
|
||||
@@ -6,21 +6,68 @@
|
||||
|
||||
"""Gemini LLM adapter for Pipecat."""
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
|
||||
|
||||
try:
|
||||
from google.genai.types import (
|
||||
Blob,
|
||||
Content,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
Part,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GeminiLLMAdapter(BaseLLMAdapter):
|
||||
"""LLM adapter for Google's Gemini service.
|
||||
class GeminiLLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking Gemini LLM."""
|
||||
|
||||
Provides tool schema conversion functionality to transform standard tool
|
||||
definitions into Gemini's specific function-calling format for use with
|
||||
Gemini LLM models.
|
||||
system_instruction: Optional[str]
|
||||
messages: List[Content]
|
||||
tools: List[Any]
|
||||
|
||||
|
||||
class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
"""Gemini-specific adapter for Pipecat.
|
||||
|
||||
Handles:
|
||||
- Extracting parameters for Gemini's API from a universal
|
||||
LLM context
|
||||
- Converting Pipecat's standardized tools schema to Gemini's function-calling format.
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
# TODO: remove when done testing
|
||||
print(f"[pk] {self}: Getting LLM invocation params...")
|
||||
messages = self._from_standard_messages(context.messages)
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
|
||||
"""Convert tool schemas to Gemini's function-calling format.
|
||||
|
||||
Args:
|
||||
@@ -39,3 +86,217 @@ class GeminiLLMAdapter(BaseLLMAdapter):
|
||||
custom_gemini_tools = tools_schema.custom_tools.get(AdapterType.GEMINI, [])
|
||||
|
||||
return formatted_standard_tools + custom_gemini_tools
|
||||
|
||||
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
|
||||
"""Get messages from the LLM context in a format ready for logging about Gemini.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_standard_messages(context.messages).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
for message in messages:
|
||||
obj = message.to_json_dict()
|
||||
try:
|
||||
if "parts" in obj:
|
||||
for part in obj["parts"]:
|
||||
if "inline_data" in part:
|
||||
part["inline_data"]["data"] = "..."
|
||||
except Exception as e:
|
||||
logger.debug(f"Error: {e}")
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for converted messages.
|
||||
|
||||
Holds the converted messages in a format suitable for Gemini's API.
|
||||
"""
|
||||
|
||||
messages: List[Content]
|
||||
system_instruction: Optional[str] = None
|
||||
|
||||
def _from_standard_messages(
|
||||
self, standard_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
"""Restructures messages to ensure proper Google format and message ordering.
|
||||
|
||||
This method handles conversion of OpenAI-formatted messages to Google format,
|
||||
with special handling for function calls, function responses, and system messages.
|
||||
System messages are added back to the context as user messages when needed.
|
||||
|
||||
The final message order is preserved as:
|
||||
1. Function calls (from model)
|
||||
2. Function responses (from user)
|
||||
3. Text messages (converted from system messages)
|
||||
|
||||
Note:
|
||||
System messages are only added back when there are no regular text
|
||||
messages in the context, ensuring proper conversation continuity
|
||||
after function calls.
|
||||
"""
|
||||
system_instruction = None
|
||||
messages = []
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in standard_messages:
|
||||
if isinstance(message, Content):
|
||||
# Keep existing Google-formatted messages (e.g., function calls/responses)
|
||||
# TODO: this branch is probably not needed anymore, since LLMContext contains a universal format
|
||||
messages.append(message)
|
||||
continue
|
||||
|
||||
# Convert standard format to Google format
|
||||
converted = self._from_standard_message(message)
|
||||
if isinstance(converted, Content):
|
||||
# Regular (non-system) message
|
||||
messages.append(converted)
|
||||
else:
|
||||
# System instruction
|
||||
system_instruction = converted
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
and getattr(msg.parts[0], "text", None)
|
||||
and not getattr(msg.parts[0], "function_call", None)
|
||||
and not getattr(msg.parts[0], "function_response", None)
|
||||
for msg in messages
|
||||
)
|
||||
|
||||
# Add system instruction back as a user message if we only have function messages
|
||||
if system_instruction and not has_regular_messages:
|
||||
messages.append(Content(role="user", parts=[Part(text=system_instruction)]))
|
||||
|
||||
# Remove any empty messages
|
||||
messages = [m for m in messages if m.parts]
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
|
||||
|
||||
def _from_standard_message(self, message: LLMContextMessage) -> Content | str:
|
||||
"""Convert standard format message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's
|
||||
format.
|
||||
System instructions are returned as a plain string.
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Content object with role and parts, or a plain string for system
|
||||
messages.
|
||||
|
||||
Examples:
|
||||
Standard text message::
|
||||
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello there"
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(text="Hello there")]
|
||||
)
|
||||
|
||||
Standard function call message::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "test"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
|
||||
)
|
||||
"""
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
if role == "system":
|
||||
# System instructions are returned as plain text
|
||||
# TODO: here we've always assumed that system instructions are plain text...is that a safe assumption?
|
||||
return content
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
parts.append(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
parts.append(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=json.loads(message["content"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
parts.append(Part(text=content))
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif c["type"] == "input_audio":
|
||||
input_audio = c["input_audio"]
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
self.create_wav_header(
|
||||
input_audio["sample_rate"],
|
||||
input_audio["num_channels"],
|
||||
16,
|
||||
len(input_audio["data"]),
|
||||
)
|
||||
+ input_audio["data"]
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
@@ -6,22 +6,62 @@
|
||||
|
||||
"""OpenAI LLM adapter for Pipecat."""
|
||||
|
||||
from typing import List
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, List, TypedDict
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
from openai._types import NotGiven as OpenAINotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
|
||||
class OpenAILLMAdapter(BaseLLMAdapter):
|
||||
"""Adapter for converting tool schemas to OpenAI's format.
|
||||
class OpenAILLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking OpenAI ChatCompletion API."""
|
||||
|
||||
Provides conversion utilities for transforming Pipecat's standard tool
|
||||
schemas into the format expected by OpenAI's ChatCompletion API for
|
||||
function calling capabilities.
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
tools: List[ChatCompletionToolParam] | OpenAINotGiven
|
||||
tool_choice: ChatCompletionToolChoiceOptionParam | OpenAINotGiven
|
||||
|
||||
|
||||
class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
"""OpenAI-specific adapter for Pipecat.
|
||||
|
||||
Handles:
|
||||
- Extracting parameters for OpenAI's ChatCompletion API from a universal
|
||||
LLM context
|
||||
- Converting Pipecat's standardized tools schema to OpenAI's function-calling format.
|
||||
- Extracting and sanitizing messages from the LLM context for logging with OpenAI.
|
||||
"""
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_standard_messages(context.messages),
|
||||
# TODO: doesn't seem right that we may or may not need to convert tools here; they should already be guaranteed to exist in a universal format in the LLMContext, right?
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]:
|
||||
"""Convert function schemas to OpenAI's function-calling format.
|
||||
|
||||
@@ -37,3 +77,40 @@ class OpenAILLMAdapter(BaseLLMAdapter):
|
||||
ChatCompletionToolParam(type="function", function=func.to_default_dict())
|
||||
for func in functions_schema
|
||||
]
|
||||
|
||||
def get_messages_for_logging(self, context: LLMContext) -> List[dict[str, Any]]:
|
||||
"""Get messages from the LLM context in a format ready for logging about OpenAI.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in context.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:image/"):
|
||||
item["image_url"]["url"] = "data:image/..."
|
||||
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
|
||||
msg["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs, ensure_ascii=False)
|
||||
|
||||
def _from_standard_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
# Just a pass-through: messages is already the right type
|
||||
return messages
|
||||
|
||||
def _from_standard_tool_choice(
|
||||
self, tool_choice: LLMContextToolChoice | NotGiven
|
||||
) -> ChatCompletionToolChoiceOptionParam | OpenAINotGiven:
|
||||
# Just a pass-through: tool_choice is already the right type
|
||||
return tool_choice
|
||||
|
||||
@@ -378,7 +378,7 @@ class TranslationFrame(TextFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAILLMContextAssistantTimestampFrame(DataFrame):
|
||||
class LLMContextAssistantTimestampFrame(DataFrame):
|
||||
"""Timestamp information for assistant messages in LLM context.
|
||||
|
||||
Parameters:
|
||||
|
||||
211
src/pipecat/processors/aggregators/llm_context.py
Normal file
211
src/pipecat/processors/aggregators/llm_context.py
Normal file
@@ -0,0 +1,211 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Universal LLM context management for LLM services in Pipecat.
|
||||
|
||||
Context contents are represented in a generic format (extended from OpenAI)
|
||||
that supports a union of known Pipecat LLM service functionality.
|
||||
|
||||
Whenever an LLM service needs to access context, it does a just-in-time
|
||||
translation from this universal context into whatever format it needs, using a
|
||||
service-specific adapter.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
from openai._types import NotGiven as OpenAINotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
|
||||
# "Re-export" types from OpenAI that we're using as universal context types.
|
||||
# NOTE: this is just for convenience, for now. As soon as the universal types
|
||||
# diverge from OpenAI's, we should ditch this. In fact, audio frames already
|
||||
# diverge from OpenAI's standard format...we really ought to do this.
|
||||
LLMContextMessage = ChatCompletionMessageParam
|
||||
LLMContextTool = ChatCompletionToolParam
|
||||
LLMContextToolChoice = ChatCompletionToolChoiceOptionParam
|
||||
NOT_GIVEN = OPEN_AI_NOT_GIVEN
|
||||
NotGiven = OpenAINotGiven
|
||||
|
||||
|
||||
class LLMContext:
|
||||
"""Manages conversation context for LLM interactions.
|
||||
|
||||
Handles message history, tool definitions, tool choices, and multimedia
|
||||
content for LLM conversations. Provides methods for message manipulation,
|
||||
and content formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[LLMContextMessage]] = None,
|
||||
tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN,
|
||||
tool_choice: LLMContextToolChoice | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
"""Initialize the LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial list of conversation messages.
|
||||
tools: Available tools for the LLM to use.
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
"""
|
||||
self._messages: List[LLMContextMessage] = messages if messages else []
|
||||
self._tools: List[LLMContextTool] | NotGiven | ToolsSchema = tools
|
||||
self._tool_choice: LLMContextToolChoice | NotGiven = tool_choice
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
"""Get the current messages list.
|
||||
|
||||
Returns:
|
||||
List of conversation messages.
|
||||
"""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def tools(self) -> List[LLMContextTool] | NotGiven | List[Any]:
|
||||
"""Get the tools list.
|
||||
|
||||
Returns:
|
||||
Tools list.
|
||||
"""
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> LLMContextToolChoice | NotGiven:
|
||||
"""Get the current tool choice setting.
|
||||
|
||||
Returns:
|
||||
The tool choice configuration.
|
||||
"""
|
||||
return self._tool_choice
|
||||
|
||||
def add_message(self, message: LLMContextMessage):
|
||||
"""Add a single message to the context.
|
||||
|
||||
Args:
|
||||
message: The message to add to the conversation history.
|
||||
"""
|
||||
self._messages.append(message)
|
||||
|
||||
def add_messages(self, messages: List[LLMContextMessage]):
|
||||
"""Add multiple messages to the context.
|
||||
|
||||
Args:
|
||||
messages: List of messages to add to the conversation history.
|
||||
"""
|
||||
self._messages.extend(messages)
|
||||
|
||||
def set_messages(self, messages: List[LLMContextMessage]):
|
||||
"""Replace all messages in the context.
|
||||
|
||||
Args:
|
||||
messages: New list of messages to replace the current history.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def set_tools(self, tools: List[LLMContextTool] | NotGiven | ToolsSchema = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
Args:
|
||||
tools: List of tools available to the LLM, a ToolsSchema, or NOT_GIVEN to disable tools.
|
||||
"""
|
||||
# TODO: convert empty ToolsSchema to NOT_GIVEN if needed?
|
||||
# TODO: maybe someday also convert provider-specific tools to ToolsSchema so it's always in a provider-neutral format here? See open_ai_adapter.py for related comment. Pipecat Flows is currently converting provider-specific tools to ToolsSchema...
|
||||
if isinstance(tools, list) and len(tools) == 0:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def set_tool_choice(self, tool_choice: LLMContextToolChoice | NotGiven):
|
||||
"""Set the tool choice configuration.
|
||||
|
||||
Args:
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
# TODO: we might not want the universal format to be base64 encoded, since encoding is not needed by all LLM services; today, te Gemini adapter has to decode from base64, which is less than ideal.
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = []
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
# NOTE: today we've only built support for audio frames with the Google
|
||||
# LLM, so this "universal" representation skews towards that.
|
||||
# When we add support for other LLMs, we may need to adjust this.
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
|
||||
Args:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# TODO: filter this out in OpenAI adapter, since it doesn't support audio frames
|
||||
content.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": data,
|
||||
"sample_rate": sample_rate,
|
||||
"num_channels": num_channels,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextFrame(Frame):
|
||||
"""Frame containing LLM context.
|
||||
|
||||
Used as a signal to LLM services to ingest the provided context and
|
||||
generate a response based on it.
|
||||
|
||||
Parameters:
|
||||
context: The LLM context containing messages, tools, and configuration.
|
||||
"""
|
||||
|
||||
context: LLMContext
|
||||
@@ -36,6 +36,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -44,7 +45,6 @@ from pipecat.frames.frames import (
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -864,7 +864,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
await self.push_frame(timestamp_frame)
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
|
||||
874
src/pipecat/processors/aggregators/llm_response_universal.py
Normal file
874
src/pipecat/processors/aggregators/llm_response_universal.py
Normal file
@@ -0,0 +1,874 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LLM response aggregators for handling conversation context and message aggregation.
|
||||
|
||||
This module provides aggregators that process and accumulate LLM responses, user inputs,
|
||||
and conversation context. These aggregators handle the flow between speech-to-text,
|
||||
LLM processing, and text-to-speech components in conversational AI pipelines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUserContextAggregatorParams:
|
||||
"""Parameters for configuring LLM user context aggregation behavior.
|
||||
|
||||
Parameters:
|
||||
aggregation_timeout: Maximum time in seconds to wait for additional
|
||||
transcription content before pushing aggregated result. This
|
||||
timeout is used only when the transcription is slow to arrive.
|
||||
turn_emulated_vad_timeout: Maximum time in seconds to wait for emulated
|
||||
VAD when using turn-based analysis. Applied when transcription is
|
||||
received but VAD didn't detect speech (e.g., whispered utterances).
|
||||
"""
|
||||
|
||||
aggregation_timeout: float = 0.5
|
||||
turn_emulated_vad_timeout: float = 0.8
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMAssistantContextAggregatorParams:
|
||||
"""Parameters for configuring LLM assistant context aggregation behavior.
|
||||
|
||||
Parameters:
|
||||
expect_stripped_words: Whether to expect and handle stripped words
|
||||
in text frames by adding spaces between tokens.
|
||||
"""
|
||||
|
||||
expect_stripped_words: bool = True
|
||||
|
||||
|
||||
class LLMContextAggregator(FrameProcessor):
|
||||
"""Base LLM aggregator that uses an LLMContext for conversation storage.
|
||||
|
||||
This aggregator maintains conversation state using an LLMContext and
|
||||
pushes LLMContextFrame objects as aggregation frames. It provides
|
||||
common functionality for context-based conversation management.
|
||||
"""
|
||||
|
||||
def __init__(self, *, context: LLMContext, role: str, **kwargs):
|
||||
"""Initialize the context response aggregator.
|
||||
|
||||
Args:
|
||||
context: The LLM context to use for conversation storage.
|
||||
role: The role this aggregator represents (e.g. "user", "assistant").
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation: str = ""
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
"""Get messages from the LLM context.
|
||||
|
||||
Returns:
|
||||
List of message dictionaries from the context.
|
||||
"""
|
||||
return self._context.messages
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
"""Get the role for this aggregator.
|
||||
|
||||
Returns:
|
||||
The role string for this aggregator.
|
||||
"""
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
"""Get the LLM context.
|
||||
|
||||
Returns:
|
||||
The LLMContext instance used by this aggregator.
|
||||
"""
|
||||
return self._context
|
||||
|
||||
def get_context_frame(self) -> LLMContextFrame:
|
||||
"""Create a context frame with the current context.
|
||||
|
||||
Returns:
|
||||
LLMContextFrame containing the current context.
|
||||
"""
|
||||
return LLMContextFrame(context=self._context)
|
||||
|
||||
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a context frame in the specified direction.
|
||||
|
||||
Args:
|
||||
direction: The direction to push the frame (upstream or downstream).
|
||||
"""
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def add_messages(self, messages):
|
||||
"""Add messages to the context.
|
||||
|
||||
Args:
|
||||
messages: Messages to add to the conversation context.
|
||||
"""
|
||||
self._context.add_messages(messages)
|
||||
|
||||
def set_messages(self, messages):
|
||||
"""Set the context messages.
|
||||
|
||||
Args:
|
||||
messages: Messages to replace the current context messages.
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def set_tools(self, tools: List):
|
||||
"""Set tools in the context.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions to set in the context.
|
||||
"""
|
||||
self._context.set_tools(tools)
|
||||
|
||||
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
|
||||
"""Set tool choice in the context.
|
||||
|
||||
Args:
|
||||
tool_choice: Tool choice configuration for the context.
|
||||
"""
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state."""
|
||||
self._aggregation = ""
|
||||
|
||||
|
||||
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
|
||||
# from the old LLMUserContextAggregator while we gradually migrate service to
|
||||
# use the new universal LLMContext and associated patterns. The suffix will go
|
||||
# away once the migration is complete and the other LLMUserContextAggregator is
|
||||
# deprecated.
|
||||
class LLMUserContextAggregator_Universal(LLMContextAggregator):
|
||||
"""User LLM aggregator that processes speech-to-text transcriptions.
|
||||
|
||||
This aggregator handles the complex logic of aggregating user speech transcriptions
|
||||
from STT services. It manages multiple scenarios including:
|
||||
|
||||
- Transcriptions received between VAD events
|
||||
- Transcriptions received outside VAD events
|
||||
- Interim vs final transcriptions
|
||||
- User interruptions during bot speech
|
||||
- Emulated VAD for whispered or short utterances
|
||||
|
||||
The aggregator uses timeouts to handle cases where transcriptions arrive
|
||||
after VAD events or when no VAD is available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: LLMContext,
|
||||
*,
|
||||
params: Optional[LLMUserContextAggregatorParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the user context aggregator.
|
||||
|
||||
Args:
|
||||
context: The LLM context for conversation storage.
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'.
|
||||
"""
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserContextAggregatorParams()
|
||||
self._vad_params: Optional[VADParams] = None
|
||||
self._turn_params: Optional[SmartTurnParams] = None
|
||||
|
||||
if "aggregation_timeout" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
|
||||
|
||||
self._user_speaking = False
|
||||
self._bot_speaking = False
|
||||
self._was_bot_speaking = False
|
||||
self._emulating_vad = False
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state and interruption strategies."""
|
||||
await super().reset()
|
||||
self._was_bot_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
[await s.reset() for s in self._interruption_strategies]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for user speech aggregation and context management.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
# processed by every processor before any other frame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
self._vad_params = frame.vad_params
|
||||
self._turn_params = frame.turn_params
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_aggregation(self):
|
||||
"""Process the current aggregation and push it downstream."""
|
||||
aggregation = self._aggregation
|
||||
await self.reset()
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
frame = LLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
"""Push the current aggregation based on interruption strategies and conditions."""
|
||||
if len(self._aggregation) > 0:
|
||||
if self.interruption_strategies and self._bot_speaking:
|
||||
should_interrupt = await self._should_interrupt_based_on_strategies()
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
# Don't process aggregation, just reset it
|
||||
await self.reset()
|
||||
else:
|
||||
# No interruption config - normal behavior (always push aggregation)
|
||||
await self._process_aggregation()
|
||||
# Handles the case where both the user and the bot are not speaking,
|
||||
# and the bot was previously speaking before the user interruption.
|
||||
# Normally, when the user stops speaking, new text is expected,
|
||||
# which triggers the bot to respond. However, if no new text
|
||||
# is received, this safeguard ensures
|
||||
# the bot doesn't hang indefinitely while waiting to speak again.
|
||||
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
|
||||
logger.warning("User stopped speaking but no new aggregation received.")
|
||||
# Resetting it so we don't trigger this twice
|
||||
self._was_bot_speaking = False
|
||||
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription
|
||||
# So we need more tests and probably make this feature configurable, disabled it by default.
|
||||
# We are just pushing the same previous context to be processed again in this case
|
||||
# await self.push_frame(LLMContextFrame(self._context))
|
||||
|
||||
async def _should_interrupt_based_on_strategies(self) -> bool:
|
||||
"""Check if interruption should occur based on configured strategies.
|
||||
|
||||
Returns:
|
||||
True if any interruption strategy indicates interruption should occur.
|
||||
"""
|
||||
|
||||
async def should_interrupt(strategy: BaseInterruptionStrategy):
|
||||
await strategy.append_text(self._aggregation)
|
||||
return await strategy.should_interrupt()
|
||||
|
||||
return any([await should_interrupt(s) for s in self._interruption_strategies])
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._create_aggregation_task()
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_input_audio(self, frame: InputAudioRawFrame):
|
||||
for s in self.interruption_strategies:
|
||||
await s.append_audio(frame.audio, frame.sample_rate)
|
||||
|
||||
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
self._waiting_for_aggregation = True
|
||||
self._was_bot_speaking = self._bot_speaking
|
||||
|
||||
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
|
||||
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
|
||||
# EmulateUserStoppedSpeakingFrame).
|
||||
if not frame.emulated and self._emulating_vad:
|
||||
self._emulating_vad = False
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
# We just stopped speaking. Let's see if there's some aggregation to
|
||||
# push. If the last thing we saw is an interim transcription, let's wait
|
||||
# pushing the aggregation as we will probably get a final transcription.
|
||||
if len(self._aggregation) > 0:
|
||||
if not self._seen_interim_results:
|
||||
await self._push_aggregation()
|
||||
# Handles the case where both the user and the bot are not speaking,
|
||||
# and the bot was previously speaking before the user interruption.
|
||||
# So in this case we are resetting the aggregation timer
|
||||
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
|
||||
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
# Make sure we really have some text.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._aggregation += f" {text}" if self._aggregation else text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
|
||||
def _create_aggregation_task(self):
|
||||
if not self._aggregation_task:
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
async def _cancel_aggregation_task(self):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
# The _aggregation_task_handler handles two distinct timeout scenarios:
|
||||
#
|
||||
# 1. When emulating_vad=True: Wait for emulated VAD timeout before
|
||||
# pushing aggregation (simulating VAD behavior when no actual VAD
|
||||
# detection occurred).
|
||||
#
|
||||
# 2. When emulating_vad=False: Use aggregation_timeout as a buffer
|
||||
# to wait for potential late-arriving transcription frames after
|
||||
# a real VAD event.
|
||||
#
|
||||
# For emulated VAD scenarios, the timeout strategy depends on whether
|
||||
# a turn analyzer is configured:
|
||||
#
|
||||
# - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because
|
||||
# the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech
|
||||
# chunking to feed the turn analyzer. This low value is too fast
|
||||
# for emulated VAD scenarios where we need to allow users time to
|
||||
# finish speaking (e.g. 0.8s).
|
||||
#
|
||||
# - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain
|
||||
# consistent user experience between real VAD detection and
|
||||
# emulated VAD scenarios.
|
||||
if not self._emulating_vad:
|
||||
timeout = self._params.aggregation_timeout
|
||||
elif self._turn_params:
|
||||
timeout = self._params.turn_emulated_vad_timeout
|
||||
else:
|
||||
# Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params
|
||||
timeout = (
|
||||
self._vad_params.stop_secs
|
||||
if self._vad_params
|
||||
else self._params.turn_emulated_vad_timeout
|
||||
)
|
||||
await asyncio.wait_for(self._aggregation_event.wait(), timeout)
|
||||
await self._maybe_emulate_user_speaking()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
await self._push_aggregation()
|
||||
|
||||
# If we are emulating VAD we still need to send the user stopped
|
||||
# speaking frame.
|
||||
if self._emulating_vad:
|
||||
await self.push_frame(
|
||||
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
|
||||
)
|
||||
self._emulating_vad = False
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
self._aggregation_event.clear()
|
||||
|
||||
async def _maybe_emulate_user_speaking(self):
|
||||
"""Maybe emulate user speaking based on transcription.
|
||||
|
||||
Emulate user speaking if we got a transcription but it was not
|
||||
detected by VAD. Only do that if the bot is not speaking.
|
||||
"""
|
||||
# Check if we received a transcription but VAD was not able to detect
|
||||
# voice (e.g. when you whisper a short utterance). In that case, we need
|
||||
# to emulate VAD (i.e. user start/stopped speaking), but we do it only
|
||||
# if the bot is not speaking. If the bot is speaking and we really have
|
||||
# a short utterance we don't really want to interrupt the bot.
|
||||
if (
|
||||
not self._user_speaking
|
||||
and not self._waiting_for_aggregation
|
||||
and len(self._aggregation) > 0
|
||||
):
|
||||
if self._bot_speaking:
|
||||
# If we reached this case and the bot is speaking, let's ignore
|
||||
# what the user said.
|
||||
logger.debug("Ignoring user speaking emulation, bot is speaking.")
|
||||
await self.reset()
|
||||
else:
|
||||
# The bot is not speaking so, let's trigger user speaking
|
||||
# emulation.
|
||||
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
self._emulating_vad = True
|
||||
|
||||
|
||||
# NOTE: the "universal" suffix is just meant to distinguish this aggregator
|
||||
# from the old LLMAssistantContextAggregator while we gradually migrate service
|
||||
# to use the new universal LLMContext and associated patterns. The suffix will
|
||||
# go away once the migration is complete and the other
|
||||
# LLMAssistantContextAggregator is deprecated.
|
||||
class LLMAssistantContextAggregator_Universal(LLMContextAggregator):
|
||||
"""Assistant LLM aggregator that processes bot responses and function calls.
|
||||
|
||||
This aggregator handles the complex logic of processing assistant responses including:
|
||||
|
||||
- Text frame aggregation between response start/end markers
|
||||
- Function call lifecycle management
|
||||
- Context updates with timestamps
|
||||
- Tool execution and result handling
|
||||
- Interruption handling during responses
|
||||
|
||||
The aggregator manages function calls in progress and coordinates between
|
||||
text generation and tool execution phases of LLM responses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: LLMContext,
|
||||
*,
|
||||
params: Optional[LLMAssistantContextAggregatorParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the assistant context aggregator.
|
||||
|
||||
Args:
|
||||
context: The OpenAI LLM context for conversation storage.
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'.
|
||||
"""
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._params = params or LLMAssistantContextAggregatorParams()
|
||||
|
||||
if "expect_stripped_words" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
|
||||
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def has_function_calls_in_progress(self) -> bool:
|
||||
"""Check if there are any function calls currently in progress.
|
||||
|
||||
Returns:
|
||||
True if function calls are in progress, False otherwise.
|
||||
"""
|
||||
return bool(self._function_calls_in_progress)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for assistant response aggregation and function call management.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, FunctionCallsStartedFrame):
|
||||
await self._handle_function_calls_started(frame)
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
await self._handle_function_call_in_progress(frame)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
elif isinstance(frame, FunctionCallCancelFrame):
|
||||
await self._handle_function_call_cancel(frame)
|
||||
elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id:
|
||||
await self._handle_user_image_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._push_aggregation()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
"""Push the current assistant aggregation with timestamp."""
|
||||
if not self._aggregation:
|
||||
return
|
||||
|
||||
aggregation = self._aggregation.strip()
|
||||
await self.reset()
|
||||
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
await self.push_frame(timestamp_frame)
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
|
||||
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
|
||||
for function_call in frame.function_calls:
|
||||
self._function_calls_in_progress[function_call.tool_call_id] = None
|
||||
|
||||
async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
|
||||
# Update context with the in-progress function call
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
|
||||
async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
if frame.tool_call_id not in self._function_calls_in_progress:
|
||||
logger.warning(
|
||||
f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running"
|
||||
)
|
||||
return
|
||||
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
properties = frame.properties
|
||||
|
||||
# Update context with the function call result
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
|
||||
|
||||
run_llm = False
|
||||
|
||||
# Run inference if the function call result requires it.
|
||||
if frame.result:
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it.
|
||||
run_llm = properties.run_llm
|
||||
elif frame.run_llm is not None:
|
||||
# If the frame is indicating we should run the LLM, do it.
|
||||
run_llm = frame.run_llm
|
||||
else:
|
||||
# If this is the last function call in progress, run the LLM.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Call the `on_context_updated` callback once the function call result
|
||||
# is added to the context. Also, run this in a separate task to make
|
||||
# sure we don't block the pipeline.
|
||||
if properties and properties.on_context_updated:
|
||||
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
|
||||
task = self.create_task(properties.on_context_updated(), task_name)
|
||||
self._context_updated_tasks.add(task)
|
||||
task.add_done_callback(self._context_updated_task_finished)
|
||||
|
||||
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
logger.debug(
|
||||
f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]"
|
||||
)
|
||||
if frame.tool_call_id not in self._function_calls_in_progress:
|
||||
return
|
||||
|
||||
if self._function_calls_in_progress[frame.tool_call_id].cancel_on_interruption:
|
||||
# Update context with the function call cancellation
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED")
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
logger.debug(
|
||||
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
|
||||
)
|
||||
|
||||
if frame.request.tool_call_id not in self._function_calls_in_progress:
|
||||
logger.warning(
|
||||
f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running"
|
||||
)
|
||||
return
|
||||
|
||||
del self._function_calls_in_progress[frame.request.tool_call_id]
|
||||
|
||||
# Update context with the image frame
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
await self._push_aggregation()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started -= 1
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._params.expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
# The task is finished so this should exit immediately. We need to do
|
||||
# this because otherwise the task manager would report a dangling task
|
||||
# if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextAggregatorPair:
|
||||
"""Pair of LLM context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: LLMUserContextAggregator_Universal
|
||||
_assistant: LLMAssistantContextAggregator_Universal
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
context: LLMContext,
|
||||
*,
|
||||
user_params: LLMUserContextAggregatorParams = LLMUserContextAggregatorParams(),
|
||||
assistant_params: LLMAssistantContextAggregatorParams = LLMAssistantContextAggregatorParams(),
|
||||
) -> "LLMContextAggregatorPair":
|
||||
"""Factory method to create an LLMContextAggregatorPair.
|
||||
|
||||
Args:
|
||||
context: The context managed by the aggregators.
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
LLMContextAggregatorPair: A new instance with configured aggregators.
|
||||
"""
|
||||
user = LLMUserContextAggregator_Universal(context, params=user_params)
|
||||
assistant = LLMAssistantContextAggregator_Universal(context, params=assistant_params)
|
||||
return LLMContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def user(self) -> LLMUserContextAggregator_Universal:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> LLMAssistantContextAggregator_Universal:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
@@ -72,7 +72,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
|
||||
)
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
|
||||
@@ -10,49 +10,29 @@ This module provides Google Gemini integration for the Pipecat framework,
|
||||
including LLM services, context management, and message aggregation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter, GeminiLLMInvocationParams
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
|
||||
from pipecat.utils.tracing.service_decorators import traced_llm
|
||||
|
||||
@@ -63,13 +43,8 @@ try:
|
||||
from google import genai
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
from google.genai.types import (
|
||||
Blob,
|
||||
Content,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
GenerateContentConfig,
|
||||
HttpOptions,
|
||||
Part,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -77,577 +52,12 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""Google-specific user context aggregator.
|
||||
|
||||
Extends OpenAI user context aggregator to handle Google AI's specific
|
||||
Content and Part message format for user messages.
|
||||
"""
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push aggregated user text as a Google Content message."""
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
# Push context frame
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
await self.reset()
|
||||
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Google-specific assistant context aggregator.
|
||||
|
||||
Extends OpenAI assistant context aggregator to handle Google AI's specific
|
||||
Content and Part message format for assistant responses and function calls.
|
||||
"""
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
"""Handle aggregated assistant text response.
|
||||
|
||||
Args:
|
||||
aggregation: The aggregated text response from the assistant.
|
||||
"""
|
||||
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle function call in progress frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call details.
|
||||
"""
|
||||
self._context.add_message(
|
||||
Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
self._context.add_message(
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
id=frame.tool_call_id,
|
||||
name=frame.function_name,
|
||||
response={"response": "IN_PROGRESS"},
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call result.
|
||||
"""
|
||||
if frame.result:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, frame.result
|
||||
)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle function call cancellation frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call cancellation details.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message.role == "user":
|
||||
for part in message.parts:
|
||||
if part.function_response and part.function_response.id == tool_call_id:
|
||||
part.function_response.response = {"value": json.dumps(result)}
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle user image frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing user image data and request context.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleContextAggregatorPair:
|
||||
"""Pair of Google context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for handling user messages.
|
||||
_assistant: Assistant context aggregator for handling assistant responses.
|
||||
"""
|
||||
|
||||
_user: GoogleUserContextAggregator
|
||||
_assistant: GoogleAssistantContextAggregator
|
||||
|
||||
def user(self) -> GoogleUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> GoogleAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class GoogleLLMContext(OpenAILLMContext):
|
||||
"""Google AI LLM context that extends OpenAI context for Google-specific formatting.
|
||||
|
||||
This class handles conversion between OpenAI-style messages and Google AI's
|
||||
Content/Part format, including system messages, function calls, and media.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize GoogleLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial messages in OpenAI format.
|
||||
tools: Available tools/functions for the model.
|
||||
tool_choice: Tool choice configuration.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system_message = None
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
|
||||
"""Upgrade an OpenAI context to a Google context.
|
||||
|
||||
Args:
|
||||
obj: OpenAI LLM context to upgrade.
|
||||
|
||||
Returns:
|
||||
GoogleLLMContext instance with converted messages.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
|
||||
logger.debug(f"Upgrading to Google: {obj}")
|
||||
obj.__class__ = GoogleLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set messages and restructure them for Google format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to set.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
def add_messages(self, messages: List):
|
||||
"""Add messages to the context, converting to Google format as needed.
|
||||
|
||||
Args:
|
||||
messages: List of messages to add (can be mixed formats).
|
||||
"""
|
||||
# Convert each message individually
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, Content):
|
||||
# Already in Gemini format
|
||||
converted_messages.append(msg)
|
||||
else:
|
||||
# Convert from standard format to Gemini format
|
||||
converted = self.from_standard_message(msg)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Add the converted messages to our existing messages
|
||||
self._messages.extend(converted_messages)
|
||||
|
||||
def get_messages_for_logging(self):
|
||||
"""Get messages formatted for logging with sensitive data redacted.
|
||||
|
||||
Returns:
|
||||
List of message dictionaries with inline data redacted.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
obj = message.to_json_dict()
|
||||
try:
|
||||
if "parts" in obj:
|
||||
for part in obj["parts"]:
|
||||
if "inline_data" in part:
|
||||
part["inline_data"]["data"] = "..."
|
||||
except Exception as e:
|
||||
logger.debug(f"Error: {e}")
|
||||
msgs.append(obj)
|
||||
return msgs
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add an image message to the context.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height).
|
||||
image: Raw image bytes.
|
||||
text: Optional text to accompany the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(Part(text=text))
|
||||
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
|
||||
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add audio frames as a message to the context.
|
||||
|
||||
Args:
|
||||
audio_frames: List of audio frames to add.
|
||||
text: Text description of the audio content.
|
||||
"""
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
parts = []
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# NOTE(aleix): According to the docs only text or inline_data should be needed.
|
||||
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
|
||||
parts.append(Part(text=text))
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
|
||||
# self.add_message(message)
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's format.
|
||||
System messages are stored separately and return None.
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Content object with role and parts, or None for system messages.
|
||||
|
||||
Examples:
|
||||
Standard text message::
|
||||
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello there"
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(text="Hello there")]
|
||||
)
|
||||
|
||||
Standard function call message::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "test"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
|
||||
)
|
||||
|
||||
System message returns None and stores content in self.system_message.
|
||||
"""
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
if role == "system":
|
||||
self.system_message = content
|
||||
return None
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
parts.append(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
parts.append(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=json.loads(message["content"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
parts.append(Part(text=content))
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
"""Convert Google Content object to standard structured format.
|
||||
|
||||
Handles text, images, and function calls from Google's Content/Part objects.
|
||||
|
||||
Args:
|
||||
obj: Google Content object with role and parts.
|
||||
|
||||
Returns:
|
||||
List containing a single message in standard format.
|
||||
|
||||
Examples:
|
||||
Google Content with text::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(text="Hello")]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}
|
||||
]
|
||||
|
||||
Google Content with function call::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"q": "test"}))]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "search",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q": "test"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
Google Content with image::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(inline_data=Blob(mime_type="image/jpeg", data=bytes_data))]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,<encoded_data>"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
msg = {"role": obj.role, "content": []}
|
||||
if msg["role"] == "model":
|
||||
msg["role"] = "assistant"
|
||||
|
||||
for part in obj.parts:
|
||||
if part.text:
|
||||
msg["content"].append({"type": "text", "text": part.text})
|
||||
elif part.inline_data:
|
||||
encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
|
||||
}
|
||||
)
|
||||
elif part.function_call:
|
||||
args = part.function_call.args if hasattr(part.function_call, "args") else {}
|
||||
msg["tool_calls"] = [
|
||||
{
|
||||
"id": part.function_call.name,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": part.function_call.name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
elif part.function_response:
|
||||
msg["role"] = "tool"
|
||||
resp = (
|
||||
part.function_response.response
|
||||
if hasattr(part.function_response, "response")
|
||||
else {}
|
||||
)
|
||||
msg["tool_call_id"] = part.function_response.name
|
||||
msg["content"] = json.dumps(resp)
|
||||
|
||||
# there might be no content parts for tool_calls messages
|
||||
if not msg["content"]:
|
||||
del msg["content"]
|
||||
return [msg]
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
"""Restructures messages to ensure proper Google format and message ordering.
|
||||
|
||||
This method handles conversion of OpenAI-formatted messages to Google format,
|
||||
with special handling for function calls, function responses, and system messages.
|
||||
System messages are added back to the context as user messages when needed.
|
||||
|
||||
The final message order is preserved as:
|
||||
1. Function calls (from model)
|
||||
2. Function responses (from user)
|
||||
3. Text messages (converted from system messages)
|
||||
|
||||
Note:
|
||||
System messages are only added back when there are no regular text
|
||||
messages in the context, ensuring proper conversation continuity
|
||||
after function calls.
|
||||
"""
|
||||
self.system_message = None
|
||||
converted_messages = []
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in self._messages:
|
||||
if isinstance(message, Content):
|
||||
# Keep existing Google-formatted messages (e.g., function calls/responses)
|
||||
converted_messages.append(message)
|
||||
continue
|
||||
|
||||
# Convert OpenAI format to Google format, system messages return None
|
||||
converted = self.from_standard_message(message)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Update message list
|
||||
self._messages[:] = converted_messages
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
and getattr(msg.parts[0], "text", None)
|
||||
and not getattr(msg.parts[0], "function_call", None)
|
||||
and not getattr(msg.parts[0], "function_response", None)
|
||||
for msg in self._messages
|
||||
)
|
||||
|
||||
# Add system message back as a user message if we only have function messages
|
||||
if self.system_message and not has_regular_messages:
|
||||
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
|
||||
|
||||
# Remove any empty messages
|
||||
self._messages = [m for m in self._messages if m.parts]
|
||||
|
||||
|
||||
class GoogleLLMService(LLMService):
|
||||
"""Google AI (Gemini) LLM service implementation.
|
||||
|
||||
This class implements inference with Google's AI models, translating internally
|
||||
from OpenAILLMContext to the messages format expected by the Google AI model.
|
||||
We use OpenAILLMContext as a lingua franca for all LLM services to enable
|
||||
easy switching between different LLMs.
|
||||
from the universal LLMContext to the message format expected by the Google
|
||||
AI model.
|
||||
"""
|
||||
|
||||
# Overriding the default adapter to use the Gemini one.
|
||||
@@ -750,7 +160,7 @@ class GoogleLLMService(LLMService):
|
||||
logger.exception(f"Failed to unset thinking budget: {e}")
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
prompt_tokens = 0
|
||||
@@ -763,19 +173,31 @@ class GoogleLLMService(LLMService):
|
||||
search_result = ""
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
|
||||
f"{self}: Generating chat [{context.get_messages_for_logging()}]"
|
||||
adapter = self.get_llm_adapter()
|
||||
llm_invocation_params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if context.system_message and self._system_instruction != context.system_message:
|
||||
logger.debug(f"System instruction changed: {context.system_message}")
|
||||
self._system_instruction = context.system_message
|
||||
logger.debug(
|
||||
# TODO: figure out a nice way to also log system instruction
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | [{adapter.get_messages_for_logging(context)}]"
|
||||
f"{self}: Generating chat [{adapter.get_messages_for_logging(context)}]"
|
||||
)
|
||||
|
||||
messages = llm_invocation_params["messages"]
|
||||
if (
|
||||
llm_invocation_params.get("system_instruction")
|
||||
and self._system_instruction != llm_invocation_params["system_instruction"]
|
||||
):
|
||||
logger.debug(
|
||||
f"System instruction changed: {llm_invocation_params['system_instruction']}"
|
||||
)
|
||||
self._system_instruction = llm_invocation_params["system_instruction"]
|
||||
|
||||
# TODO: test what happens when there are no tools
|
||||
tools = []
|
||||
if context.tools:
|
||||
tools = context.tools
|
||||
if llm_invocation_params.get("tools"):
|
||||
tools = llm_invocation_params["tools"]
|
||||
elif self._tools:
|
||||
tools = self._tools
|
||||
tool_config = None
|
||||
@@ -922,12 +344,12 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
context = None
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = GoogleLLMContext.upgrade_to_google(frame.context)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = GoogleLLMContext(frame.messages)
|
||||
context = LLMContext(messages=frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = GoogleLLMContext()
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
@@ -938,34 +360,3 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GoogleContextAggregatorPair:
|
||||
"""Create Google-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for Google's message format,
|
||||
including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
GoogleContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
GoogleContextAggregatorPair.
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
user = GoogleUserContextAggregator(context, params=user_params)
|
||||
assistant = GoogleAssistantContextAggregator(context, params=assistant_params)
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -41,6 +41,7 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -89,7 +90,8 @@ class FunctionCallParams:
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
llm: "LLMService"
|
||||
context: OpenAILLMContext
|
||||
# TODO: after migration of all services to universal LLMContext, OpenAILLMContext can be removed
|
||||
context: LLMContext | OpenAILLMContext
|
||||
result_callback: FunctionCallResultCallback
|
||||
|
||||
|
||||
@@ -418,7 +420,10 @@ class LLMService(AIService):
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(self, context: OpenAILLMContext, function_name: str):
|
||||
# TODO: after migration of all services to universal LLMContext, OpenAILLMContext can be removed
|
||||
async def _call_start_function(
|
||||
self, context: LLMContext | OpenAILLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base OpenAI LLM service implementation."""
|
||||
"""Base LLM service implementation for services that use the AsyncOpenAI client."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
@@ -21,6 +21,7 @@ from openai import (
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -31,9 +32,9 @@ from pipecat.frames.frames import (
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
@@ -44,8 +45,8 @@ from pipecat.utils.tracing.service_decorators import traced_llm
|
||||
class BaseOpenAILLMService(LLMService):
|
||||
"""Base class for all services that use the AsyncOpenAI client.
|
||||
|
||||
This service consumes OpenAILLMContextFrame frames, which contain a reference
|
||||
to an OpenAILLMContext object. The context defines what is sent to the LLM for
|
||||
This service consumes LLMContextFrame frames, which contain a reference to
|
||||
an LLMContext object. The context defines what is sent to the LLM for
|
||||
completion, including user, assistant, and system messages, as well as tool
|
||||
choices and function call configurations.
|
||||
"""
|
||||
@@ -173,13 +174,13 @@ class BaseOpenAILLMService(LLMService):
|
||||
return True
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""Get streaming chat completions from OpenAI API.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing tools and configuration.
|
||||
messages: List of chat completion messages to send.
|
||||
params_from_context: Parameters, derived from the LLM context, to
|
||||
use for the chat completion. Contains messages, tools, and tool choice.
|
||||
|
||||
Returns:
|
||||
Async stream of chat completion chunks.
|
||||
@@ -187,9 +188,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
@@ -200,39 +198,28 @@ class BaseOpenAILLMService(LLMService):
|
||||
"max_completion_tokens": self._settings["max_completion_tokens"],
|
||||
}
|
||||
|
||||
# Messages, tools, tool_choice
|
||||
params.update(params_from_context)
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
self, context: LLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
|
||||
adapter = self.get_llm_adapter()
|
||||
logger.debug(f"{self}: Generating chat [{adapter.get_messages_for_logging(context)}]")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
|
||||
# base64 encode any images
|
||||
for message in messages:
|
||||
if message.get("mime_type") == "image/jpeg":
|
||||
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
|
||||
text = message["content"]
|
||||
message["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
},
|
||||
]
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
chunks = await self.get_chat_completions(context, messages)
|
||||
chunks = await self.get_chat_completions(params)
|
||||
|
||||
return chunks
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
@@ -331,7 +318,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for LLM completion requests.
|
||||
|
||||
Handles OpenAILLMContextFrame, LLMMessagesFrame, VisionImageRawFrame,
|
||||
Handles LLMContextFrame, LLMMessagesFrame, VisionImageRawFrame,
|
||||
and LLMUpdateSettingsFrame to trigger LLM completions and manage settings.
|
||||
|
||||
Args:
|
||||
@@ -341,12 +328,12 @@ class BaseOpenAILLMService(LLMService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAILLMContext = frame.context
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context: LLMContext = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
context = LLMContext(messages=frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext()
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
|
||||
@@ -16,45 +16,16 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantContextAggregator_Universal,
|
||||
LLMAssistantContextAggregatorParams,
|
||||
LLMUserContextAggregator_Universal,
|
||||
LLMUserContextAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIContextAggregatorPair:
|
||||
"""Pair of OpenAI context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "OpenAIAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
"""OpenAI LLM service implementation.
|
||||
|
||||
@@ -78,141 +49,3 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
**kwargs: Additional arguments passed to the parent BaseOpenAILLMService.
|
||||
"""
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create OpenAI-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for OpenAI's message format,
|
||||
including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
"""OpenAI-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for OpenAI LLM services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""OpenAI-specific assistant context aggregator.
|
||||
|
||||
Handles aggregation of assistant messages for OpenAI LLM services,
|
||||
with specialized support for OpenAI's function calling format,
|
||||
tool usage tracking, and image message handling.
|
||||
"""
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle a function call in progress.
|
||||
|
||||
Adds the function call to the context with an IN_PROGRESS status
|
||||
to track ongoing function execution.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call progress information.
|
||||
"""
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle the result of a function call.
|
||||
|
||||
Updates the context with the function call result, replacing any
|
||||
previous IN_PROGRESS status.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the function call result.
|
||||
"""
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle a cancelled function call.
|
||||
|
||||
Updates the context to mark the function call as cancelled.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the function call cancellation information.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle a user image frame from a function call request.
|
||||
|
||||
Marks the associated function call as completed and adds the image
|
||||
to the context for processing.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the user image and request context.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
@@ -17,9 +17,9 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
InterimTranscriptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
@@ -738,7 +738,7 @@ class TestAnthropicAssistantContextAggregator(
|
||||
):
|
||||
CONTEXT_CLASS = AnthropicLLMContext
|
||||
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -773,7 +773,7 @@ class TestAWSBedrockAssistantContextAggregator(
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -814,7 +814,7 @@ class TestGoogleAssistantContextAggregator(
|
||||
):
|
||||
CONTEXT_CLASS = GoogleLLMContext
|
||||
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = context.messages[index].to_json_dict()
|
||||
@@ -848,4 +848,4 @@ class TestOpenAIAssistantContextAggregator(
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
Reference in New Issue
Block a user