Compare commits
3 Commits
pk/optiona
...
pk/flows-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcaa3164f6 | ||
|
|
2c16faa662 | ||
|
|
bbd7f5e033 |
@@ -129,4 +129,42 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
# Fallback to return the same tools in case they are not in a standard format
|
||||
return tools
|
||||
|
||||
def _warn_about_orphaned_tool_messages(self, context: LLMContext) -> None:
|
||||
"""Warn if context contains messages referencing tools that aren't currently available.
|
||||
|
||||
This can happen when tools are removed/deactivated but the conversation history
|
||||
still contains function calls or tool responses for those tools. Such orphaned
|
||||
messages may cause API errors from the LLM provider.
|
||||
|
||||
Args:
|
||||
context: The LLM context to check.
|
||||
"""
|
||||
# Get the set of currently available tool names
|
||||
available_tool_names: set[str] = set()
|
||||
if isinstance(context.tools, ToolsSchema):
|
||||
available_tool_names = {tool.name for tool in context.tools.standard_tools}
|
||||
# Note: We don't check custom tools as they may have varying formats
|
||||
|
||||
# Track orphaned function names found in messages
|
||||
orphaned_tools: set[str] = set()
|
||||
|
||||
for message in self.get_messages(context):
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Skip LLM-specific messages for now
|
||||
continue
|
||||
|
||||
# Check for tool_calls in assistant messages
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
func_name = tc.get("function", {}).get("name")
|
||||
if func_name and available_tool_names and func_name not in available_tool_names:
|
||||
orphaned_tools.add(func_name)
|
||||
|
||||
# Log warning for orphaned messages
|
||||
if orphaned_tools:
|
||||
logger.warning(
|
||||
f"Context contains references to tools that are no longer available: "
|
||||
f"{sorted(orphaned_tools)}. This may cause unexpected behavior or API errors."
|
||||
)
|
||||
|
||||
# TODO: we can move the logic to also handle the Messages here
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides schemas for managing both standardized function tools
|
||||
and custom adapter-specific tools in the Pipecat framework.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -17,6 +18,53 @@ from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunct
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsSchemaDiff:
|
||||
"""Represents the differences between two ToolsSchema instances.
|
||||
|
||||
Parameters:
|
||||
standard_tools_added: Names of newly added standard tools.
|
||||
standard_tools_removed: Names of removed standard tools.
|
||||
standard_tools_modified: True if any existing standard tool's definition changed.
|
||||
custom_tools_changed: True if the custom_tools dictionary differs.
|
||||
"""
|
||||
|
||||
standard_tools_added: List[str] = field(default_factory=list)
|
||||
standard_tools_removed: List[str] = field(default_factory=list)
|
||||
standard_tools_modified: bool = False
|
||||
custom_tools_changed: bool = False
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if there are any differences.
|
||||
|
||||
Returns:
|
||||
True if any field indicates a change, False otherwise.
|
||||
"""
|
||||
return bool(
|
||||
self.standard_tools_added
|
||||
or self.standard_tools_removed
|
||||
or self.standard_tools_modified
|
||||
or self.custom_tools_changed
|
||||
)
|
||||
|
||||
def change_description(self) -> str:
|
||||
"""Generate a human-readable description of the differences.
|
||||
|
||||
Returns:
|
||||
A string summarizing the changes.
|
||||
"""
|
||||
changes = []
|
||||
if self.standard_tools_added:
|
||||
changes.append(f"Added standard tools: {', '.join(self.standard_tools_added)}")
|
||||
if self.standard_tools_removed:
|
||||
changes.append(f"Removed standard tools: {', '.join(self.standard_tools_removed)}")
|
||||
if self.standard_tools_modified:
|
||||
changes.append("Modified definitions of existing standard tools")
|
||||
if self.custom_tools_changed:
|
||||
changes.append("Custom tools changed")
|
||||
return "; ".join(changes) if changes else "No changes"
|
||||
|
||||
|
||||
class AdapterType(Enum):
|
||||
"""Supported adapter types for custom tools.
|
||||
|
||||
@@ -92,3 +140,44 @@ class ToolsSchema:
|
||||
value: Dictionary mapping adapter types to their custom tool definitions.
|
||||
"""
|
||||
self._custom_tools = value
|
||||
|
||||
def diff(self, other: "ToolsSchema") -> ToolsSchemaDiff:
|
||||
"""Compare this ToolsSchema to another and return the differences.
|
||||
|
||||
Args:
|
||||
other: The ToolsSchema to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ToolsSchemaDiff containing the differences between self and other.
|
||||
"""
|
||||
result = ToolsSchemaDiff()
|
||||
|
||||
# Build maps of tool name -> FunctionSchema for comparison
|
||||
self_tools_by_name: Dict[str, FunctionSchema] = {
|
||||
tool.name: tool for tool in self._standard_tools
|
||||
}
|
||||
other_tools_by_name: Dict[str, FunctionSchema] = {
|
||||
tool.name: tool for tool in other._standard_tools
|
||||
}
|
||||
|
||||
self_names = set(self_tools_by_name.keys())
|
||||
other_names = set(other_tools_by_name.keys())
|
||||
|
||||
# Find added and removed tools
|
||||
result.standard_tools_added = sorted(other_names - self_names)
|
||||
result.standard_tools_removed = sorted(self_names - other_names)
|
||||
|
||||
# Check for modified tools (same name, different definition)
|
||||
common_names = self_names & other_names
|
||||
for name in common_names:
|
||||
self_tool = self_tools_by_name[name]
|
||||
other_tool = other_tools_by_name[name]
|
||||
# Compare using to_default_dict() for full schema comparison
|
||||
if self_tool.to_default_dict() != other_tool.to_default_dict():
|
||||
result.standard_tools_modified = True
|
||||
break
|
||||
|
||||
# Compare custom tools
|
||||
result.custom_tools_changed = self._custom_tools != other._custom_tools
|
||||
|
||||
return result
|
||||
|
||||
@@ -59,6 +59,9 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
|
||||
@@ -83,6 +83,9 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Nova Sonic's LLM API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
|
||||
@@ -56,6 +56,9 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
|
||||
@@ -62,6 +62,9 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
|
||||
@@ -59,6 +59,9 @@ class GrokRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Grok's Voice Agent API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
|
||||
@@ -60,6 +60,9 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
|
||||
@@ -54,6 +54,9 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking OpenAI Realtime's API.
|
||||
"""
|
||||
# Warn about orphaned tool-related messages
|
||||
self._warn_about_orphaned_tool_messages(context)
|
||||
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
|
||||
@@ -18,8 +18,8 @@ import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -30,7 +30,7 @@ from openai.types.chat import (
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
|
||||
from pipecat.frames.frames import AudioRawFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,6 +47,58 @@ NOT_GIVEN = OPEN_AI_NOT_GIVEN
|
||||
NotGiven = OpenAINotGiven
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextDiff:
|
||||
"""Represents the differences between two LLMContext instances.
|
||||
|
||||
Parameters:
|
||||
messages_appended: New messages appended at the end. Empty if history_edited is True.
|
||||
history_edited: True if earlier messages were changed, inserted, or removed.
|
||||
tool_calls_resolved: List of tool_call_ids that changed from "IN_PROGRESS" to a result.
|
||||
tools_diff: Differences in tools configuration, or None if unchanged or both NOT_GIVEN.
|
||||
tool_choice_changed: True if the tool_choice setting differs.
|
||||
"""
|
||||
|
||||
messages_appended: List["LLMContextMessage"] = field(default_factory=list)
|
||||
history_edited: bool = False
|
||||
tool_calls_resolved: List[str] = field(default_factory=list)
|
||||
tools_diff: ToolsSchemaDiff = field(default_factory=ToolsSchemaDiff)
|
||||
tool_choice_changed: bool = False
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if there are any differences.
|
||||
|
||||
Returns:
|
||||
True if any field indicates a change, False otherwise.
|
||||
"""
|
||||
return bool(
|
||||
self.messages_appended
|
||||
or self.history_edited
|
||||
or self.tool_calls_resolved
|
||||
or self.tools_diff.has_changes()
|
||||
or self.tool_choice_changed
|
||||
)
|
||||
|
||||
def change_description(self) -> str:
|
||||
"""Get a human-readable description of the changes.
|
||||
|
||||
Returns:
|
||||
A string summarizing the changes.
|
||||
"""
|
||||
changes = []
|
||||
if self.history_edited:
|
||||
changes.append("history edited")
|
||||
if self.messages_appended:
|
||||
changes.append(f"{len(self.messages_appended)} messages appended")
|
||||
if self.tool_calls_resolved:
|
||||
changes.append(f"{len(self.tool_calls_resolved)} tool calls resolved")
|
||||
if self.tools_diff.has_changes():
|
||||
changes.append(f"tools changed: ({self.tools_diff.change_description()})")
|
||||
if self.tool_choice_changed:
|
||||
changes.append("tool choice changed")
|
||||
return ", ".join(changes) if changes else "no changes"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSpecificMessage:
|
||||
"""A container for a context message that is specific to a particular LLM service.
|
||||
@@ -410,3 +462,138 @@ class LLMContext:
|
||||
raise TypeError(
|
||||
f"In LLMContext, tools must be a ToolsSchema object or NOT_GIVEN. Got type: {type(tools)}",
|
||||
)
|
||||
|
||||
def diff(self, other: "LLMContext") -> LLMContextDiff:
|
||||
"""Compare this context to another and return the differences.
|
||||
|
||||
Compares self (the "before" state) to other (the "after" state) and
|
||||
identifies what has changed.
|
||||
|
||||
Args:
|
||||
other: The LLMContext to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ContextDiff containing the differences between self and other.
|
||||
"""
|
||||
result = LLMContextDiff()
|
||||
|
||||
# Compare messages
|
||||
self_messages = self._messages
|
||||
other_messages = other._messages
|
||||
self_len = len(self_messages)
|
||||
other_len = len(other_messages)
|
||||
|
||||
# Check if history was edited (messages removed, modified, or inserted in the middle)
|
||||
if other_len < self_len:
|
||||
# Messages were removed
|
||||
result.history_edited = True
|
||||
else:
|
||||
# Check if the prefix matches (first self_len messages should be identical)
|
||||
for i in range(self_len):
|
||||
if not self._messages_equal(self_messages[i], other_messages[i]):
|
||||
result.history_edited = True
|
||||
break
|
||||
|
||||
# If history wasn't edited, capture appended messages
|
||||
if not result.history_edited and other_len > self_len:
|
||||
result.messages_appended = other_messages[self_len:]
|
||||
|
||||
# Find resolved tool calls (IN_PROGRESS -> something else)
|
||||
result.tool_calls_resolved = self._find_resolved_tool_calls(other)
|
||||
|
||||
# Compare tools
|
||||
result.tools_diff = self._compute_tools_diff(other)
|
||||
|
||||
# Compare tool_choice
|
||||
# print("[pk] comparing tool choices: ", self._tool_choice, other._tool_choice, type(self._tool_choice), type(other._tool_choice), self._tool_choice == other._tool_choice, self._tool_choice != other._tool_choice)
|
||||
# (For some reason if they're both NOT_GIVEN, equality check returns False?)
|
||||
if not self._tool_choice and not other._tool_choice:
|
||||
result.tool_choice_changed = False
|
||||
else:
|
||||
result.tool_choice_changed = self._tool_choice != other._tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _messages_equal(self, msg1: LLMContextMessage, msg2: LLMContextMessage) -> bool:
|
||||
"""Compare two messages for equality.
|
||||
|
||||
Args:
|
||||
msg1: First message to compare.
|
||||
msg2: Second message to compare.
|
||||
|
||||
Returns:
|
||||
True if the messages are equal, False otherwise.
|
||||
"""
|
||||
# Handle LLMSpecificMessage
|
||||
if isinstance(msg1, LLMSpecificMessage) and isinstance(msg2, LLMSpecificMessage):
|
||||
return msg1.llm == msg2.llm and msg1.message == msg2.message
|
||||
elif isinstance(msg1, LLMSpecificMessage) or isinstance(msg2, LLMSpecificMessage):
|
||||
return False
|
||||
|
||||
# Both are standard messages (dicts)
|
||||
return msg1 == msg2
|
||||
|
||||
def _find_resolved_tool_calls(self, other: "LLMContext") -> List[str]:
|
||||
"""Find tool calls that changed from IN_PROGRESS to a resolved state.
|
||||
|
||||
Args:
|
||||
other: The context to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
List of tool_call_ids that were resolved.
|
||||
"""
|
||||
resolved: List[str] = []
|
||||
|
||||
# Build a map of tool_call_id -> content for "other" context
|
||||
other_tool_contents: Dict[str, Any] = {}
|
||||
for msg in other._messages:
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
other_tool_contents[tool_call_id] = msg.get("content")
|
||||
|
||||
# Find tool messages in self that are IN_PROGRESS but resolved in other
|
||||
for msg in self._messages:
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
if msg.get("content") == "IN_PROGRESS":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if tool_call_id and tool_call_id in other_tool_contents:
|
||||
other_content = other_tool_contents[tool_call_id]
|
||||
if other_content != "IN_PROGRESS":
|
||||
resolved.append(tool_call_id)
|
||||
|
||||
return resolved
|
||||
|
||||
def _compute_tools_diff(self, other: "LLMContext") -> ToolsSchemaDiff:
|
||||
"""Compute the difference in tools between self and other.
|
||||
|
||||
Args:
|
||||
other: The context to compare against (the "after" state).
|
||||
|
||||
Returns:
|
||||
ToolsSchemaDiff if there are changes, None if both are NOT_GIVEN or identical.
|
||||
"""
|
||||
self_has_tools = isinstance(self._tools, ToolsSchema)
|
||||
other_has_tools = isinstance(other._tools, ToolsSchema)
|
||||
|
||||
if not self_has_tools and not other_has_tools:
|
||||
# Both are NOT_GIVEN
|
||||
return ToolsSchemaDiff()
|
||||
|
||||
if self_has_tools and other_has_tools:
|
||||
# Both have tools - use ToolsSchema.diff()
|
||||
diff = self._tools.diff(other._tools)
|
||||
return diff
|
||||
|
||||
if not self_has_tools and other_has_tools:
|
||||
# Tools were added (self is NOT_GIVEN, other has tools)
|
||||
return ToolsSchemaDiff(
|
||||
standard_tools_added=[tool.name for tool in other._tools.standard_tools],
|
||||
custom_tools_changed=other._tools.custom_tools is not None,
|
||||
)
|
||||
|
||||
# Tools were removed (self has tools, other is NOT_GIVEN)
|
||||
return ToolsSchemaDiff(
|
||||
standard_tools_removed=[tool.name for tool in self._tools.standard_tools],
|
||||
custom_tools_changed=self._tools.custom_tools is not None,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ voice transcription, streaming responses, and tool usage.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import time
|
||||
import uuid
|
||||
@@ -59,7 +60,7 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext, LLMContextDiff
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -684,6 +685,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._video_input_paused = start_video_paused
|
||||
self._context = None
|
||||
self._context_snapshot = None
|
||||
self._api_key = api_key
|
||||
self._http_options = update_google_client_http_options(http_options)
|
||||
self._session: AsyncSession = None
|
||||
@@ -826,6 +828,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
logger.error("Context already set. Can only set up Gemini Live context once.")
|
||||
return
|
||||
self._context = GeminiLiveContext.upgrade(context)
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
self._context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
await self._create_initial_response()
|
||||
|
||||
#
|
||||
@@ -950,7 +955,13 @@ class GeminiLiveLLMService(LLMService):
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self._update_settings()
|
||||
# TODO: you are here - setting tools doesn't work yet (requires reconnection)
|
||||
# Do we have reference to previous tools to compare?
|
||||
# New tools should already have been set on the context by the user aggregator.
|
||||
# LLMSetToolsFrame without a user aggregator is not supported.
|
||||
# If tools have changed, update context snapshot.
|
||||
# await self._update_settings()
|
||||
pass
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -958,6 +969,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
if not self._context:
|
||||
# We got our initial context
|
||||
self._context = context
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
|
||||
# If context contains system instruction or tools, reconnect in
|
||||
# order to apply them.
|
||||
@@ -1011,17 +1025,64 @@ class GeminiLiveLLMService(LLMService):
|
||||
# We got an updated context.
|
||||
self._context = context
|
||||
|
||||
# Here we assume that the updated context will contain either:
|
||||
# - new messages (that the Gemini Live service, with its own
|
||||
# context management, is already aware of), or
|
||||
# - tool call results (that we need to tell the remote service
|
||||
# about).
|
||||
# (In the future, we could do more sophisticated diffing here,
|
||||
# which would enable the user to programmatically manipulate the
|
||||
# context).
|
||||
diff = self._context_snapshot.diff(self._context)
|
||||
self._context_snapshot = copy.deepcopy(
|
||||
context
|
||||
) # Take a static snapshot guaranteed not to change
|
||||
|
||||
# Send results for newly-completed function calls, if any.
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
if self._context_update_requires_reconnect(diff):
|
||||
# Reconnect
|
||||
print("[pk] Context update requires reconnect. Reconnecting...")
|
||||
|
||||
# TODO: necessary?
|
||||
self._session_resumption_handle = None
|
||||
|
||||
# TODO: do something special here to handle the context like it's the initial one again?
|
||||
|
||||
# Reconnect
|
||||
await self._reconnect()
|
||||
|
||||
# Initialize our bookkeeping of already-completed tool calls in
|
||||
# the context
|
||||
await self._process_completed_function_calls(send_new_results=False)
|
||||
|
||||
# Trigger "initial" response with new connection
|
||||
await self._create_initial_response()
|
||||
else:
|
||||
# Send results for newly-completed function calls, if any.
|
||||
print(
|
||||
"[pk] Context update does not require reconnect. Sending any newly-completed function call results..."
|
||||
)
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
def _context_update_requires_reconnect(self, diff: LLMContextDiff) -> bool:
|
||||
"""Check if an update to our LLM context requires reconnection.
|
||||
|
||||
Args:
|
||||
diff: The context diff representing the update
|
||||
Returns:
|
||||
True if reconnection is required, False otherwise.
|
||||
"""
|
||||
# During the course of a "normal" conversation with no external context
|
||||
# manipulation (like adding tools, changing system instructions,
|
||||
# rewriting history), context updates will contain either:
|
||||
# - new messages (that the Gemini Live service, with its own internal
|
||||
# context management, is already aware of), or
|
||||
# - tool call results (that we need to tell the remote service
|
||||
# about).
|
||||
# Any other changes to the context require reconnection.
|
||||
# TODO: the below
|
||||
# Note that it's possible that the developer is trying to
|
||||
# programmatically append messages, in which case we'd miss that
|
||||
# update...maybe we can check the number of appended messages...
|
||||
if diff.history_edited or diff.tools_diff.has_changes() or diff.tool_choice_changed:
|
||||
if diff.history_edited:
|
||||
print("[pk] Context diff: history edited.")
|
||||
if diff.tools_diff.has_changes():
|
||||
print("[pk] Context diff: tools changed.")
|
||||
if diff.tool_choice_changed:
|
||||
print("[pk] Context diff: tool choice changed.")
|
||||
return True
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
# Check for set of completed function calls in the context
|
||||
@@ -1446,6 +1507,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
|
||||
@traced_gemini_live(operation="llm_setup")
|
||||
async def _handle_session_ready(self, session: AsyncSession):
|
||||
print("[pk] Handling session ready...")
|
||||
"""Handle the session being ready."""
|
||||
self._session = session
|
||||
# If we were just waititng for the session to be ready to run the LLM,
|
||||
|
||||
@@ -1055,3 +1055,285 @@ class TestLLMAssistantAggregator(
|
||||
0,
|
||||
"Hello Pipecat. Here's some code: ```python\nprint('Hello, World!')\n``` ```javascript\nconsole.log('Hello, World!');\n``` And some more: ```html\n<div>Hello, World!</div>\n``` Hope that helps!",
|
||||
)
|
||||
|
||||
|
||||
class TestLLMContextDiff(unittest.TestCase):
|
||||
"""Tests for the LLMContext.diff() method."""
|
||||
|
||||
def test_diff_identical_contexts(self):
|
||||
"""Test diff of two identical contexts returns no changes."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context1 = LLMContext(messages=messages.copy())
|
||||
context2 = LLMContext(messages=messages.copy())
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
self.assertFalse(diff.history_edited)
|
||||
self.assertEqual(diff.tool_calls_resolved, [])
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
self.assertFalse(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_messages_appended(self):
|
||||
"""Test diff detects appended messages."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi there!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1])
|
||||
context2 = LLMContext(messages=[msg1, msg2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(len(diff.messages_appended), 1)
|
||||
self.assertEqual(diff.messages_appended[0], msg2)
|
||||
self.assertFalse(diff.history_edited)
|
||||
|
||||
def test_diff_multiple_messages_appended(self):
|
||||
"""Test diff detects multiple appended messages."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
msg3 = {"role": "user", "content": "How are you?"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1])
|
||||
context2 = LLMContext(messages=[msg1, msg2, msg3])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(len(diff.messages_appended), 2)
|
||||
self.assertEqual(diff.messages_appended[0], msg2)
|
||||
self.assertEqual(diff.messages_appended[1], msg3)
|
||||
self.assertFalse(diff.history_edited)
|
||||
|
||||
def test_diff_message_removed(self):
|
||||
"""Test diff detects message removal as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2])
|
||||
context2 = LLMContext(messages=[msg1])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.messages_appended, []) # Empty when history edited
|
||||
self.assertTrue(diff.history_edited)
|
||||
|
||||
def test_diff_message_modified(self):
|
||||
"""Test diff detects message modification as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2_v1 = {"role": "assistant", "content": "Hi!"}
|
||||
msg2_v2 = {"role": "assistant", "content": "Hello there!"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2_v1])
|
||||
context2 = LLMContext(messages=[msg1, msg2_v2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.history_edited)
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
|
||||
def test_diff_message_inserted_in_middle(self):
|
||||
"""Test diff detects message insertion in middle as history edit."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
msg2 = {"role": "assistant", "content": "Hi!"}
|
||||
msg_inserted = {"role": "system", "content": "System message"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2])
|
||||
context2 = LLMContext(messages=[msg1, msg_inserted, msg2])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.history_edited)
|
||||
self.assertEqual(diff.messages_appended, [])
|
||||
|
||||
def test_diff_tool_call_resolved_to_result(self):
|
||||
"""Test diff detects tool call resolution to actual result."""
|
||||
msg1 = {"role": "user", "content": "What's the weather?"}
|
||||
msg2 = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_123", "function": {"name": "get_weather", "arguments": "{}"}}
|
||||
],
|
||||
}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
|
||||
tool_resolved = {
|
||||
"role": "tool",
|
||||
"content": '{"temperature": 72}',
|
||||
"tool_call_id": "call_123",
|
||||
}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, msg2, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, msg2, tool_resolved])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_123"])
|
||||
# Note: the tool message content changed, so history is edited
|
||||
self.assertTrue(diff.history_edited)
|
||||
|
||||
def test_diff_tool_call_resolved_to_completed(self):
|
||||
"""Test diff detects tool call resolution to COMPLETED."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_456"}
|
||||
tool_completed = {"role": "tool", "content": "COMPLETED", "tool_call_id": "call_456"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_completed])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_456"])
|
||||
|
||||
def test_diff_tool_call_resolved_to_cancelled(self):
|
||||
"""Test diff detects tool call resolution to CANCELLED."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_789"}
|
||||
tool_cancelled = {"role": "tool", "content": "CANCELLED", "tool_call_id": "call_789"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_cancelled])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, ["call_789"])
|
||||
|
||||
def test_diff_tool_call_still_in_progress(self):
|
||||
"""Test diff does not report tool call as resolved if still IN_PROGRESS."""
|
||||
msg1 = {"role": "user", "content": "Do something"}
|
||||
tool_in_progress = {"role": "tool", "content": "IN_PROGRESS", "tool_call_id": "call_123"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
context2 = LLMContext(messages=[msg1, tool_in_progress])
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.tool_calls_resolved, [])
|
||||
|
||||
def test_diff_tool_choice_changed(self):
|
||||
"""Test diff detects tool_choice changes."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
context2 = LLMContext(messages=[msg1], tool_choice="none")
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_tool_choice_unchanged(self):
|
||||
"""Test diff reports no change when tool_choice is the same."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
context2 = LLMContext(messages=[msg1], tool_choice="auto")
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tool_choice_changed)
|
||||
|
||||
def test_diff_empty_contexts(self):
|
||||
"""Test diff of two empty contexts returns no changes."""
|
||||
context1 = LLMContext()
|
||||
context2 = LLMContext()
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
|
||||
class TestLLMContextDiffWithTools(unittest.TestCase):
|
||||
"""Tests for LLMContext.diff() with tools configuration changes."""
|
||||
|
||||
def _create_tools_schema(self, tool_names: list[str]) -> "ToolsSchema":
|
||||
"""Helper to create a ToolsSchema with named tools."""
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
tools = [
|
||||
FunctionSchema(name=name, description=f"Test {name}", properties={}, required=[])
|
||||
for name in tool_names
|
||||
]
|
||||
return ToolsSchema(standard_tools=tools)
|
||||
|
||||
def test_diff_tools_added_from_not_given(self):
|
||||
"""Test diff detects tools being added when self has no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools = self._create_tools_schema(["get_weather", "get_time"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
context2 = LLMContext(messages=[msg1], tools=tools)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(sorted(diff.tools_diff.standard_tools_added), ["get_time", "get_weather"])
|
||||
self.assertEqual(diff.tools_diff.standard_tools_removed, [])
|
||||
|
||||
def test_diff_tools_removed_to_not_given(self):
|
||||
"""Test diff detects tools being removed when other has no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools = self._create_tools_schema(["get_weather", "get_time"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=tools)
|
||||
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.tools_diff.standard_tools_added, [])
|
||||
self.assertEqual(
|
||||
sorted(diff.tools_diff.standard_tools_removed), ["get_time", "get_weather"]
|
||||
)
|
||||
|
||||
def test_diff_both_not_given(self):
|
||||
"""Test diff returns None tools_diff when both have no tools."""
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
context2 = LLMContext(messages=[msg1], tools=NOT_GIVEN)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
|
||||
def test_diff_tools_modified(self):
|
||||
"""Test diff detects tool modification via ToolsSchema.diff()."""
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
|
||||
tool_v1 = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get weather v1",
|
||||
properties={"location": {"type": "string"}},
|
||||
required=["location"],
|
||||
)
|
||||
tool_v2 = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get weather v2",
|
||||
properties={"city": {"type": "string"}},
|
||||
required=["city"],
|
||||
)
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v1]))
|
||||
context2 = LLMContext(messages=[msg1], tools=ToolsSchema(standard_tools=[tool_v2]))
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.tools_diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tools_unchanged(self):
|
||||
"""Test diff returns None tools_diff when tools are identical."""
|
||||
msg1 = {"role": "user", "content": "Hello"}
|
||||
tools1 = self._create_tools_schema(["get_weather"])
|
||||
tools2 = self._create_tools_schema(["get_weather"])
|
||||
|
||||
context1 = LLMContext(messages=[msg1], tools=tools1)
|
||||
context2 = LLMContext(messages=[msg1], tools=tools2)
|
||||
|
||||
diff = context1.diff(context2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertFalse(diff.tools_diff.has_changes())
|
||||
|
||||
184
tests/test_tools_schema.py
Normal file
184
tests/test_tools_schema.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema, ToolsSchemaDiff
|
||||
|
||||
|
||||
class TestToolsSchemaDiff(unittest.TestCase):
|
||||
"""Tests for the ToolsSchemaDiff dataclass."""
|
||||
|
||||
def test_has_changes_empty(self):
|
||||
"""Test has_changes returns False for empty diff."""
|
||||
diff = ToolsSchemaDiff()
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_added(self):
|
||||
"""Test has_changes returns True when tools are added."""
|
||||
diff = ToolsSchemaDiff(standard_tools_added=["tool1"])
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_removed(self):
|
||||
"""Test has_changes returns True when tools are removed."""
|
||||
diff = ToolsSchemaDiff(standard_tools_removed=["tool1"])
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_modified(self):
|
||||
"""Test has_changes returns True when tools are modified."""
|
||||
diff = ToolsSchemaDiff(standard_tools_modified=True)
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
def test_has_changes_with_custom_changed(self):
|
||||
"""Test has_changes returns True when custom tools changed."""
|
||||
diff = ToolsSchemaDiff(custom_tools_changed=True)
|
||||
self.assertTrue(diff.has_changes())
|
||||
|
||||
|
||||
class TestToolsSchemaDiffMethod(unittest.TestCase):
|
||||
"""Tests for the ToolsSchema.diff() method."""
|
||||
|
||||
def _create_function_schema(
|
||||
self, name: str, description: str = "Test function", properties: dict = None
|
||||
) -> FunctionSchema:
|
||||
"""Helper to create a FunctionSchema."""
|
||||
return FunctionSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=[],
|
||||
)
|
||||
|
||||
def test_diff_identical_schemas(self):
|
||||
"""Test diff of two identical schemas returns no changes."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(standard_tools=[self._create_function_schema("get_weather")])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
self.assertFalse(diff.custom_tools_changed)
|
||||
|
||||
def test_diff_tool_added(self):
|
||||
"""Test diff detects added tools."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
tool2 = self._create_function_schema("get_time")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1, tool2])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, ["get_time"])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tool_removed(self):
|
||||
"""Test diff detects removed tools."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
tool2 = self._create_function_schema("get_time")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1, tool2])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, ["get_time"])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_tool_modified(self):
|
||||
"""Test diff detects modified tools (same name, different definition)."""
|
||||
tool1_v1 = self._create_function_schema(
|
||||
"get_weather", description="Get weather v1", properties={"location": {"type": "string"}}
|
||||
)
|
||||
tool1_v2 = self._create_function_schema(
|
||||
"get_weather",
|
||||
description="Get weather v2",
|
||||
properties={"city": {"type": "string"}},
|
||||
)
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1_v1])
|
||||
schema2 = ToolsSchema(standard_tools=[tool1_v2])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertTrue(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_multiple_changes(self):
|
||||
"""Test diff with multiple types of changes."""
|
||||
tool_keep = self._create_function_schema("keep_tool")
|
||||
tool_remove = self._create_function_schema("remove_tool")
|
||||
tool_add = self._create_function_schema("add_tool")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool_keep, tool_remove])
|
||||
schema2 = ToolsSchema(standard_tools=[tool_keep, tool_add])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertEqual(diff.standard_tools_added, ["add_tool"])
|
||||
self.assertEqual(diff.standard_tools_removed, ["remove_tool"])
|
||||
|
||||
def test_diff_empty_schemas(self):
|
||||
"""Test diff of two empty schemas returns no changes."""
|
||||
schema1 = ToolsSchema(standard_tools=[])
|
||||
schema2 = ToolsSchema(standard_tools=[])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertFalse(diff.has_changes())
|
||||
|
||||
def test_diff_custom_tools_changed(self):
|
||||
"""Test diff detects custom tools changes."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
custom1 = {AdapterType.GEMINI: [{"name": "search"}]}
|
||||
custom2 = {AdapterType.GEMINI: [{"name": "search_v2"}]}
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1], custom_tools=custom1)
|
||||
schema2 = ToolsSchema(standard_tools=[tool1], custom_tools=custom2)
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
# Standard tools unchanged
|
||||
self.assertEqual(diff.standard_tools_added, [])
|
||||
self.assertEqual(diff.standard_tools_removed, [])
|
||||
self.assertFalse(diff.standard_tools_modified)
|
||||
|
||||
def test_diff_custom_tools_added(self):
|
||||
"""Test diff detects custom tools being added."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
|
||||
schema1 = ToolsSchema(standard_tools=[tool1])
|
||||
schema2 = ToolsSchema(
|
||||
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
|
||||
)
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
|
||||
def test_diff_custom_tools_removed(self):
|
||||
"""Test diff detects custom tools being removed."""
|
||||
tool1 = self._create_function_schema("get_weather")
|
||||
|
||||
schema1 = ToolsSchema(
|
||||
standard_tools=[tool1], custom_tools={AdapterType.GEMINI: [{"name": "search"}]}
|
||||
)
|
||||
schema2 = ToolsSchema(standard_tools=[tool1])
|
||||
|
||||
diff = schema1.diff(schema2)
|
||||
self.assertTrue(diff.has_changes())
|
||||
self.assertTrue(diff.custom_tools_changed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user