Compare commits

...

3 Commits

Author SHA1 Message Date
Paul Kompfner
bcaa3164f6 Add warning when detecting context messages referring to missing tools 2026-01-30 16:48:05 -05:00
Paul Kompfner
2c16faa662 WIP Gemini Live Pipecat Flows support 2026-01-30 11:15:19 -05:00
Paul Kompfner
bbd7f5e033 WIP Gemini Live Pipecat Flows support 2026-01-29 21:59:31 -05:00
13 changed files with 878 additions and 15 deletions

View File

@@ -129,4 +129,42 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
# Fallback to return the same tools in case they are not in a standard format
return tools
def _warn_about_orphaned_tool_messages(self, context: LLMContext) -> None:
"""Warn if context contains messages referencing tools that aren't currently available.
This can happen when tools are removed/deactivated but the conversation history
still contains function calls or tool responses for those tools. Such orphaned
messages may cause API errors from the LLM provider.
Args:
context: The LLM context to check.
"""
# Get the set of currently available tool names
available_tool_names: set[str] = set()
if isinstance(context.tools, ToolsSchema):
available_tool_names = {tool.name for tool in context.tools.standard_tools}
# Note: We don't check custom tools as they may have varying formats
# Track orphaned function names found in messages
orphaned_tools: set[str] = set()
for message in self.get_messages(context):
if isinstance(message, LLMSpecificMessage):
# Skip LLM-specific messages for now
continue
# Check for tool_calls in assistant messages
if message.get("tool_calls"):
for tc in message["tool_calls"]:
func_name = tc.get("function", {}).get("name")
if func_name and available_tool_names and func_name not in available_tool_names:
orphaned_tools.add(func_name)
# Log warning for orphaned messages
if orphaned_tools:
logger.warning(
f"Context contains references to tools that are no longer available: "
f"{sorted(orphaned_tools)}. This may cause unexpected behavior or API errors."
)
# TODO: we can move the logic to also handle the Messages here

View File

@@ -10,6 +10,7 @@ This module provides schemas for managing both standardized function tools
and custom adapter-specific tools in the Pipecat framework.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
@@ -17,6 +18,53 @@ from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunct
from pipecat.adapters.schemas.function_schema import FunctionSchema
@dataclass
class ToolsSchemaDiff:
"""Represents the differences between two ToolsSchema instances.
Parameters:
standard_tools_added: Names of newly added standard tools.
standard_tools_removed: Names of removed standard tools.
standard_tools_modified: True if any existing standard tool's definition changed.
custom_tools_changed: True if the custom_tools dictionary differs.
"""
standard_tools_added: List[str] = field(default_factory=list)
standard_tools_removed: List[str] = field(default_factory=list)
standard_tools_modified: bool = False
custom_tools_changed: bool = False
def has_changes(self) -> bool:
"""Check if there are any differences.
Returns:
True if any field indicates a change, False otherwise.
"""
return bool(
self.standard_tools_added
or self.standard_tools_removed
or self.standard_tools_modified
or self.custom_tools_changed
)
def change_description(self) -> str:
"""Generate a human-readable description of the differences.
Returns:
A string summarizing the changes.
"""
changes = []
if self.standard_tools_added:
changes.append(f"Added standard tools: {', '.join(self.standard_tools_added)}")
if self.standard_tools_removed:
changes.append(f"Removed standard tools: {', '.join(self.standard_tools_removed)}")
if self.standard_tools_modified:
changes.append("Modified definitions of existing standard tools")
if self.custom_tools_changed:
changes.append("Custom tools changed")
return "; ".join(changes) if changes else "No changes"
class AdapterType(Enum):
"""Supported adapter types for custom tools.
@@ -92,3 +140,44 @@ class ToolsSchema:
value: Dictionary mapping adapter types to their custom tool definitions.
"""
self._custom_tools = value
def diff(self, other: "ToolsSchema") -> ToolsSchemaDiff:
"""Compare this ToolsSchema to another and return the differences.
Args:
other: The ToolsSchema to compare against (the "after" state).
Returns:
ToolsSchemaDiff containing the differences between self and other.
"""
result = ToolsSchemaDiff()
# Build maps of tool name -> FunctionSchema for comparison
self_tools_by_name: Dict[str, FunctionSchema] = {
tool.name: tool for tool in self._standard_tools
}
other_tools_by_name: Dict[str, FunctionSchema] = {
tool.name: tool for tool in other._standard_tools
}
self_names = set(self_tools_by_name.keys())
other_names = set(other_tools_by_name.keys())
# Find added and removed tools
result.standard_tools_added = sorted(other_names - self_names)
result.standard_tools_removed = sorted(self_names - other_names)
# Check for modified tools (same name, different definition)
common_names = self_names & other_names
for name in common_names:
self_tool = self_tools_by_name[name]
other_tool = other_tools_by_name[name]
# Compare using to_default_dict() for full schema comparison
if self_tool.to_default_dict() != other_tool.to_default_dict():
result.standard_tools_modified = True
break
# Compare custom tools
result.custom_tools_changed = self._custom_tools != other._custom_tools
return result

