Compare commits
12 Commits
hush/conte
...
pk/flows-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b776049535 | ||
|
|
3aa403a16f | ||
|
|
50dace147d | ||
|
|
e039abd290 | ||
|
|
63f88c0add | ||
|
|
2aef572e38 | ||
|
|
ba3100be0d | ||
|
|
9e65e77095 | ||
|
|
3183f9c077 | ||
|
|
90e6f0dca8 | ||
|
|
44b917c546 | ||
|
|
c1ac1a6326 |
3
changelog/3620.added.2.md
Normal file
3
changelog/3620.added.2.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
|
||||
|
||||
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
|
||||
1
changelog/3620.added.md
Normal file
1
changelog/3620.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added support to Gemini Live (`GeminiLiveLLMService`) for programmatically swapping tools or editing context at runtime; now you can use `LLMMessagesAppendFrame`, `LLMMessagesUpdateFrame`, `LLMMessagesTransformFrame`, and `LLMSetToolsFrame` with Gemini Live, like you would with text-to-text services. Note that this new functionality only works if you're using `LLMContext` and `LLMContextAggregatorPair` rather than the deprecated `OpenAILLMContext` and associated aggregators.
|
||||
@@ -57,6 +57,10 @@ async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
async def get_news(params: FunctionCallParams):
|
||||
await params.result_callback(
|
||||
{
|
||||
@@ -69,10 +73,6 @@ async def get_news(params: FunctionCallParams):
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
@@ -90,13 +90,6 @@ weather_function = FunctionSchema(
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
get_news_function = FunctionSchema(
|
||||
name="get_news",
|
||||
description="Get the current news.",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
@@ -109,6 +102,13 @@ restaurant_function = FunctionSchema(
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
get_news_function = FunctionSchema(
|
||||
name="get_news",
|
||||
description="Get the current news.",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
@@ -215,8 +215,9 @@ Remember, your responses should be short. Just one or two sentences, usually. Re
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
# Add a new tool at runtime after a delay.
|
||||
# Add a new tool (get_news) at runtime after a delay
|
||||
await asyncio.sleep(15)
|
||||
logger.info(f"Adding new tool get_news at runtime...")
|
||||
new_tools = ToolsSchema(
|
||||
standard_tools=[weather_function, restaurant_function, get_news_function]
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
@@ -15,7 +16,7 @@ from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -51,6 +52,18 @@ async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
async def get_news(params: FunctionCallParams):
|
||||
await params.result_callback(
|
||||
{
|
||||
"news": [
|
||||
"Massive UFO currently hovering above New York City",
|
||||
"Stock markets reach all-time highs",
|
||||
"Living dinosaur species discovered in the Amazon rainforest",
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
system_instruction = """
|
||||
You are a helpful assistant who can answer questions and use tools.
|
||||
|
||||
@@ -109,6 +122,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
get_news_function = FunctionSchema(
|
||||
name="get_news",
|
||||
description="Get the current news.",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
search_tool = {"google_search": {}}
|
||||
# KNOWN ISSUE: If using GeminiVertexLiveLLMService, it appears
|
||||
# you cannot use the "google_search" tool alongside other tools.
|
||||
@@ -126,6 +145,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
llm.register_function("get_news", get_news)
|
||||
|
||||
# You can provide the system instructions and tools in the context rather
|
||||
# than as arguments to GeminiLiveLLMService, but note that doing so will
|
||||
@@ -174,6 +194,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
# Add a new tool (get_news) at runtime after a delay
|
||||
await asyncio.sleep(15)
|
||||
logger.info(f"Adding new tool get_news at runtime...")
|
||||
new_tools = ToolsSchema(
|
||||
standard_tools=[weather_function, restaurant_function, get_news_function]
|
||||
)
|
||||
await task.queue_frames([LLMSetToolsFrame(tools=new_tools)])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides schemas for managing both standardized function tools
|
||||
and custom adapter-specific tools in the Pipecat framework.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -17,6 +18,36 @@ from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunct
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsSchemaDiff:
|
||||
"""Represents the differences between two ToolsSchema instances.
|
||||
|
||||
Parameters:
|
||||
standard_tools_added: Names of newly added standard tools.
|
||||
standard_tools_removed: Names of removed standard tools.
|
||||
standard_tools_modified: True if any existing standard tool's definition changed.
|
||||
custom_tools_changed: True if the custom_tools dictionary differs.
|
||||
"""
|
||||
|
||||
standard_tools_added: List[str] = field(default_factory=list)
|
||||
standard_tools_removed: List[str] = field(default_factory=list)
|
||||
standard_tools_modified: bool = False
|
||||
custom_tools_changed: bool = False
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if there are any differences.
|
||||
|
||||
Returns:
|
||||
True if any field indicates a change, False otherwise.
|
||||
"""
|
||||
return bool(
|
||||
self.standard_tools_added
|
||||
or self.standard_tools_removed
|
||||
or self.standard_tools_modified
|
||||
or self.custom_tools_changed
|
||||
)
|
||||
|
||||
|
||||
class AdapterType(Enum):
|
||||
"""Supported adapter types for custom tools.
|
||||
|
||||
@@ -92,3 +123,44 @@ class ToolsSchema:
|
||||
value: Dictionary mapping adapter types to their custom tool definitions.
|
||||
"""
|
||||
self._custom_tools = value
|
||||
|
||||
def diff(self, other: "ToolsSchema") -> ToolsSchemaDiff:
|
||||
"""Compare this ToolsSchema to another and return the differences.
|
||||
|
||||
Args:
|
||||
other: The ToolsSchema to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ToolsSchemaDiff containing the differences between self and other.
|
||||
"""
|
||||
result = ToolsSchemaDiff()
|
||||
|
||||
# Build maps of tool name -> FunctionSchema for comparison
|
||||
self_tools_by_name: Dict[str, FunctionSchema] = {
|
||||
tool.name: tool for tool in self._standard_tools
|
||||
}
|
||||
other_tools_by_name: Dict[str, FunctionSchema] = {
|
||||
tool.name: tool for tool in other._standard_tools
|
||||
}
|
||||
|
||||
self_names = set(self_tools_by_name.keys())
|
||||
other_names = set(other_tools_by_name.keys())
|
||||
|
||||
# Find added and removed tools
|
||||
result.standard_tools_added = sorted(other_names - self_names)
|
||||
result.standard_tools_removed = sorted(self_names - other_names)
|
||||
|
||||
# Check for modified tools (same name, different definition)
|
||||
common_names = self_names & other_names
|
||||
for name in common_names:
|
||||
self_tool = self_tools_by_name[name]
|
||||
other_tool = other_tools_by_name[name]
|
||||
# Compare using to_default_dict() for full schema comparison
|
||||
if self_tool.to_default_dict() != other_tool.to_default_dict():
|
||||
result.standard_tools_modified = True
|
||||
break
|
||||
|
||||
# Compare custom tools
|
||||
result.custom_tools_changed = self._custom_tools != other._custom_tools
|
||||
|
||||
return result
|
||||
|
||||
@@ -53,19 +53,32 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, *, convert_function_messages_to_text: bool = False
|
||||
) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
convert_function_messages_to_text: If True, convert function_call and function_response
|
||||
parts to specially-formatted text messages. This is needed for Gemini Live
|
||||
(at least with "models/gemini-2.5-flash-native-audio-preview-12-2025", the
|
||||
default at the time of this writing) which cannot handle function-call-related
|
||||
messages when initializing conversation history.
|
||||
See https://stackoverflow.com/a/79851394.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
converted = self._from_universal_context_messages(self.get_messages(context))
|
||||
messages = converted.messages
|
||||
|
||||
if convert_function_messages_to_text:
|
||||
messages = self._convert_function_messages_to_text(messages)
|
||||
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
"system_instruction": converted.system_instruction,
|
||||
"messages": messages,
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
}
|
||||
@@ -668,3 +681,43 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _convert_function_messages_to_text(self, messages: List[Content]) -> List[Content]:
|
||||
"""Convert function_call and function_response parts to text messages.
|
||||
|
||||
Args:
|
||||
messages: List of Content messages to process.
|
||||
|
||||
Returns:
|
||||
List of Content messages with function-related parts converted to text.
|
||||
"""
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if msg.parts:
|
||||
converted_parts = []
|
||||
for part in msg.parts:
|
||||
if func_call := getattr(part, "function_call", None):
|
||||
# Convert function call to text
|
||||
args_str = json.dumps(func_call.args) if func_call.args else "{}"
|
||||
text = (
|
||||
f"[Historical function call (for context only, not a template): "
|
||||
f"{func_call.name}({args_str})]"
|
||||
)
|
||||
converted_parts.append(Part(text=text))
|
||||
elif func_response := getattr(part, "function_response", None):
|
||||
# Convert function response to text
|
||||
response_str = (
|
||||
json.dumps(func_response.response) if func_response.response else "{}"
|
||||
)
|
||||
text = (
|
||||
f"[Historical function result (for context only): "
|
||||
f"{func_response.name} returned {response_str}]"
|
||||
)
|
||||
converted_parts.append(Part(text=text))
|
||||
else:
|
||||
converted_parts.append(part)
|
||||
if converted_parts:
|
||||
converted_messages.append(Content(role=msg.role, parts=converted_parts))
|
||||
else:
|
||||
converted_messages.append(msg)
|
||||
return converted_messages
|
||||
|
||||
@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
@@ -667,9 +667,16 @@ class LLMContextFrame(Frame):
|
||||
|
||||
Parameters:
|
||||
context: The LLM context containing messages, tools, and configuration.
|
||||
messages_programmatically_edited: Whether the context messages were
|
||||
programmatically edited (e.g. via LLMMessagesAppendFrame or
|
||||
LLMMessagesUpdateFrame) since the last context frame was pushed.
|
||||
This is used by speech-to-speech LLM services (like Gemini Live) to
|
||||
distinguish between messages that originated from the LLM output
|
||||
itself vs. messages that were externally injected.
|
||||
"""
|
||||
|
||||
context: "LLMContext"
|
||||
messages_programmatically_edited: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -822,6 +829,25 @@ class LLMMessagesUpdateFrame(DataFrame):
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesTransformFrame(DataFrame):
|
||||
"""Frame containing a transform function to modify the current context's LLM messages.
|
||||
|
||||
A frame containing a transform function that takes the context's current list
|
||||
of LLM messages and returns a modified list.
|
||||
|
||||
Only compatible with LLMContext and not the deprecated OpenAILLMContext.
|
||||
|
||||
Parameters:
|
||||
transform: A function that takes a list of messages and returns a
|
||||
modified list.
|
||||
run_llm: Whether the context update should be sent to the LLM.
|
||||
"""
|
||||
|
||||
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSetToolsFrame(DataFrame):
|
||||
"""Frame containing tools for LLM function calling.
|
||||
|
||||
@@ -18,8 +18,8 @@ import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -30,7 +30,7 @@ from openai.types.chat import (
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
|
||||
from pipecat.frames.frames import AudioRawFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,6 +47,39 @@ NOT_GIVEN = OPEN_AI_NOT_GIVEN
|
||||
NotGiven = OpenAINotGiven
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextDiff:
|
||||
"""Represents the differences between two LLMContext instances.
|
||||
|
||||
Parameters:
|
||||
messages_appended: New messages appended at the end. Empty if history_edited is True.
|
||||
history_edited: True if earlier messages were changed, inserted, or removed.
|
||||
tool_calls_resolved: List of tool_call_ids that changed from "IN_PROGRESS" to a result.
|
||||
tools_diff: Differences in tools configuration, or None if unchanged or both NOT_GIVEN.
|
||||
tool_choice_changed: True if the tool_choice setting differs.
|
||||
"""
|
||||
|
||||
messages_appended: List["LLMContextMessage"] = field(default_factory=list)
|
||||
history_edited: bool = False
|
||||
tool_calls_resolved: List[str] = field(default_factory=list)
|
||||
tools_diff: ToolsSchemaDiff = field(default_factory=ToolsSchemaDiff)
|
||||
tool_choice_changed: bool = False
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if there are any differences.
|
||||
|
||||
Returns:
|
||||
True if any field indicates a change, False otherwise.
|
||||
"""
|
||||
return bool(
|
||||
self.messages_appended
|
||||
or self.history_edited
|
||||
or self.tool_calls_resolved
|
||||
or self.tools_diff.has_changes()
|
||||
or self.tool_choice_changed
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSpecificMessage:
|
||||
"""A container for a context message that is specific to a particular LLM service.
|
||||
@@ -341,6 +374,19 @@ class LLMContext:
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the current messages using the provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
current_messages = self._messages
|
||||
new_messages = transform(current_messages)
|
||||
self.set_messages(new_messages)
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
@@ -410,3 +456,137 @@ class LLMContext:
|
||||
raise TypeError(
|
||||
f"In LLMContext, tools must be a ToolsSchema object or NOT_GIVEN. Got type: {type(tools)}",
|
||||
)
|
||||
|
||||
def diff(self, other: "LLMContext") -> LLMContextDiff:
|
||||
"""Compare this context to another and return the differences.
|
||||
|
||||
Compares self (the "before" state) to other (the "after" state) and
|
||||
identifies what has changed.
|
||||
|
||||
Args:
|
||||
other: The LLMContext to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ContextDiff containing the differences between self and other.
|
||||
"""
|
||||
result = LLMContextDiff()
|
||||
|
||||
# Compare messages
|
||||
self_messages = self._messages
|
||||
other_messages = other._messages
|
||||
self_len = len(self_messages)
|
||||
other_len = len(other_messages)
|
||||
|
||||
# Check if history was edited (messages removed, modified, or inserted in the middle)
|
||||
if other_len < self_len:
|
||||
# Messages were removed
|
||||
result.history_edited = True
|
||||
else:
|
||||
# Check if the prefix matches (first self_len messages should be identical)
|
||||
for i in range(self_len):
|
||||
if not self._messages_equal(self_messages[i], other_messages[i]):
|
||||
result.history_edited = True
|
||||
break
|
||||
|
||||
# If history wasn't edited, capture appended messages
|
||||
if not result.history_edited and other_len > self_len:
|
||||
result.messages_appended = other_messages[self_len:]
|
||||
|
||||
# Find resolved tool calls (IN_PROGRESS -> something else)
|
||||
result.tool_calls_resolved = self._find_resolved_tool_calls(other)
|
||||
|
||||
# Compare tools
|
||||
result.tools_diff = self._compute_tools_diff(other)
|
||||
|
||||
# Compare tool_choice
|
||||
# (For some reason if they're both NOT_GIVEN, equality check returns False?)
|
||||
if not self._tool_choice and not other._tool_choice:
|
||||
result.tool_choice_changed = False
|
||||
else:
|
||||
result.tool_choice_changed = self._tool_choice != other._tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _messages_equal(self, msg1: LLMContextMessage, msg2: LLMContextMessage) -> bool:
|
||||
"""Compare two messages for equality.
|
||||
|
||||
Args:
|
||||
msg1: First message to compare.
|
||||
msg2: Second message to compare.
|
||||
|
||||
Returns:
|
||||
True if the messages are equal, False otherwise.
|
||||
"""
|
||||
# Handle LLMSpecificMessage
|
||||
if isinstance(msg1, LLMSpecificMessage) and isinstance(msg2, LLMSpecificMessage):
|
||||
return msg1.llm == msg2.llm and msg1.message == msg2.message
|
||||
elif isinstance(msg1, LLMSpecificMessage) or isinstance(msg2, LLMSpecificMessage):
|
||||
return False
|
||||
|
||||
# Both are standard messages (dicts)
|
||||
return msg1 == msg2
|
||||
|
||||
def _find_resolved_tool_calls(self, other: "LLMContext") -> List[str]:
|
||||
"""Find tool calls that changed from IN_PROGRESS to a resolved state.
|
||||
|
||||
Args:
|
||||
other: The context to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
List of tool_call_ids that were resolved.
|
||||
"""
|
||||
resolved: List[str] = []
|
||||
|
||||
# Build a map of tool_call_id -> content for "other" context
|
||||
other_tool_contents: Dict[str, Any] = {}
|
||||
for msg in other._messages:
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
other_tool_contents[tool_call_id] = msg.get("content")
|
||||
|
||||
# Find tool messages in self that are IN_PROGRESS but resolved in other
|
||||
for msg in self._messages:
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
if msg.get("content") == "IN_PROGRESS":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if tool_call_id and tool_call_id in other_tool_contents:
|
||||
other_content = other_tool_contents[tool_call_id]
|
||||
if other_content != "IN_PROGRESS":
|
||||
resolved.append(tool_call_id)
|
||||
|
||||
return resolved
|
||||
|
||||
def _compute_tools_diff(self, other: "LLMContext") -> ToolsSchemaDiff:
|
||||
"""Compute the difference in tools between self and other.
|
||||
|
||||
Args:
|
||||
other: The context to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ToolsSchemaDiff if there are changes, None if both are NOT_GIVEN or identical.
|
||||
"""
|
||||
self_has_tools = isinstance(self._tools, ToolsSchema)
|
||||
other_has_tools = isinstance(other._tools, ToolsSchema)
|
||||
|
||||
if not self_has_tools and not other_has_tools:
|
||||
# Both are NOT_GIVEN
|
||||
return ToolsSchemaDiff()
|
||||
|
||||
if self_has_tools and other_has_tools:
|
||||
# Both have tools - use ToolsSchema.diff()
|
||||
diff = self._tools.diff(other._tools)
|
||||
return diff
|
||||
|
||||
if not self_has_tools and other_has_tools:
|
||||
# Tools were added (self is NOT_GIVEN, other has tools)
|
||||
return ToolsSchemaDiff(
|
||||
standard_tools_added=[tool.name for tool in other._tools.standard_tools],
|
||||
custom_tools_changed=other._tools.custom_tools is not None,
|
||||
)
|
||||
|
||||
# Tools were removed (self has tools, other is NOT_GIVEN)
|
||||
return ToolsSchemaDiff(
|
||||
standard_tools_removed=[tool.name for tool in self._tools.standard_tools],
|
||||
custom_tools_changed=self._tools.custom_tools is not None,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -40,6 +40,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
@@ -231,7 +232,16 @@ class LLMContextAggregator(FrameProcessor):
|
||||
Returns:
|
||||
LLMContextFrame containing the current context.
|
||||
"""
|
||||
return LLMContextFrame(context=self._context)
|
||||
# Check if messages were programmatically edited since the last push.
|
||||
# This flag is stored as a runtime attribute on the shared context
|
||||
# object so that both user and assistant aggregators can see it.
|
||||
messages_programmatically_edited = getattr(
|
||||
self._context, "_pipecat_messages_programmatically_edited", False
|
||||
)
|
||||
return LLMContextFrame(
|
||||
context=self._context,
|
||||
messages_programmatically_edited=messages_programmatically_edited,
|
||||
)
|
||||
|
||||
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a context frame in the specified direction.
|
||||
@@ -241,6 +251,9 @@ class LLMContextAggregator(FrameProcessor):
|
||||
"""
|
||||
frame = self._get_context_frame()
|
||||
await self.push_frame(frame, direction)
|
||||
# Clear the programmatic edit flag after pushing, since the context
|
||||
# frame now carries this information to downstream processors.
|
||||
self._context._pipecat_messages_programmatically_edited = False
|
||||
|
||||
def add_messages(self, messages):
|
||||
"""Add messages to the context.
|
||||
@@ -258,6 +271,17 @@ class LLMContextAggregator(FrameProcessor):
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the context messages using a provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
self._context.transform_messages(transform)
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven):
|
||||
"""Set tools in the context.
|
||||
|
||||
@@ -458,13 +482,17 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
|
||||
# services (like OpenAI Realtime) may need to know about tool
|
||||
# changes; unlike text-based LLM services they won't just "pick up
|
||||
# the change" on the next LLM run, as the LLM is continuously
|
||||
# running.
|
||||
# Push the LLMSetToolsFrame as well, since some realtime (aka
|
||||
# speech-to-speech) LLM services (like OpenAI Realtime) may need to
|
||||
# be directly be informed of tool changes; unlike text-based LLM
|
||||
# services they can't necessarily rely on "picking up the change"
|
||||
# from the context on the next LLM run, as the LLM is continuously
|
||||
# running and they may need to apply the change sooner than the
|
||||
# next context frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
@@ -573,11 +601,28 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
# Mark the context as programmatically edited. This flag is stored as a
|
||||
# runtime attribute on the shared context object so that both user and
|
||||
# assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
# Mark the context as programmatically edited. This flag is stored as a
|
||||
# runtime attribute on the shared context object so that both user and
|
||||
# assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
# Mark the context as programmatically edited. This flag is stored as a
|
||||
# runtime attribute on the shared context object so that both user and
|
||||
# assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
@@ -866,6 +911,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
@@ -909,11 +956,28 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
# Mark the context as programmatically edited. This flag is stored as a
|
||||
# runtime attribute on the shared context object so that both user and
|
||||
# assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
# Mark the context as programmatically edited. This flag is stored as
|
||||
# a runtime attribute on the shared context object so that both user
|
||||
# and assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
# Mark the context as programmatically edited. This flag is stored as a
|
||||
# runtime attribute on the shared context object so that both user and
|
||||
# assistant aggregators can see it.
|
||||
self._context._pipecat_messages_programmatically_edited = True
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ voice transcription, streaming responses, and tool usage.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import time
|
||||
import uuid
|
||||
@@ -59,7 +60,12 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
NOT_GIVEN,
|
||||
LLMContext,
|
||||
LLMContextDiff,
|
||||
LLMContextMessage,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -684,6 +690,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._video_input_paused = start_video_paused
|
||||
self._context = None
|
||||
self._context_snapshot = None
|
||||
self._api_key = api_key
|
||||
self._http_options = update_google_client_http_options(http_options)
|
||||
self._session: AsyncSession = None
|
||||
@@ -826,6 +833,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
logger.error("Context already set. Can only set up Gemini Live context once.")
|
||||
return
|
||||
self._context = GeminiLiveContext.upgrade(context)
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
self._context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
await self._create_initial_response()
|
||||
|
||||
#
|
||||
@@ -915,7 +925,12 @@ class GeminiLiveLLMService(LLMService):
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
messages_programmatically_edited = (
|
||||
frame.messages_programmatically_edited
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else False
|
||||
)
|
||||
await self._handle_context(context, messages_programmatically_edited)
|
||||
elif isinstance(frame, InputTextRawFrame):
|
||||
await self._send_user_text(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -950,14 +965,21 @@ class GeminiLiveLLMService(LLMService):
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self._update_settings()
|
||||
# We actually don't need to do anything here; the next time we get
|
||||
# a context frame (like after a user transcription is appended),
|
||||
# we'll detect that tools have changed and reconnect then to apply
|
||||
# the new tools.
|
||||
pass
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
async def _handle_context(self, context: LLMContext, messages_programmatically_edited: bool):
|
||||
if not self._context:
|
||||
# We got our initial context
|
||||
self._context = context
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
|
||||
# If context contains system instruction or tools, reconnect in
|
||||
# order to apply them.
|
||||
@@ -981,6 +1003,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
"Tools provided both at init time and in context; using context-provided value."
|
||||
)
|
||||
if system_instruction or tools:
|
||||
self._session_resumption_handle = None
|
||||
await self._reconnect()
|
||||
|
||||
# Initialize our bookkeeping of already-completed tool calls in
|
||||
@@ -1011,17 +1034,71 @@ class GeminiLiveLLMService(LLMService):
|
||||
# We got an updated context.
|
||||
self._context = context
|
||||
|
||||
# Here we assume that the updated context will contain either:
|
||||
# - new messages (that the Gemini Live service, with its own
|
||||
# context management, is already aware of), or
|
||||
# - tool call results (that we need to tell the remote service
|
||||
# about).
|
||||
# (In the future, we could do more sophisticated diffing here,
|
||||
# which would enable the user to programmatically manipulate the
|
||||
# context).
|
||||
diff = self._context_snapshot.diff(self._context)
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
|
||||
# Send results for newly-completed function calls, if any.
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
if self._context_update_requires_reconnect(diff, messages_programmatically_edited):
|
||||
logger.debug("Context update requires reconnect.")
|
||||
|
||||
# Reconnect
|
||||
self._session_resumption_handle = None
|
||||
await self._reconnect()
|
||||
|
||||
# Initialize our bookkeeping of already-completed tool calls in
|
||||
# the context
|
||||
await self._process_completed_function_calls(send_new_results=False)
|
||||
|
||||
# Trigger "initial" response with new connection
|
||||
await self._create_initial_response()
|
||||
else:
|
||||
logger.debug("Context update does not require reconnect.")
|
||||
|
||||
# Send results for newly-completed function calls, if any.
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
def _context_update_requires_reconnect(
|
||||
self, diff: LLMContextDiff, messages_programmatically_edited: bool
|
||||
) -> bool:
|
||||
"""Check if an update to our LLM context requires reconnection.
|
||||
|
||||
Args:
|
||||
diff: The context diff representing the update
|
||||
messages_programmatically_edited: Whether context messages were
|
||||
programmatically edited (e.g. with LLMMessagesAppendFrame)
|
||||
|
||||
Returns:
|
||||
True if reconnection is required, False otherwise.
|
||||
"""
|
||||
# We need to reconnect in 3 cases:
|
||||
# 1. If the conversation history was edited
|
||||
# 2. If the tools available to the model were changed
|
||||
# 3. If messages were appended programmatically by the user, e.g. with
|
||||
# LLMMessagesAppendFrame (NOT if they originated from Gemini Live
|
||||
# itself, since the remote service's internal bookkeeping is already
|
||||
# aware of those)
|
||||
#
|
||||
# Note that *ideally* in this 3rd case we would just send the newly
|
||||
# appended messages without reconnecting, but in all my testing so far,
|
||||
# I haven't been able to get that to work as reliably as I'd like.
|
||||
# (Commments in the Gemini Live Python SDK also warn against
|
||||
# programmatically sending new messages after the initial conversation
|
||||
# history seeding is done and the voice chat has begun.)
|
||||
if (
|
||||
diff.history_edited
|
||||
or diff.tools_diff.has_changes()
|
||||
or (diff.messages_appended and messages_programmatically_edited)
|
||||
):
|
||||
if diff.history_edited:
|
||||
logger.debug("Context update: history edited.")
|
||||
if diff.tools_diff.has_changes():
|
||||
logger.debug("Context update: tools changed.")
|
||||
if diff.messages_appended and messages_programmatically_edited:
|
||||
logger.debug(f"Context update: messages programmatically appended.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
# Check for set of completed function calls in the context
|
||||
@@ -1042,6 +1119,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
):
|
||||
# Found a newly-completed function call - send the result to the service
|
||||
if send_new_results:
|
||||
logger.debug(
|
||||
f"Sending newly-completed tool call result for tool '{tool_name}'"
|
||||
)
|
||||
await self._tool_result(
|
||||
tool_call_id, tool_name, part.function_response.response
|
||||
)
|
||||
@@ -1383,7 +1463,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
return
|
||||
|
||||
adapter: GeminiLLMAdapter = self.get_llm_adapter()
|
||||
messages = adapter.get_llm_invocation_params(self._context).get("messages", [])
|
||||
messages = adapter.get_llm_invocation_params(
|
||||
self._context, convert_function_messages_to_text=True
|
||||
).get("messages", [])
|
||||
if not messages:
|
||||
return
|
||||
|
||||
@@ -1413,7 +1495,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
# in the right format
|
||||
context = LLMContext(messages=messages_list)
|
||||
adapter: GeminiLLMAdapter = self.get_llm_adapter()
|
||||
messages = adapter.get_llm_invocation_params(context).get("messages", [])
|
||||
messages = adapter.get_llm_invocation_params(
|
||||
context, convert_function_messages_to_text=True
|
||||
).get("messages", [])
|
||||
|
||||
if not messages:
|
||||
return
|
||||
|
||||
@@ -1057,5 +1057,287 @@ class TestLLMAssistantAggregator(
|
||||
)
|
||||
|
||||
|
||||
class TestLLMContextDiff(unittest.TestCase):
|
||||
"""Tests for the LLMContext.diff() method."""
|
||||
|
||||
def test_diff_identical_contexts(self):
|
||||
"""Test diff of two identical contexts returns no changes."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context1 = LLMContext(messages=messages.copy())
|
||||
context2 = LLMContext(messages=messages.copy())
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
self.assertFalse(diff.history_edited)
|
||||
self.assertEqual(diff.tool_calls_resolved, [])
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
self.assertFalse(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_messages_appended(self):
|
||||
"""Test diff detects appended messages."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi there!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1])
|
||||
context2 = LLMContext(messages=[msg1, msg2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(len(diff.messages_appended), 1)
|
||||
self.assertEqual(diff.messages_appended[0], msg2)
|
||||
self.assertFalse(diff.history_edited)
|
||||
|
||||
def test_diff_multiple_messages_appended(self):
|
||||
"""Test diff detects multiple appended messages."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
msg3 = {"role": "user", "content": "How are you?"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1])
|
||||
context2 = LLMContext(messages=[msg1, msg2, msg3])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(len(diff.messages_appended), 2)
|
||||
self.assertEqual(diff.messages_appended[0], msg2)
|
||||
self.assertEqual(diff.messages_appended[1], msg3)
|
||||
self.assertFalse(diff.history_edited)
|
||||
|
||||
def test_diff_message_removed(self):
|
||||
"""Test diff detects message removal as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2])
|
||||
context2 = LLMContext(messages=[msg1])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.messages_appended, []) # Empty when history edited
|
||||
self.assertTrue(diff.history_edited)
|
||||
|
||||
def test_diff_message_modified(self):
|
||||
"""Test diff detects message modification as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2_v1 = {"role": "assistant", "content": "Hi!"}
|
||||
msg2_v2 = {"role": "assistant", "content": "Hello there!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2_v1])
|
||||
context2 = LLMContext(messages=[msg1, msg2_v2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.history_edited)
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
|
||||
def test_diff_message_inserted_in_middle(self):
|
||||
"""Test diff detects message insertion in middle as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
msg_inserted = {"role": "system", "content": "System message"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2])
|
||||
context2 = LLMContext(messages=[msg1, msg_inserted, msg2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.history_edited)
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
|
||||
def test_diff_tool_call_resolved_to_result(self):
|
||||
"""Test diff detects tool call resolution to actual result."""
|
||||
msg1 = {"role": "user", "content": "What's the weather?"}
|
||||
msg2 = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_123", "function": {"name": "get_weather", "arguments": "{}"}}
|
||||
],
|
||||
}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
|
||||
tool_resolved = {
|
||||
"role": "tool",
|
||||
"content": '{"temperature": 72}',
|
||||
"tool_call_id": "call_123",
|
||||
}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, msg2, tool_resolved])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_123"])
|
||||
# Note: the tool message content changed, so history is edited
|
||||
self.assertTrue(diff.history_edited)
|
||||
|
||||
def test_diff_tool_call_resolved_to_completed(self):
|
||||
"""Test diff detects tool call resolution to COMPLETED."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_456"}
|
||||
tool_completed = {"role": "tool", "content": "COMPLETED", "tool_call_id": "call_456"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_completed])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_456"])
|
||||
|
||||
def test_diff_tool_call_resolved_to_cancelled(self):
|
||||
"""Test diff detects tool call resolution to CANCELLED."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_789"}
|
||||
tool_cancelled = {"role": "tool", "content": "CANCELLED", "tool_call_id": "call_789"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_cancelled])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_789"])
|
||||
|
||||
def test_diff_tool_call_still_in_progress(self):
|
||||
"""Test diff does not report tool call as resolved if still IN_PROGRESS."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, [])
|
||||
|
||||
def test_diff_tool_choice_changed(self):
|
||||
"""Test diff detects tool_choice changes."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
context2 = LLMContext(messages=[msg1], tool_choice="none")
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_tool_choice_unchanged(self):
|
||||
"""Test diff reports no change when tool_choice is the same."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
context2 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_empty_contexts(self):
|
||||
"""Test diff of two empty contexts returns no changes."""
|
||||
context1 = LLMContext()
|
||||
context2 = LLMContext()
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
|
||||
class TestLLMContextDiffWithTools(unittest.TestCase):
|
||||
"""Tests for LLMContext.diff() with tools configuration changes."""
|
||||
|
||||
def _create_tools_schema(self, tool_names: list[str]) -> "ToolsSchema":
|
||||
"""Helper to create a ToolsSchema with named tools."""
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
tools = [
|
||||
FunctionSchema(name=name, description=f"Test {name}", properties={}, required=[])
|
||||
for name in tool_names
|
||||
]
|
||||
return ToolsSchema(standard_tools=tools)
|
||||
|
||||
def test_diff_tools_added_from_not_given(self):
|
||||
"""Test diff detects tools being added when self has no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools = self._create_tools_schema(["get_weather", "get_time"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
context2 = LLMContext(messages=[msg1], tools=tools)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(sorted(diff.tools_diff.standard_tools_added), ["get_time", "get_weather"])
|
||||
self.assertEqual(diff.tools_diff.standard_tools_removed, [])
|
||||
|
||||
def test_diff_tools_removed_to_not_given(self):
|
||||
"""Test diff detects tools being removed when other has no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools = self._create_tools_schema(["get_weather", "get_time"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=tools)
|
||||
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tools_diff.standard_tools_added, [])
|
||||
self.assertEqual(
|
||||
sorted(diff.tools_diff.standard_tools_removed), ["get_time", "get_weather"]
|
||||
)
|
||||
|
||||
def test_diff_both_not_given(self):
|
||||
"""Test diff returns None tools_diff when both have no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
|
||||
def test_diff_tools_modified(self):
|
||||
"""Test diff detects tool modification via ToolsSchema.diff()."""
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
tool_v1 = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get weather v1",
|
||||
properties={"location": {"type": "string"}},
|
||||
required=["location"],
|
||||
)
|
||||
tool_v2 = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get weather v2",
|
||||
properties={"city": {"type": "string"}},
|
||||
required=["city"],
|
||||
)
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v1]))
|
||||
context2 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v2]))
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.tools_diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tools_unchanged(self):
|
||||
"""Test diff returns None tools_diff when tools are identical."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools1 = self._create_tools_schema(["get_weather"])
|
||||
tools2 = self._create_tools_schema(["get_weather"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=tools1)
|
||||
context2 = LLMContext(messages=[msg1], tools=tools2)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMTextFrame,
|
||||
@@ -147,6 +148,52 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_transform(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that keeps only user messages
|
||||
def keep_user_messages(messages):
|
||||
return [m for m in messages if m["role"] == "user"]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[0]["content"] == "Hello"
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_llm_messages_transform_run(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages([{"role": "user", "content": "Hello"}])
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that modifies the content
|
||||
def uppercase_content(messages):
|
||||
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
|
||||
expected_down_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(context)
|
||||
|
||||
184
tests/test_tools_schema.py
Normal file
184
tests/test_tools_schema.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
|
||||
|
||||
|
||||
class TestToolsSchemaDiff(unittest.TestCase):
|
||||
"""Tests for the ToolsSchemaDiff dataclass."""
|
||||
|
||||
def test_has_changes_empty(self):
|
||||
"""Test has_changes returns False for empty diff."""
|
||||
diff = ToolsSchemaDiff()
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_added(self):
|
||||
"""Test has_changes returns True when tools are added."""
|
||||
diff = ToolsSchemaDiff(standard_tools_added=["tool1"])
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_removed(self):
|
||||
"""Test has_changes returns True when tools are removed."""
|
||||
diff = ToolsSchemaDiff(standard_tools_removed=["tool1"])
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_modified(self):
|
||||
"""Test has_changes returns True when tools are modified."""
|
||||
diff = ToolsSchemaDiff(standard_tools_modified=True)
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_custom_changed(self):
|
||||
"""Test has_changes returns True when custom tools changed."""
|
||||
diff = ToolsSchemaDiff(custom_tools_changed=True)
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
|
||||
class TestToolsSchemaDiffMethod(unittest.TestCase):
|
||||
"""Tests for the ToolsSchema.diff() method."""
|
||||
|
||||
def _create_function_schema(
|
||||
self, name: str, description: str = "Test function", properties: dict = None
|
||||
) -> FunctionSchema:
|
||||
"""Helper to create a FunctionSchema."""
|
||||
return FunctionSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=[],
|
||||
)
|
||||
|
||||
def test_diff_identical_schemas(self):
|
||||
"""Test diff of two identical schemas returns no changes."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(standard_tools=[self._create_function_schema("get_weather")])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
self.assertFalse(diff.custom_tools_changed)
|
||||
|
||||
def test_diff_tool_added(self):
|
||||
"""Test diff detects added tools."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
tool2 = self._create_function_schema("get_time")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1, tool2])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, ["get_time"])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tool_removed(self):
|
||||
"""Test diff detects removed tools."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
tool2 = self._create_function_schema("get_time")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1, tool2])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, ["get_time"])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tool_modified(self):
|
||||
"""Test diff detects modified tools (same name, different definition)."""
|
||||
tool1_v1 = self._create_function_schema(
|
||||
"get_weather", description="Get weather v1", properties={"location": {"type": "string"}}
|
||||
)
|
||||
tool1_v2 = self._create_function_schema(
|
||||
"get_weather",
|
||||
description="Get weather v2",
|
||||
properties={"city": {"type": "string"}},
|
||||
)
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1_v1])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1_v2])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertTrue(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_multiple_changes(self):
|
||||
"""Test diff with multiple types of changes."""
|
||||
tool_keep = self._create_function_schema("keep_tool")
|
||||
tool_remove = self._create_function_schema("remove_tool")
|
||||
tool_add = self._create_function_schema("add_tool")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool_keep, tool_remove])
|
||||
schema2 = ToolsSchema(standard_tools=[tool_keep, tool_add])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, ["add_tool"])
|
||||
self.assertEqual(diff.standard_tools_removed, ["remove_tool"])
|
||||
|
||||
def test_diff_empty_schemas(self):
|
||||
"""Test diff of two empty schemas returns no changes."""
|
||||
schema1 = ToolsSchema(standard_tools=[])
|
||||
schema2 = ToolsSchema(standard_tools=[])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
def test_diff_custom_tools_changed(self):
|
||||
"""Test diff detects custom tools changes."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
custom1 = {AdapterType.GEMINI: [{"name": "search"}]}
|
||||
custom2 = {AdapterType.GEMINI: [{"name": "search_v2"}]}
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1], custom_tools=custom1)
|
||||
schema2 = ToolsSchema(standard_tools=[tool1], custom_tools=custom2)
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
# Standard tools unchanged
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_custom_tools_added(self):
|
||||
"""Test diff detects custom tools being added."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(
|
||||
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
|
||||
)
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
|
||||
def test_diff_custom_tools_removed(self):
|
||||
"""Test diff detects custom tools being removed."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
|
||||
schema1 = ToolsSchema(
|
||||
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
|
||||
)
|
||||
schema2 = ToolsSchema(standard_tools=[tool1])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user