Compare commits

...

12 Commits

Author SHA1 Message Date
Paul Kompfner
b776049535 Tweak CHANGELOG entry 2026-02-06 09:52:08 -05:00
Paul Kompfner
3aa403a16f Add a LLMMessagesTransformFrame to facilitate programmatically editing context in a frame-based way.
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages *at that point in time*, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
2026-02-05 16:37:45 -05:00
Paul Kompfner
50dace147d Remove stray debugging print statements 2026-02-05 12:37:28 -05:00
Paul Kompfner
e039abd290 Add README for changes to make Gemini Live work with Pipecat Flows 2026-02-05 11:27:57 -05:00
Paul Kompfner
63f88c0add Deal with some minor code TODOs 2026-02-04 15:26:52 -05:00
Paul Kompfner
2aef572e38 Add to the Gemini Live function calling example runtime tool addition with LLMSetToolsFrame 2026-02-04 15:05:09 -05:00
Paul Kompfner
ba3100be0d Another change to support Gemini Live in Pipecat Flows: rather than strip function call and response messages out of context before sending to Gemini Live when seeding conversation history, which we were doing to sidestep a seeming Gemini Live limitation (see https://stackoverflow.com/a/79851394), convert them to regular text messages with special formatting 2026-02-04 11:44:58 -05:00
Paul Kompfner
9e65e77095 Another change to support Gemini Live in Pipecat Flows: now even in the case of programmatically-appended context messages, do a full reconnection—I've tried many things to get "live updates" to context working reliably and have been unable to (Gemini code comments also warn against doing "live updates" after the initial context seeding prior to starting audio input). 2026-02-03 21:22:24 -05:00
Paul Kompfner
3183f9c077 Revert "Another change to support Gemini Live in Pipecat Flows: avoid triggering a model response twice when there is a Flows node transition that returns a value from the transition function and loads new context messages in the new node"
This reverts commit 90e6f0dca8.
2026-02-03 17:02:31 -05:00
Paul Kompfner
90e6f0dca8 Another change to support Gemini Live in Pipecat Flows: avoid triggering a model response twice when there is a Flows node transition that returns a value from the transition function and loads new context messages in the new node 2026-02-03 16:56:34 -05:00
Paul Kompfner
44b917c546 Another change to support Gemini Live in Pipecat Flows: if the only change to the context is newly-appended messages, send them to the server.
This requires us to distinguish between newly-appended "bookkeeping" messages that just reflect what Gemini Live already said, and messages that were programmatically inserted, such as from the the transition to the new Pipecat Flows node.

This change makes it so that using `LLMMessagesAppendFrame` will have the desired effect, of updating the Gemini Live conversation.
2026-02-03 15:45:00 -05:00
Paul Kompfner
c1ac1a6326 Changes to support Gemini Live in Pipecat Flows:
- Detect when a newly-received context warrants a reconnection to the Gemini Live API; we need to reconnect in order to re-seed new conversation history or swap out the current set of tools. This reconnection occurs when Pipecat Flows transitions between conversational nodes, as context has been edited and/or tools added/removed.
- Strip function call and response messages out of context before sending to Gemini Live when seeding conversation history, to sidestep a seeming Gemini Live limitation (see https://stackoverflow.com/a/79851394)
2026-02-02 16:34:12 -05:00
13 changed files with 1069 additions and 44 deletions

View File

@@ -0,0 +1,3 @@
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.

1
changelog/3620.added.md Normal file
View File

@@ -0,0 +1 @@
- Added support to Gemini Live (`GeminiLiveLLMService`) for programmatically swapping tools or editing context at runtime; now you can use `LLMMessagesAppendFrame`, `LLMMessagesUpdateFrame`, `LLMMessagesTransformFrame`, and `LLMSetToolsFrame` with Gemini Live, like you would with text-to-text services. Note that this new functionality only works if you're using `LLMContext` and `LLMContextAggregatorPair` rather than the deprecated `OpenAILLMContext` and associated aggregators.

View File

@@ -57,6 +57,10 @@ async def fetch_weather_from_api(params: FunctionCallParams):
)
async def fetch_restaurant_recommendation(params: FunctionCallParams):
await params.result_callback({"name": "The Golden Dragon"})
async def get_news(params: FunctionCallParams):
await params.result_callback(
{
@@ -69,10 +73,6 @@ async def get_news(params: FunctionCallParams):
)
async def fetch_restaurant_recommendation(params: FunctionCallParams):
await params.result_callback({"name": "The Golden Dragon"})
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
@@ -90,13 +90,6 @@ weather_function = FunctionSchema(
required=["location", "format"],
)
get_news_function = FunctionSchema(
name="get_news",
description="Get the current news.",
properties={},
required=[],
)
restaurant_function = FunctionSchema(
name="get_restaurant_recommendation",
description="Get a restaurant recommendation",
@@ -109,6 +102,13 @@ restaurant_function = FunctionSchema(
required=["location"],
)
get_news_function = FunctionSchema(
name="get_news",
description="Get the current news.",
properties={},
required=[],
)
# Create tools schema
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
@@ -215,8 +215,9 @@ Remember, your responses should be short. Just one or two sentences, usually. Re
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
# Add a new tool at runtime after a delay.
# Add a new tool (get_news) at runtime after a delay
await asyncio.sleep(15)
logger.info(f"Adding new tool get_news at runtime...")
new_tools = ToolsSchema(
standard_tools=[weather_function, restaurant_function, get_news_function]
)

View File

@@ -5,6 +5,7 @@
#
import asyncio
import os
from datetime import datetime
@@ -15,7 +16,7 @@ from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -51,6 +52,18 @@ async def fetch_restaurant_recommendation(params: FunctionCallParams):
await params.result_callback({"name": "The Golden Dragon"})
async def get_news(params: FunctionCallParams):
await params.result_callback(
{
"news": [
"Massive UFO currently hovering above New York City",
"Stock markets reach all-time highs",
"Living dinosaur species discovered in the Amazon rainforest",
],
}
)
system_instruction = """
You are a helpful assistant who can answer questions and use tools.
@@ -109,6 +122,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
},
required=["location"],
)
get_news_function = FunctionSchema(
name="get_news",
description="Get the current news.",
properties={},
required=[],
)
search_tool = {"google_search": {}}
# KNOWN ISSUE: If using GeminiVertexLiveLLMService, it appears
# you cannot use the "google_search" tool alongside other tools.
@@ -126,6 +145,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
llm.register_function("get_news", get_news)
# You can provide the system instructions and tools in the context rather
# than as arguments to GeminiLiveLLMService, but note that doing so will
@@ -174,6 +194,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
# Add a new tool (get_news) at runtime after a delay
await asyncio.sleep(15)
logger.info(f"Adding new tool get_news at runtime...")
new_tools = ToolsSchema(
standard_tools=[weather_function, restaurant_function, get_news_function]
)
await task.queue_frames([LLMSetToolsFrame(tools=new_tools)])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")

View File

@@ -10,6 +10,7 @@ This module provides schemas for managing both standardized function tools
and custom adapter-specific tools in the Pipecat framework.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
@@ -17,6 +18,36 @@ from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunct
from pipecat.adapters.schemas.function_schema import FunctionSchema
@dataclass
class ToolsSchemaDiff:
"""Represents the differences between two ToolsSchema instances.
Parameters:
standard_tools_added: Names of newly added standard tools.
standard_tools_removed: Names of removed standard tools.
standard_tools_modified: True if any existing standard tool's definition changed.
custom_tools_changed: True if the custom_tools dictionary differs.
"""
standard_tools_added: List[str] = field(default_factory=list)
standard_tools_removed: List[str] = field(default_factory=list)
standard_tools_modified: bool = False
custom_tools_changed: bool = False
def has_changes(self) -> bool:
"""Check if there are any differences.
Returns:
True if any field indicates a change, False otherwise.
"""
return bool(
self.standard_tools_added
or self.standard_tools_removed
or self.standard_tools_modified
or self.custom_tools_changed
)
class AdapterType(Enum):
"""Supported adapter types for custom tools.
@@ -92,3 +123,44 @@ class ToolsSchema:
value: Dictionary mapping adapter types to their custom tool definitions.
"""
self._custom_tools = value
def diff(self, other: "ToolsSchema") -> ToolsSchemaDiff:
"""Compare this ToolsSchema to another and return the differences.
Args:
other: The ToolsSchema to compare against (the "after" state).
Returns:
ToolsSchemaDiff containing the differences between self and other.
"""
result = ToolsSchemaDiff()
# Build maps of tool name -> FunctionSchema for comparison
self_tools_by_name: Dict[str, FunctionSchema] = {
tool.name: tool for tool in self._standard_tools
}
other_tools_by_name: Dict[str, FunctionSchema] = {
tool.name: tool for tool in other._standard_tools
}
self_names = set(self_tools_by_name.keys())
other_names = set(other_tools_by_name.keys())
# Find added and removed tools
result.standard_tools_added = sorted(other_names - self_names)
result.standard_tools_removed = sorted(self_names - other_names)
# Check for modified tools (same name, different definition)
common_names = self_names & other_names
for name in common_names:
self_tool = self_tools_by_name[name]
other_tool = other_tools_by_name[name]
# Compare using to_default_dict() for full schema comparison
if self_tool.to_default_dict() != other_tool.to_default_dict():
result.standard_tools_modified = True
break
# Compare custom tools
result.custom_tools_changed = self._custom_tools != other._custom_tools
return result

View File

@@ -53,19 +53,32 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
"""Get the identifier used in LLMSpecificMessage instances for Google."""
return "google"
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
def get_llm_invocation_params(
self, context: LLMContext, *, convert_function_messages_to_text: bool = False
) -> GeminiLLMInvocationParams:
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
Args:
context: The LLM context containing messages, tools, etc.
convert_function_messages_to_text: If True, convert function_call and function_response
parts to specially-formatted text messages. This is needed for Gemini Live
(at least with "models/gemini-2.5-flash-native-audio-preview-12-2025", the
default at the time of this writing) which cannot handle function-call-related
messages when initializing conversation history.
See https://stackoverflow.com/a/79851394.
Returns:
Dictionary of parameters for Gemini's API.
"""
messages = self._from_universal_context_messages(self.get_messages(context))
converted = self._from_universal_context_messages(self.get_messages(context))
messages = converted.messages
if convert_function_messages_to_text:
messages = self._convert_function_messages_to_text(messages)
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
"system_instruction": converted.system_instruction,
"messages": messages,
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
}
@@ -668,3 +681,43 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
return True
return False
def _convert_function_messages_to_text(self, messages: List[Content]) -> List[Content]:
"""Convert function_call and function_response parts to text messages.
Args:
messages: List of Content messages to process.
Returns:
List of Content messages with function-related parts converted to text.
"""
converted_messages = []
for msg in messages:
if msg.parts:
converted_parts = []
for part in msg.parts:
if func_call := getattr(part, "function_call", None):
# Convert function call to text
args_str = json.dumps(func_call.args) if func_call.args else "{}"
text = (
f"[Historical function call (for context only, not a template): "
f"{func_call.name}({args_str})]"
)
converted_parts.append(Part(text=text))
elif func_response := getattr(part, "function_response", None):
# Convert function response to text
response_str = (
json.dumps(func_response.response) if func_response.response else "{}"
)
text = (
f"[Historical function result (for context only): "
f"{func_response.name} returned {response_str}]"
)
converted_parts.append(Part(text=text))
else:
converted_parts.append(part)
if converted_parts:
converted_messages.append(Content(role=msg.role, parts=converted_parts))
else:
converted_messages.append(msg)
return converted_messages

View File

@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
from pipecat.processors.frame_processor import FrameProcessor
@@ -667,9 +667,16 @@ class LLMContextFrame(Frame):
Parameters:
context: The LLM context containing messages, tools, and configuration.
messages_programmatically_edited: Whether the context messages were
programmatically edited (e.g. via LLMMessagesAppendFrame or
LLMMessagesUpdateFrame) since the last context frame was pushed.
This is used by speech-to-speech LLM services (like Gemini Live) to
distinguish between messages that originated from the LLM output
itself vs. messages that were externally injected.
"""
context: "LLMContext"
messages_programmatically_edited: bool = False
@dataclass
@@ -822,6 +829,25 @@ class LLMMessagesUpdateFrame(DataFrame):
run_llm: Optional[bool] = None
@dataclass
class LLMMessagesTransformFrame(DataFrame):
"""Frame containing a transform function to modify the current context's LLM messages.
A frame containing a transform function that takes the context's current list
of LLM messages and returns a modified list.
Only compatible with LLMContext and not the deprecated OpenAILLMContext.
Parameters:
transform: A function that takes a list of messages and returns a
modified list.
run_llm: Whether the context update should be sent to the LLM.
"""
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
run_llm: Optional[bool] = None
@dataclass
class LLMSetToolsFrame(DataFrame):
"""Frame containing tools for LLM function calling.

View File

@@ -18,8 +18,8 @@ import asyncio
import base64
import io
import wave
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeAlias, Union
from loguru import logger
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
@@ -30,7 +30,7 @@ from openai.types.chat import (
)
from PIL import Image
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
from pipecat.frames.frames import AudioRawFrame
if TYPE_CHECKING:
@@ -47,6 +47,39 @@ NOT_GIVEN = OPEN_AI_NOT_GIVEN
NotGiven = OpenAINotGiven
@dataclass
class LLMContextDiff:
"""Represents the differences between two LLMContext instances.
Parameters:
messages_appended: New messages appended at the end. Empty if history_edited is True.
history_edited: True if earlier messages were changed, inserted, or removed.
tool_calls_resolved: List of tool_call_ids that changed from "IN_PROGRESS" to a result.
tools_diff: Differences in tools configuration, or None if unchanged or both NOT_GIVEN.
tool_choice_changed: True if the tool_choice setting differs.
"""
messages_appended: List["LLMContextMessage"] = field(default_factory=list)
history_edited: bool = False
tool_calls_resolved: List[str] = field(default_factory=list)
tools_diff: ToolsSchemaDiff = field(default_factory=ToolsSchemaDiff)
tool_choice_changed: bool = False
def has_changes(self) -> bool:
"""Check if there are any differences.
Returns:
True if any field indicates a change, False otherwise.
"""
return bool(
self.messages_appended
or self.history_edited
or self.tool_calls_resolved
or self.tools_diff.has_changes()
or self.tool_choice_changed
)
@dataclass
class LLMSpecificMessage:
"""A container for a context message that is specific to a particular LLM service.
@@ -341,6 +374,19 @@ class LLMContext:
"""
self._messages[:] = messages
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the current messages using the provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
current_messages = self._messages
new_messages = transform(current_messages)
self.set_messages(new_messages)
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
"""Set the available tools for the LLM.
@@ -410,3 +456,137 @@ class LLMContext:
raise TypeError(
f"In LLMContext, tools must be a ToolsSchema object or NOT_GIVEN. Got type: {type(tools)}",
)
def diff(self, other: "LLMContext") -> LLMContextDiff:
"""Compare this context to another and return the differences.
Compares self (the "before" state) to other (the "after" state) and
identifies what has changed.
Args:
other: The LLMContext to compare against (the "after" state).
Returns:
ContextDiff containing the differences between self and other.
"""
result = LLMContextDiff()
# Compare messages
self_messages = self._messages
other_messages = other._messages
self_len = len(self_messages)
other_len = len(other_messages)
# Check if history was edited (messages removed, modified, or inserted in the middle)
if other_len < self_len:
# Messages were removed
result.history_edited = True
else:
# Check if the prefix matches (first self_len messages should be identical)
for i in range(self_len):
if not self._messages_equal(self_messages[i], other_messages[i]):
result.history_edited = True
break
# If history wasn't edited, capture appended messages
if not result.history_edited and other_len > self_len:
result.messages_appended = other_messages[self_len:]
# Find resolved tool calls (IN_PROGRESS -> something else)
result.tool_calls_resolved = self._find_resolved_tool_calls(other)
# Compare tools
result.tools_diff = self._compute_tools_diff(other)
# Compare tool_choice
# (For some reason if they're both NOT_GIVEN, equality check returns False?)
if not self._tool_choice and not other._tool_choice:
result.tool_choice_changed = False
else:
result.tool_choice_changed = self._tool_choice != other._tool_choice
return result
def _messages_equal(self, msg1: LLMContextMessage, msg2: LLMContextMessage) -> bool:
"""Compare two messages for equality.
Args:
msg1: First message to compare.
msg2: Second message to compare.
Returns:
True if the messages are equal, False otherwise.
"""
# Handle LLMSpecificMessage
if isinstance(msg1, LLMSpecificMessage) and isinstance(msg2, LLMSpecificMessage):
return msg1.llm == msg2.llm and msg1.message == msg2.message
elif isinstance(msg1, LLMSpecificMessage) or isinstance(msg2, LLMSpecificMessage):
return False
# Both are standard messages (dicts)
return msg1 == msg2
def _find_resolved_tool_calls(self, other: "LLMContext") -> List[str]:
"""Find tool calls that changed from IN_PROGRESS to a resolved state.
Args:
other: The context to compare against (the "after" state).
Returns:
List of tool_call_ids that were resolved.
"""
resolved: List[str] = []
# Build a map of tool_call_id -> content for "other" context
other_tool_contents: Dict[str, Any] = {}
for msg in other._messages:
if isinstance(msg, dict) and msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id")
if tool_call_id:
other_tool_contents[tool_call_id] = msg.get("content")
# Find tool messages in self that are IN_PROGRESS but resolved in other
for msg in self._messages:
if isinstance(msg, dict) and msg.get("role") == "tool":
if msg.get("content") == "IN_PROGRESS":
tool_call_id = msg.get("tool_call_id")
if tool_call_id and tool_call_id in other_tool_contents:
other_content = other_tool_contents[tool_call_id]
if other_content != "IN_PROGRESS":
resolved.append(tool_call_id)
return resolved
def _compute_tools_diff(self, other: "LLMContext") -> ToolsSchemaDiff:
"""Compute the difference in tools between self and other.
Args:
other: The context to compare against (the "after" state).
Returns:
ToolsSchemaDiff if there are changes, None if both are NOT_GIVEN or identical.
"""
self_has_tools = isinstance(self._tools, ToolsSchema)
other_has_tools = isinstance(other._tools, ToolsSchema)
if not self_has_tools and not other_has_tools:
# Both are NOT_GIVEN
return ToolsSchemaDiff()
if self_has_tools and other_has_tools:
# Both have tools - use ToolsSchema.diff()
diff = self._tools.diff(other._tools)
return diff
if not self_has_tools and other_has_tools:
# Tools were added (self is NOT_GIVEN, other has tools)
return ToolsSchemaDiff(
standard_tools_added=[tool.name for tool in other._tools.standard_tools],
custom_tools_changed=other._tools.custom_tools is not None,
)
# Tools were removed (self has tools, other is NOT_GIVEN)
return ToolsSchemaDiff(
standard_tools_removed=[tool.name for tool in self._tools.standard_tools],
custom_tools_changed=self._tools.custom_tools is not None,
)

View File

@@ -16,7 +16,7 @@ import json
import warnings
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Set, Type
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
from loguru import logger
@@ -40,6 +40,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMSetToolChoiceFrame,
@@ -231,7 +232,16 @@ class LLMContextAggregator(FrameProcessor):
Returns:
LLMContextFrame containing the current context.
"""
return LLMContextFrame(context=self._context)
# Check if messages were programmatically edited since the last push.
# This flag is stored as a runtime attribute on the shared context
# object so that both user and assistant aggregators can see it.
messages_programmatically_edited = getattr(
self._context, "_pipecat_messages_programmatically_edited", False
)
return LLMContextFrame(
context=self._context,
messages_programmatically_edited=messages_programmatically_edited,
)
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a context frame in the specified direction.
@@ -241,6 +251,9 @@ class LLMContextAggregator(FrameProcessor):
"""
frame = self._get_context_frame()
await self.push_frame(frame, direction)
# Clear the programmatic edit flag after pushing, since the context
# frame now carries this information to downstream processors.
self._context._pipecat_messages_programmatically_edited = False
def add_messages(self, messages):
"""Add messages to the context.
@@ -258,6 +271,17 @@ class LLMContextAggregator(FrameProcessor):
"""
self._context.set_messages(messages)
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the context messages using a provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
self._context.transform_messages(transform)
def set_tools(self, tools: ToolsSchema | NotGiven):
"""Set tools in the context.
@@ -458,13 +482,17 @@ class LLMUserAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
# services (like OpenAI Realtime) may need to know about tool
# changes; unlike text-based LLM services they won't just "pick up
# the change" on the next LLM run, as the LLM is continuously
# running.
# Push the LLMSetToolsFrame as well, since some realtime (aka
# speech-to-speech) LLM services (like OpenAI Realtime) may need to
# be directly be informed of tool changes; unlike text-based LLM
# services they can't necessarily rely on "picking up the change"
# from the context on the next LLM run, as the LLM is continuously
# running and they may need to apply the change sooner than the
# next context frame.
await self.push_frame(frame, direction)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
@@ -573,11 +601,28 @@ class LLMUserAggregator(LLMContextAggregator):
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
# Mark the context as programmatically edited. This flag is stored as a
# runtime attribute on the shared context object so that both user and
# assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
# Mark the context as programmatically edited. This flag is stored as a
# runtime attribute on the shared context object so that both user and
# assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
# Mark the context as programmatically edited. This flag is stored as a
# runtime attribute on the shared context object so that both user and
# assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame()
@@ -866,6 +911,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
@@ -909,11 +956,28 @@ class LLMAssistantAggregator(LLMContextAggregator):
async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame):
self.add_messages(frame.messages)
# Mark the context as programmatically edited. This flag is stored as a
# runtime attribute on the shared context object so that both user and
# assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
# Mark the context as programmatically edited. This flag is stored as
# a runtime attribute on the shared context object so that both user
# and assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
# Mark the context as programmatically edited. This flag is stored as a
# runtime attribute on the shared context object so that both user and
# assistant aggregators can see it.
self._context._pipecat_messages_programmatically_edited = True
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)