View File

@@ -59,6 +59,9 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking Anthropic's LLM API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,

View File

@@ -83,6 +83,9 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking AWS Nova Sonic's LLM API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,

View File

@@ -56,6 +56,9 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking AWS Bedrock's LLM API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,

View File

@@ -62,6 +62,9 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
Returns:
Dictionary of parameters for Gemini's API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,

View File

@@ -59,6 +59,9 @@ class GrokRealtimeLLMAdapter(BaseLLMAdapter):
Returns:
Dictionary of parameters for invoking Grok's Voice Agent API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,

View File

@@ -60,6 +60,9 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
Returns:
Dictionary of parameters for OpenAI's ChatCompletion API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
return {
"messages": self._from_universal_context_messages(self.get_messages(context)),
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)

View File

@@ -54,6 +54,9 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
Returns:
Dictionary of parameters for invoking OpenAI Realtime's API.
"""
# Warn about orphaned tool-related messages
self._warn_about_orphaned_tool_messages(context)
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,

View File

@@ -18,8 +18,8 @@ import asyncio
import base64
import io
import wave
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union
from loguru import logger
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
@@ -30,7 +30,7 @@ from openai.types.chat import (
)
from PIL import Image
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
from pipecat.frames.frames import AudioRawFrame
if TYPE_CHECKING:
@@ -47,6 +47,58 @@ NOT_GIVEN = OPEN_AI_NOT_GIVEN
NotGiven = OpenAINotGiven
@dataclass
class LLMContextDiff:
"""Represents the differences between two LLMContext instances.
Parameters:
messages_appended: New messages appended at the end. Empty if history_edited is True.
history_edited: True if earlier messages were changed, inserted, or removed.
tool_calls_resolved: List of tool_call_ids that changed from "IN_PROGRESS" to a result.
tools_diff: Differences in tools configuration, or None if unchanged or both NOT_GIVEN.
tool_choice_changed: True if the tool_choice setting differs.
"""
messages_appended: List["LLMContextMessage"] = field(default_factory=list)
history_edited: bool = False
tool_calls_resolved: List[str] = field(default_factory=list)
tools_diff: ToolsSchemaDiff = field(default_factory=ToolsSchemaDiff)
tool_choice_changed: bool = False
def has_changes(self) -> bool:
"""Check if there are any differences.
Returns:
True if any field indicates a change, False otherwise.
"""
return bool(
self.messages_appended
or self.history_edited
or self.tool_calls_resolved
or self.tools_diff.has_changes()
or self.tool_choice_changed
)
def change_description(self) -> str:
"""Get a human-readable description of the changes.
Returns:
A string summarizing the changes.
"""
changes = []
if self.history_edited:
changes.append("history edited")
if self.messages_appended:
changes.append(f"{len(self.messages_appended)} messages appended")
if self.tool_calls_resolved:
changes.append(f"{len(self.tool_calls_resolved)} tool calls resolved")
if self.tools_diff.has_changes():
changes.append(f"tools changed: ({self.tools_diff.change_description()})")
if self.tool_choice_changed:
changes.append("tool choice changed")
return ", ".join(changes) if changes else "no changes"
@dataclass
class LLMSpecificMessage:
"""A container for a context message that is specific to a particular LLM service.
@@ -410,3 +462,138 @@ class LLMContext:
raise TypeError(
f"In LLMContext, tools must be a ToolsSchema object or NOT_GIVEN. Got type: {type(tools)}",
)
def diff(self, other: "LLMContext") -> LLMContextDiff:
"""Compare this context to another and return the differences.
Compares self (the "before" state) to other (the "after" state) and
identifies what has changed.
Args:
other: The LLMContext to compare against (the "after" state).
Returns:
ContextDiff containing the differences between self and other.
"""
result = LLMContextDiff()
# Compare messages
self_messages = self._messages
other_messages = other._messages
self_len = len(self_messages)
other_len = len(other_messages)
# Check if history was edited (messages removed, modified, or inserted in the middle)
if other_len < self_len:
# Messages were removed
result.history_edited = True
else:
# Check if the prefix matches (first self_len messages should be identical)
for i in range(self_len):
if not self._messages_equal(self_messages[i], other_messages[i]):
result.history_edited = True
break
# If history wasn't edited, capture appended messages
if not result.history_edited and other_len > self_len:
result.messages_appended = other_messages[self_len:]
# Find resolved tool calls (IN_PROGRESS -> something else)
result.tool_calls_resolved = self._find_resolved_tool_calls(other)
# Compare tools
result.tools_diff = self._compute_tools_diff(other)
# Compare tool_choice
# print("[pk] comparing tool choices: ", self._tool_choice, other._tool_choice, type(self._tool_choice), type(other._tool_choice), self._tool_choice == other._tool_choice, self._tool_choice != other._tool_choice)
# (For some reason if they're both NOT_GIVEN, equality check returns False?)
if not self._tool_choice and not other._tool_choice:
result.tool_choice_changed = False
else:
result.tool_choice_changed = self._tool_choice != other._tool_choice
return result
def _messages_equal(self, msg1: LLMContextMessage, msg2: LLMContextMessage) -> bool:
"""Compare two messages for equality.
Args:
msg1: First message to compare.
msg2: Second message to compare.
Returns:
True if the messages are equal, False otherwise.
"""
# Handle LLMSpecificMessage
if isinstance(msg1, LLMSpecificMessage) and isinstance(msg2, LLMSpecificMessage):
return msg1.llm == msg2.llm and msg1.message == msg2.message
elif isinstance(msg1, LLMSpecificMessage) or isinstance(msg2, LLMSpecificMessage):
return False
# Both are standard messages (dicts)
return msg1 == msg2
def _find_resolved_tool_calls(self, other: "LLMContext") -> List[str]:
"""Find tool calls that changed from IN_PROGRESS to a resolved state.
Args:
other: The context to compare against (the "after" state).
Returns:
List of tool_call_ids that were resolved.
"""
resolved: List[str] = []
# Build a map of tool_call_id -> content for "other" context
other_tool_contents: Dict[str, Any] = {}
for msg in other._messages:
if isinstance(msg, dict) and msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id")
if tool_call_id:
other_tool_contents[tool_call_id] = msg.get("content")
# Find tool messages in self that are IN_PROGRESS but resolved in other
for msg in self._messages:
if isinstance(msg, dict) and msg.get("role") == "tool":
if msg.get("content") == "IN_PROGRESS":
tool_call_id = msg.get("tool_call_id")
if tool_call_id and tool_call_id in other_tool_contents:
other_content = other_tool_contents[tool_call_id]
if other_content != "IN_PROGRESS":
resolved.append(tool_call_id)
return resolved
def _compute_tools_diff(self, other: "LLMContext") -> ToolsSchemaDiff:
"""Compute the difference in tools between self and other.
Args:
other: The context to compare against (the "after" state).
Returns:
ToolsSchemaDiff if there are changes, None if both are NOT_GIVEN or identical.
"""
self_has_tools = isinstance(self._tools, ToolsSchema)
other_has_tools = isinstance(other._tools, ToolsSchema)
if not self_has_tools and not other_has_tools:
# Both are NOT_GIVEN
return ToolsSchemaDiff()
if self_has_tools and other_has_tools:
# Both have tools - use ToolsSchema.diff()
diff = self._tools.diff(other._tools)
return diff
if not self_has_tools and other_has_tools:
# Tools were added (self is NOT_GIVEN, other has tools)
return ToolsSchemaDiff(
standard_tools_added=[tool.name for tool in other._tools.standard_tools],
custom_tools_changed=other._tools.custom_tools is not None,
)
# Tools were removed (self has tools, other is NOT_GIVEN)
return ToolsSchemaDiff(
standard_tools_removed=[tool.name for tool in self._tools.standard_tools],
custom_tools_changed=self._tools.custom_tools is not None,
)

View File

@@ -13,6 +13,7 @@ voice transcription, streaming responses, and tool usage.
import asyncio
import base64
import copy
import io
import time
import uuid
@@ -59,7 +60,7 @@ from pipecat.frames.frames import (
UserStoppedSpeakingFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext, LLMContextDiff
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -684,6 +685,7 @@ class GeminiLiveLLMService(LLMService):
self._audio_input_paused = start_audio_paused
self._video_input_paused = start_video_paused
self._context = None
self._context_snapshot = None
self._api_key = api_key
self._http_options = update_google_client_http_options(http_options)
self._session: AsyncSession = None
@@ -826,6 +828,9 @@ class GeminiLiveLLMService(LLMService):
logger.error("Context already set. Can only set up Gemini Live context once.")
return
self._context = GeminiLiveContext.upgrade(context)
self._context_snapshot = copy.deepcopy(
self._context
) # Take a static snapshot guaranteed not to change
await self._create_initial_response()
#
@@ -950,7 +955,13 @@ class GeminiLiveLLMService(LLMService):
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
elif isinstance(frame, LLMSetToolsFrame):
await self._update_settings()
# TODO: you are here - setting tools doesn't work yet (requires reconnection)
# Do we have reference to previous tools to compare?
# New tools should already have been set on the context by the user aggregator.
# LLMSetToolsFrame without a user aggregator is not supported.
# If tools have changed, update context snapshot.
# await self._update_settings()
pass
else:
await self.push_frame(frame, direction)
@@ -958,6 +969,9 @@ class GeminiLiveLLMService(LLMService):
if not self._context:
# We got our initial context
self._context = context
self._context_snapshot = copy.deepcopy(
context
) # Take a static snapshot guaranteed not to change
# If context contains system instruction or tools, reconnect in
# order to apply them.
@@ -1011,17 +1025,64 @@ class GeminiLiveLLMService(LLMService):
# We got an updated context.
self._context = context
# Here we assume that the updated context will contain either:
# - new messages (that the Gemini Live service, with its own
# context management, is already aware of), or
# - tool call results (that we need to tell the remote service
# about).
# (In the future, we could do more sophisticated diffing here,
# which would enable the user to programmatically manipulate the
# context).
diff = self._context_snapshot.diff(self._context)
self._context_snapshot = copy.deepcopy(
context
) # Take a static snapshot guaranteed not to change
# Send results for newly-completed function calls, if any.
await self._process_completed_function_calls(send_new_results=True)
if self._context_update_requires_reconnect(diff):
# Reconnect
print("[pk] Context update requires reconnect. Reconnecting...")
# TODO: necessary?
self._session_resumption_handle = None
# TODO: do something special here to handle the context like it's the initial one again?
# Reconnect
await self._reconnect()
# Initialize our bookkeeping of already-completed tool calls in
# the context
await self._process_completed_function_calls(send_new_results=False)
# Trigger "initial" response with new connection
await self._create_initial_response()
else:
# Send results for newly-completed function calls, if any.
print(
"[pk] Context update does not require reconnect. Sending any newly-completed function call results..."
)
await self._process_completed_function_calls(send_new_results=True)
def _context_update_requires_reconnect(self, diff: LLMContextDiff) -> bool:
"""Check if an update to our LLM context requires reconnection.
Args:
diff: The context diff representing the update
Returns:
True if reconnection is required, False otherwise.
"""
# During the course of a "normal" conversation with no external context
# manipulation (like adding tools, changing system instructions,
# rewriting history), context updates will contain either:
# - new messages (that the Gemini Live service, with its own internal
# context management, is already aware of), or
# - tool call results (that we need to tell the remote service
# about).
# Any other changes to the context require reconnection.
# TODO: the below
# Note that it's possible that the developer is trying to
# programmatically append messages, in which case we'd miss that
# update...maybe we can check the number of appended messages...
if diff.history_edited or diff.tools_diff.has_changes() or diff.tool_choice_changed:
if diff.history_edited:
print("[pk] Context diff: history edited.")
if diff.tools_diff.has_changes():
print("[pk] Context diff: tools changed.")
if diff.tool_choice_changed:
print("[pk] Context diff: tool choice changed.")
return True
async def _process_completed_function_calls(self, send_new_results: bool):
# Check for set of completed function calls in the context
@@ -1446,6 +1507,7 @@ class GeminiLiveLLMService(LLMService):
@traced_gemini_live(operation="llm_setup")
async def _handle_session_ready(self, session: AsyncSession):
print("[pk] Handling session ready...")
"""Handle the session being ready."""
self._session = session
# If we were just waititng for the session to be ready to run the LLM,

View File

@@ -1055,3 +1055,285 @@ class TestLLMAssistantAggregator(
0,
"Hello Pipecat. Here's some code: ```python\nprint('Hello, World!')\n``` ```javascript\nconsole.log('Hello, World!');\n``` And some more: ```html\n<div>Hello, World!</div>\n``` Hope that helps!",
)
class TestLLMContextDiff(unittest.TestCase):
"""Tests for the LLMContext.diff() method."""
def test_diff_identical_contexts(self):
"""Test diff of two identical contexts returns no changes."""
messages = [{"role": "user", "content": "Hello"}]
context1 = LLMContext(messages=messages.copy())
context2 = LLMContext(messages=messages.copy())
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.messages_appended, [])
self.assertFalse(diff.history_edited)
self.assertEqual(diff.tool_calls_resolved, [])
self.assertFalse(diff.tools_diff.has_changes())
self.assertFalse(diff.tool_choice_changed)
def test_diff_messages_appended(self):
"""Test diff detects appended messages."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi there!"}
context1 = LLMContext(messages=[msg1])
context2 = LLMContext(messages=[msg1, msg2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(len(diff.messages_appended), 1)
self.assertEqual(diff.messages_appended[0], msg2)
self.assertFalse(diff.history_edited)
def test_diff_multiple_messages_appended(self):
"""Test diff detects multiple appended messages."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
msg3 = {"role": "user", "content": "How are you?"}
context1 = LLMContext(messages=[msg1])
context2 = LLMContext(messages=[msg1, msg2, msg3])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(len(diff.messages_appended), 2)
self.assertEqual(diff.messages_appended[0], msg2)
self.assertEqual(diff.messages_appended[1], msg3)
self.assertFalse(diff.history_edited)
def test_diff_message_removed(self):
"""Test diff detects message removal as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
context1 = LLMContext(messages=[msg1, msg2])
context2 = LLMContext(messages=[msg1])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.messages_appended, []) # Empty when history edited
self.assertTrue(diff.history_edited)
def test_diff_message_modified(self):
"""Test diff detects message modification as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2_v1 = {"role": "assistant", "content": "Hi!"}
msg2_v2 = {"role": "assistant", "content": "Hello there!"}
context1 = LLMContext(messages=[msg1, msg2_v1])
context2 = LLMContext(messages=[msg1, msg2_v2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.history_edited)
self.assertEqual(diff.messages_appended, [])
def test_diff_message_inserted_in_middle(self):
"""Test diff detects message insertion in middle as history edit."""
msg1 = {"role": "user", "content": "Hello"}
msg2 = {"role": "assistant", "content": "Hi!"}
msg_inserted = {"role": "system", "content": "System message"}
context1 = LLMContext(messages=[msg1, msg2])
context2 = LLMContext(messages=[msg1, msg_inserted, msg2])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.history_edited)
self.assertEqual(diff.messages_appended, [])
def test_diff_tool_call_resolved_to_result(self):
"""Test diff detects tool call resolution to actual result."""
msg1 = {"role": "user", "content": "What's the weather?"}
msg2 = {
"role": "assistant",
"tool_calls": [
{"id": "call_123", "function": {"name": "get_weather", "arguments": "{}"}}
],
}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
tool_resolved = {
"role": "tool",
"content": '{"temperature": 72}',
"tool_call_id": "call_123",
}
context1 = LLMContext(messages=[msg1, msg2, tool_in_progress])
context2 = LLMContext(messages=[msg1, msg2, tool_resolved])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_123"])
# Note: the tool message content changed, so history is edited
self.assertTrue(diff.history_edited)
def test_diff_tool_call_resolved_to_completed(self):
"""Test diff detects tool call resolution to COMPLETED."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_456"}
tool_completed = {"role": "tool", "content": "COMPLETED", "tool_call_id": "call_456"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_completed])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_456"])
def test_diff_tool_call_resolved_to_cancelled(self):
"""Test diff detects tool call resolution to CANCELLED."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_789"}
tool_cancelled = {"role": "tool", "content": "CANCELLED", "tool_call_id": "call_789"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_cancelled])
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, ["call_789"])
def test_diff_tool_call_still_in_progress(self):
"""Test diff does not report tool call as resolved if still IN_PROGRESS."""
msg1 = {"role": "user", "content": "Do something"}
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
context1 = LLMContext(messages=[msg1, tool_in_progress])
context2 = LLMContext(messages=[msg1, tool_in_progress])
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.tool_calls_resolved, [])
def test_diff_tool_choice_changed(self):
"""Test diff detects tool_choice changes."""
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tool_choice="auto")
context2 = LLMContext(messages=[msg1], tool_choice="none")
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.tool_choice_changed)
def test_diff_tool_choice_unchanged(self):
"""Test diff reports no change when tool_choice is the same."""
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tool_choice="auto")
context2 = LLMContext(messages=[msg1], tool_choice="auto")
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tool_choice_changed)
def test_diff_empty_contexts(self):
"""Test diff of two empty contexts returns no changes."""
context1 = LLMContext()
context2 = LLMContext()
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
class TestLLMContextDiffWithTools(unittest.TestCase):
"""Tests for LLMContext.diff() with tools configuration changes."""
def _create_tools_schema(self, tool_names: list[str]) -> "ToolsSchema":
"""Helper to create a ToolsSchema with named tools."""
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
tools = [
FunctionSchema(name=name, description=f"Test {name}", properties={}, required=[])
for name in tool_names
]
return ToolsSchema(standard_tools=tools)
def test_diff_tools_added_from_not_given(self):
"""Test diff detects tools being added when self has no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
tools = self._create_tools_schema(["get_weather", "get_time"])
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
context2 = LLMContext(messages=[msg1], tools=tools)
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(sorted(diff.tools_diff.standard_tools_added), ["get_time", "get_weather"])
self.assertEqual(diff.tools_diff.standard_tools_removed, [])
def test_diff_tools_removed_to_not_given(self):
"""Test diff detects tools being removed when other has no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
tools = self._create_tools_schema(["get_weather", "get_time"])
context1 = LLMContext(messages=[msg1], tools=tools)
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.tools_diff.standard_tools_added, [])
self.assertEqual(
sorted(diff.tools_diff.standard_tools_removed), ["get_time", "get_weather"]
)
def test_diff_both_not_given(self):
"""Test diff returns None tools_diff when both have no tools."""
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
msg1 = {"role": "user", "content": "Hello"}
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tools_diff.has_changes())
def test_diff_tools_modified(self):
"""Test diff detects tool modification via ToolsSchema.diff()."""
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
msg1 = {"role": "user", "content": "Hello"}
tool_v1 = FunctionSchema(
name="get_weather",
description="Get weather v1",
properties={"location": {"type": "string"}},
required=["location"],
)
tool_v2 = FunctionSchema(
name="get_weather",
description="Get weather v2",
properties={"city": {"type": "string"}},
required=["city"],
)
context1 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v1]))
context2 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v2]))
diff = context1.diff(context2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.tools_diff.standard_tools_modified)
def test_diff_tools_unchanged(self):
"""Test diff returns None tools_diff when tools are identical."""
msg1 = {"role": "user", "content": "Hello"}
tools1 = self._create_tools_schema(["get_weather"])
tools2 = self._create_tools_schema(["get_weather"])
context1 = LLMContext(messages=[msg1], tools=tools1)
context2 = LLMContext(messages=[msg1], tools=tools2)
diff = context1.diff(context2)
self.assertFalse(diff.has_changes())
self.assertFalse(diff.tools_diff.has_changes())