View File

@@ -13,6 +13,7 @@ voice transcription, streaming responses, and tool usage.
import asyncio
import base64
import copy
import io
import time
import uuid
@@ -59,7 +60,12 @@ from pipecat.frames.frames import (
UserStoppedSpeakingFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import (
NOT_GIVEN,
LLMContext,
LLMContextDiff,
LLMContextMessage,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -684,6 +690,7 @@ class GeminiLiveLLMService(LLMService):
self._audio_input_paused = start_audio_paused
self._video_input_paused = start_video_paused
self._context = None
self._context_snapshot = None
self._api_key = api_key
self._http_options = update_google_client_http_options(http_options)
self._session: AsyncSession = None
@@ -826,6 +833,9 @@ class GeminiLiveLLMService(LLMService):
logger.error("Context already set. Can only set up Gemini Live context once.")
return
self._context = GeminiLiveContext.upgrade(context)
self._context_snapshot = copy.deepcopy(
self._context
) # Take a static snapshot guaranteed not to change
await self._create_initial_response()
#
@@ -915,7 +925,12 @@ class GeminiLiveLLMService(LLMService):
if isinstance(frame, LLMContextFrame)
else LLMContext.from_openai_context(frame.context)
)
await self._handle_context(context)
messages_programmatically_edited = (
frame.messages_programmatically_edited
if isinstance(frame, LLMContextFrame)
else False
)
await self._handle_context(context, messages_programmatically_edited)
elif isinstance(frame, InputTextRawFrame):
await self._send_user_text(frame.text)
await self.push_frame(frame, direction)
@@ -950,14 +965,21 @@ class GeminiLiveLLMService(LLMService):
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
elif isinstance(frame, LLMSetToolsFrame):
await self._update_settings()
# We actually don't need to do anything here; the next time we get
# a context frame (like after a user transcription is appended),
# we'll detect that tools have changed and reconnect then to apply
# the new tools.
pass
else:
await self.push_frame(frame, direction)
async def _handle_context(self, context: LLMContext):
async def _handle_context(self, context: LLMContext, messages_programmatically_edited: bool):
if not self._context:
# We got our initial context
self._context = context
self._context_snapshot = copy.deepcopy(
context
) # Take a static snapshot guaranteed not to change
# If context contains system instruction or tools, reconnect in
# order to apply them.
@@ -981,6 +1003,7 @@ class GeminiLiveLLMService(LLMService):
"Tools provided both at init time and in context; using context-provided value."
)
if system_instruction or tools:
self._session_resumption_handle = None
await self._reconnect()
# Initialize our bookkeeping of already-completed tool calls in
@@ -1011,17 +1034,71 @@ class GeminiLiveLLMService(LLMService):
# We got an updated context.
self._context = context
# Here we assume that the updated context will contain either:
# - new messages (that the Gemini Live service, with its own
# context management, is already aware of), or
# - tool call results (that we need to tell the remote service
# about).
# (In the future, we could do more sophisticated diffing here,
# which would enable the user to programmatically manipulate the
# context).
diff = self._context_snapshot.diff(self._context)
self._context_snapshot = copy.deepcopy(
context
) # Take a static snapshot guaranteed not to change
# Send results for newly-completed function calls, if any.
await self._process_completed_function_calls(send_new_results=True)
if self._context_update_requires_reconnect(diff, messages_programmatically_edited):
logger.debug("Context update requires reconnect.")
# Reconnect
self._session_resumption_handle = None
await self._reconnect()
# Initialize our bookkeeping of already-completed tool calls in
# the context
await self._process_completed_function_calls(send_new_results=False)
# Trigger "initial" response with new connection
await self._create_initial_response()
else:
logger.debug("Context update does not require reconnect.")
# Send results for newly-completed function calls, if any.
await self._process_completed_function_calls(send_new_results=True)
def _context_update_requires_reconnect(
self, diff: LLMContextDiff, messages_programmatically_edited: bool
) -> bool:
"""Check if an update to our LLM context requires reconnection.
Args:
diff: The context diff representing the update
messages_programmatically_edited: Whether context messages were
programmatically edited (e.g. with LLMMessagesAppendFrame)
Returns:
True if reconnection is required, False otherwise.
"""
# We need to reconnect in 3 cases:
# 1. If the conversation history was edited
# 2. If the tools available to the model were changed
# 3. If messages were appended programmatically by the user, e.g. with
# LLMMessagesAppendFrame (NOT if they originated from Gemini Live
# itself, since the remote service's internal bookkeeping is already
# aware of those)
#
# Note that *ideally* in this 3rd case we would just send the newly
# appended messages without reconnecting, but in all my testing so far,
# I haven't been able to get that to work as reliably as I'd like.
# (Commments in the Gemini Live Python SDK also warn against
# programmatically sending new messages after the initial conversation
# history seeding is done and the voice chat has begun.)
if (
diff.history_edited
or diff.tools_diff.has_changes()
or (diff.messages_appended and messages_programmatically_edited)
):
if diff.history_edited:
logger.debug("Context update: history edited.")
if diff.tools_diff.has_changes():
logger.debug("Context update: tools changed.")
if diff.messages_appended and messages_programmatically_edited:
logger.debug(f"Context update: messages programmatically appended.")
return True
return False
async def _process_completed_function_calls(self, send_new_results: bool):
# Check for set of completed function calls in the context
@@ -1042,6 +1119,9 @@ class GeminiLiveLLMService(LLMService):
):
# Found a newly-completed function call - send the result to the service
if send_new_results:
logger.debug(
f"Sending newly-completed tool call result for tool '{tool_name}'"
)
await self._tool_result(
tool_call_id, tool_name, part.function_response.response
)
@@ -1383,7 +1463,9 @@ class GeminiLiveLLMService(LLMService):
return
adapter: GeminiLLMAdapter = self.get_llm_adapter()
messages = adapter.get_llm_invocation_params(self._context).get("messages", [])
messages = adapter.get_llm_invocation_params(
self._context, convert_function_messages_to_text=True
).get("messages", [])
if not messages:
return
@@ -1413,7 +1495,9 @@ class GeminiLiveLLMService(LLMService):
# in the right format
context = LLMContext(messages=messages_list)
adapter: GeminiLLMAdapter = self.get_llm_adapter()
messages = adapter.get_llm_invocation_params(context).get("messages", [])
messages = adapter.get_llm_invocation_params(
context, convert_function_messages_to_text=True
).get("messages", [])
if not messages:
return