184
tests/test_tools_schema.py Normal file
View File

@@ -0,0 +1,184 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
class TestToolsSchemaDiff(unittest.TestCase):
"""Tests for the ToolsSchemaDiff dataclass."""
def test_has_changes_empty(self):
"""Test has_changes returns False for empty diff."""
diff = ToolsSchemaDiff()
self.assertFalse(diff.has_changes())
def test_has_changes_with_added(self):
"""Test has_changes returns True when tools are added."""
diff = ToolsSchemaDiff(standard_tools_added=["tool1"])
self.assertTrue(diff.has_changes())
def test_has_changes_with_removed(self):
"""Test has_changes returns True when tools are removed."""
diff = ToolsSchemaDiff(standard_tools_removed=["tool1"])
self.assertTrue(diff.has_changes())
def test_has_changes_with_modified(self):
"""Test has_changes returns True when tools are modified."""
diff = ToolsSchemaDiff(standard_tools_modified=True)
self.assertTrue(diff.has_changes())
def test_has_changes_with_custom_changed(self):
"""Test has_changes returns True when custom tools changed."""
diff = ToolsSchemaDiff(custom_tools_changed=True)
self.assertTrue(diff.has_changes())
class TestToolsSchemaDiffMethod(unittest.TestCase):
"""Tests for the ToolsSchema.diff() method."""
def _create_function_schema(
self, name: str, description: str = "Test function", properties: dict = None
) -> FunctionSchema:
"""Helper to create a FunctionSchema."""
return FunctionSchema(
name=name,
description=description,
properties=properties or {},
required=[],
)
def test_diff_identical_schemas(self):
"""Test diff of two identical schemas returns no changes."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(standard_tools=[self._create_function_schema("get_weather")])
diff = schema1.diff(schema2)
self.assertFalse(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
self.assertFalse(diff.custom_tools_changed)
def test_diff_tool_added(self):
"""Test diff detects added tools."""
tool1 = self._create_function_schema("get_weather")
tool2 = self._create_function_schema("get_time")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(standard_tools=[tool1, tool2])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, ["get_time"])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
def test_diff_tool_removed(self):
"""Test diff detects removed tools."""
tool1 = self._create_function_schema("get_weather")
tool2 = self._create_function_schema("get_time")
schema1 = ToolsSchema(standard_tools=[tool1, tool2])
schema2 = ToolsSchema(standard_tools=[tool1])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, ["get_time"])
self.assertFalse(diff.standard_tools_modified)
def test_diff_tool_modified(self):
"""Test diff detects modified tools (same name, different definition)."""
tool1_v1 = self._create_function_schema(
"get_weather", description="Get weather v1", properties={"location": {"type": "string"}}
)
tool1_v2 = self._create_function_schema(
"get_weather",
description="Get weather v2",
properties={"city": {"type": "string"}},
)
schema1 = ToolsSchema(standard_tools=[tool1_v1])
schema2 = ToolsSchema(standard_tools=[tool1_v2])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertTrue(diff.standard_tools_modified)
def test_diff_multiple_changes(self):
"""Test diff with multiple types of changes."""
tool_keep = self._create_function_schema("keep_tool")
tool_remove = self._create_function_schema("remove_tool")
tool_add = self._create_function_schema("add_tool")
schema1 = ToolsSchema(standard_tools=[tool_keep, tool_remove])
schema2 = ToolsSchema(standard_tools=[tool_keep, tool_add])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertEqual(diff.standard_tools_added, ["add_tool"])
self.assertEqual(diff.standard_tools_removed, ["remove_tool"])
def test_diff_empty_schemas(self):
"""Test diff of two empty schemas returns no changes."""
schema1 = ToolsSchema(standard_tools=[])
schema2 = ToolsSchema(standard_tools=[])
diff = schema1.diff(schema2)
self.assertFalse(diff.has_changes())
def test_diff_custom_tools_changed(self):
"""Test diff detects custom tools changes."""
tool1 = self._create_function_schema("get_weather")
custom1 = {AdapterType.GEMINI: [{"name": "search"}]}
custom2 = {AdapterType.GEMINI: [{"name": "search_v2"}]}
schema1 = ToolsSchema(standard_tools=[tool1], custom_tools=custom1)
schema2 = ToolsSchema(standard_tools=[tool1], custom_tools=custom2)
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
# Standard tools unchanged
self.assertEqual(diff.standard_tools_added, [])
self.assertEqual(diff.standard_tools_removed, [])
self.assertFalse(diff.standard_tools_modified)
def test_diff_custom_tools_added(self):
"""Test diff detects custom tools being added."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(standard_tools=[tool1])
schema2 = ToolsSchema(
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
)
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
def test_diff_custom_tools_removed(self):
"""Test diff detects custom tools being removed."""
tool1 = self._create_function_schema("get_weather")
schema1 = ToolsSchema(
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
)
schema2 = ToolsSchema(standard_tools=[tool1])
diff = schema1.diff(schema2)
self.assertTrue(diff.has_changes())
self.assertTrue(diff.custom_tools_changed)
if __name__ == "__main__":
unittest.main()