View File

@@ -1057,5 +1057,287 @@ class TestLLMAssistantAggregator(
)
class TestLLMContextDiff(unittest.TestCase):
"""Tests for the LLMContext.diff() method."""
def test_diff_identical_contexts(self):
"""Test diff of two identical contexts returns no changes."""
messages = [{"role": "user", "content": "Hello"}]
context1 = LLMContext(messages=messages.copy())
context2 = LLMContext(messages=messages.copy())
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.messages_appended, [])
self.assertFalse(diff.history_edited)
self.assertEqual(diff.tool_calls_resolved, [])
self.assertFalse(diff.tools_diff.has_changes())
self.assertFalse(diff.tool_choice_changed)
def test_diff_messages_appended(self):
"""Test diff detects appended messages."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi there!"}
context1 = LLMContext(messages=[msg1])
context2 = LLMContext(messages=[msg1, msg2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(len(diff.messages_appended), 1)
self.assertEqual(diff.messages_appended[0], msg2)
self.assertFalse(diff.history_edited)
def test_diff_multiple_messages_appended(self):
"""Test diff detects multiple appended messages."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
msg3 = {"role": "user", "content": "How are you?"}
context1 = LLMContext(messages=[msg1])
context2 = LLMContext(messages=[msg1, msg2, msg3])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(len(diff.messages_appended), 2)
self.assertEqual(diff.messages_appended[0], msg2)
self.assertEqual(diff.messages_appended[1], msg3)
self.assertFalse(diff.history_edited)
def test_diff_message_removed(self):
"""Test diff detects message removal as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
context1 = LLMContext(messages=[msg1, msg2])
context2 = LLMContext(messages=[msg1])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.messages_appended, []) # Empty when history edited
self.assertTrue(diff.history_edited)
def test_diff_message_modified(self):
"""Test diff detects message modification as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2_v1 = {"role": "assistant", "content": "Hi!"}
msg2_v2 = {"role": "assistant", "content": "Hello there!"}
context1 = LLMContext(messages=[msg1, msg2_v1])
context2 = LLMContext(messages=[msg1, msg2_v2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.history_edited)
self.assertEqual(diff.messages_appended, [])
def test_diff_message_inserted_in_middle(self):
"""Test diff detects message insertion in middle as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
msg_inserted = {"role": "system", "content": "System message"}
context1 = LLMContext(messages=[msg1, msg2])
context2 = LLMContext(messages=[msg1, msg_inserted, msg2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.history_edited)
self.assertEqual(diff.messages_appended, [])
def test_diff_tool_call_resolved_to_result(self):
"""Test diff detects tool call resolution to actual result."""
msg1 = {"role": "user", "content": "What's the weather?"}
msg2 = {
"role": "assistant",
"tool_calls": [
{"id": "call_123", "function": {"name": "get_weather", "arguments": "{}"}}
],
}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
tool_resolved = {
"role": "tool",
"content": '{"temperature": 72}',
"tool_call_id": "call_123",
}
context1 = LLMContext(messages=[msg1, msg2, tool_in_progress])
context2 = LLMContext(messages=[msg1, msg2, tool_resolved])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_123"])
# Note: the tool message content changed, so history is edited
self.assertTrue(diff.history_edited)
def test_diff_tool_call_resolved_to_completed(self):
"""Test diff detects tool call resolution to COMPLETED."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_456"}
tool_completed = {"role": "tool", "content": "COMPLETED", "tool_call_id": "call_456"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_completed])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_456"])
def test_diff_tool_call_resolved_to_cancelled(self):
"""Test diff detects tool call resolution to CANCELLED."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_789"}
tool_cancelled = {"role": "tool", "content": "CANCELLED", "tool_call_id": "call_789"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_cancelled])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_789"])
def test_diff_tool_call_still_in_progress(self):
"""Test diff does not report tool call as resolved if still IN_PROGRESS."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_in_progress])
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, [])
def test_diff_tool_choice_changed(self):
"""Test diff detects tool_choice changes."""
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tool_choice="auto")
context2 = LLMContext(messages=[msg1], tool_choice="none")
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.tool_choice_changed)
def test_diff_tool_choice_unchanged(self):
"""Test diff reports no change when tool_choice is the same."""
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tool_choice="auto")
context2 = LLMContext(messages=[msg1], tool_choice="auto")
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tool_choice_changed)
def test_diff_empty_contexts(self):
"""Test diff of two empty contexts returns no changes."""
context1 = LLMContext()
context2 = LLMContext()
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
class TestLLMContextDiffWithTools(unittest.TestCase):
"""Tests for LLMContext.diff() with tools configuration changes."""
def _create_tools_schema(self, tool_names: list[str]) -> "ToolsSchema":
"""Helper to create a ToolsSchema with named tools."""
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
tools = [
FunctionSchema(name=name, description=f"Test {name}", properties={}, required=[])
for name in tool_names
]
return ToolsSchema(standard_tools=tools)
def test_diff_tools_added_from_not_given(self):
"""Test diff detects tools being added when self has no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
tools = self._create_tools_schema(["get_weather", "get_time"])
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
context2 = LLMContext(messages=[msg1], tools=tools)
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(sorted(diff.tools_diff.standard_tools_added), ["get_time", "get_weather"])
self.assertEqual(diff.tools_diff.standard_tools_removed, [])
def test_diff_tools_removed_to_not_given(self):
"""Test diff detects tools being removed when other has no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
tools = self._create_tools_schema(["get_weather", "get_time"])
context1 = LLMContext(messages=[msg1], tools=tools)
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tools_diff.standard_tools_added, [])
self.assertEqual(
sorted(diff.tools_diff.standard_tools_removed), ["get_time", "get_weather"]
)
def test_diff_both_not_given(self):
"""Test diff returns None tools_diff when both have no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tools_diff.has_changes())
def test_diff_tools_modified(self):
"""Test diff detects tool modification via ToolsSchema.diff()."""
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
msg1 = {"role": "user", "content": "Hello"}
tool_v1 = FunctionSchema(
name="get_weather",
description="Get weather v1",
properties={"location": {"type": "string"}},
required=["location"],
)
tool_v2 = FunctionSchema(
name="get_weather",
description="Get weather v2",
properties={"city": {"type": "string"}},
required=["city"],
)
context1 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v1]))
context2 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v2]))
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.tools_diff.standard_tools_modified)
def test_diff_tools_unchanged(self):
"""Test diff returns None tools_diff when tools are identical."""
msg1 = {"role": "user", "content": "Hello"}
tools1 = self._create_tools_schema(["get_weather"])
tools2 = self._create_tools_schema(["get_weather"])
context1 = LLMContext(messages=[msg1], tools=tools1)
context2 = LLMContext(messages=[msg1], tools=tools2)
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tools_diff.has_changes())
if __name__ == "__main__":
unittest.main()

View File

@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMTextFrame,
@@ -147,6 +148,52 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_transform(self):
context = LLMContext()
# Set up initial messages
context.set_messages(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
)
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that keeps only user messages
def keep_user_messages(messages):
return [m for m in messages if m["role"] == "user"]
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
await run_test(
pipeline,
frames_to_send=frames_to_send,
)
assert len(context.messages) == 2
assert context.messages[0]["content"] == "Hello"
assert context.messages[1]["content"] == "How are you?"
async def test_llm_messages_transform_run(self):
context = LLMContext()
# Set up initial messages
context.set_messages([{"role": "user", "content": "Hello"}])
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that modifies the content
def uppercase_content(messages):
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
expected_down_frames = [LLMContextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "HELLO"
async def test_default_user_turn_strategies(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(context)

184
tests/test_tools_schema.py Normal file
View File

@@ -0,0 +1,184 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
class TestToolsSchemaDiff(unittest.TestCase):
"""Tests for the ToolsSchemaDiff dataclass."""
def test_has_changes_empty(self):
"""Test has_changes returns False for empty diff."""
diff = ToolsSchemaDiff()
self.assertFalse(diff.has_changes())
def test_has_changes_with_added(self):
"""Test has_changes returns True when tools are added."""
diff = ToolsSchemaDiff(standard_tools_added=["tool1"])
self.assertTrue(diff.has_changes())
def test_has_changes_with_removed(self):
"""Test has_changes returns True when tools are removed."""
diff = ToolsSchemaDiff(standard_tools_removed=["tool1"])
self.assertTrue(diff.has_changes())
def test_has_changes_with_modified(self):
"""Test has_changes returns True when tools are modified."""
diff = ToolsSchemaDiff(standard_tools_modified=True)
self.assertTrue(diff.has_changes())
def test_has_changes_with_custom_changed(self):
"""Test has_changes returns True when custom tools changed."""
diff = ToolsSchemaDiff(custom_tools_changed=True)
self.assertTrue(diff.has_changes())
class TestToolsSchemaDiffMethod(unittest.TestCase):
"""Tests for the ToolsSchema.diff() method."""
def _create_function_schema(
self, name: str, description: str = "Test function", properties: dict = None
) -> FunctionSchema:
"""Helper to create a FunctionSchema."""
return FunctionSchema(
name=name,
description=description,
properties=properties or {},
required=[],
)
def test_diff_identical_schemas(self):
"""Test diff of two identical schemas returns no changes."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(standard_tools=[self._create_function_schema("get_weather")])
diff = schema1.diff(schema2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
self.assertFalse(diff.custom_tools_changed)
def test_diff_tool_added(self):
"""Test diff detects added tools."""
tool1 = self._create_function_schema("get_weather")
tool2 = self._create_function_schema("get_time")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(standard_tools=[tool1, tool2])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, ["get_time"])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
def test_diff_tool_removed(self):
"""Test diff detects removed tools."""
tool1 = self._create_function_schema("get_weather")
tool2 = self._create_function_schema("get_time")
schema1 = ToolsSchema(standard_tools=[tool1, tool2])
schema2 = ToolsSchema(standard_tools=[tool1])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, ["get_time"])
self.assertFalse(diff.standard_tools_modified)
def test_diff_tool_modified(self):
"""Test diff detects modified tools (same name, different definition)."""
tool1_v1 = self._create_function_schema(
"get_weather", description="Get weather v1", properties={"location": {"type": "string"}}
)
tool1_v2 = self._create_function_schema(
"get_weather",
description="Get weather v2",
properties={"city": {"type": "string"}},
)
schema1 = ToolsSchema(standard_tools=[tool1_v1])
schema2 = ToolsSchema(standard_tools=[tool1_v2])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertTrue(diff.standard_tools_modified)
def test_diff_multiple_changes(self):
"""Test diff with multiple types of changes."""
tool_keep = self._create_function_schema("keep_tool")
tool_remove = self._create_function_schema("remove_tool")
tool_add = self._create_function_schema("add_tool")
schema1 = ToolsSchema(standard_tools=[tool_keep, tool_remove])
schema2 = ToolsSchema(standard_tools=[tool_keep, tool_add])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, ["add_tool"])
self.assertEqual(diff.standard_tools_removed, ["remove_tool"])
def test_diff_empty_schemas(self):
"""Test diff of two empty schemas returns no changes."""
schema1 = ToolsSchema(standard_tools=[])
schema2 = ToolsSchema(standard_tools=[])
diff = schema1.diff(schema2)
self.assertFalse(diff.has_changes())
def test_diff_custom_tools_changed(self):
"""Test diff detects custom tools changes."""
tool1 = self._create_function_schema("get_weather")
custom1 = {AdapterType.GEMINI: [{"name": "search"}]}
custom2 = {AdapterType.GEMINI: [{"name": "search_v2"}]}
schema1 = ToolsSchema(standard_tools=[tool1], custom_tools=custom1)
schema2 = ToolsSchema(standard_tools=[tool1], custom_tools=custom2)
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
# Standard tools unchanged
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
def test_diff_custom_tools_added(self):
"""Test diff detects custom tools being added."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
)
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
def test_diff_custom_tools_removed(self):
"""Test diff detects custom tools being removed."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
)
schema2 = ToolsSchema(standard_tools=[tool1])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
if __name__ == "__main__":
unittest.main()