Mitigate tool-call-related hallucination

When tools change mid-conversation, LLMs can produce a few different
flavors of tool-call-related hallucination: calling tools that have
been removed, avoiding tools that have been re-added, or hallucinating
output (made-up answers or tool-call-shaped non-tool-calls) when tools
are unavailable.

This change introduces an opt-in ``add_tool_change_messages`` flag on
the LLM aggregators (preferred entry point: ``LLMContextAggregatorPair(
..., add_tool_change_messages=True)``) that appends a developer-role
message to the context whenever ``LLMSetToolsFrame`` changes the set
of advertised standard tools. Helps the LLM stay coherent across tool
changes by spelling out exactly what just became available or
unavailable. Both aggregators participate; whichever handles the
frame first wins, and the other (if any) sees an empty diff against
the shared context and stays silent — order-independent regardless of
whether the frame flows downstream or upstream.

Also tightens the existing missing-handler path (introduced in #4301):

- Reworded the terminal tool result to a neutral "The function
  ``X`` is not currently available." (overridable via
  ``LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE``). Previously
  read "Error: function 'X' is not registered."
- Logs at the call site now distinguish developer error (tool
  advertised but no handler registered → ``logger.error``) from
  hallucination (tool not advertised → ``logger.warning``).

Includes a manual validation harness
(``examples/features/features-add-tool-change-messages.py``) that
exercises the new ``add_tool_change_messages`` mitigation by flipping
tool availability on a turn counter so its effect can be observed
end-to-end with the flag on vs. off.
This commit is contained in:
Paul Kompfner
2026-05-05 13:02:43 -04:00
parent a745e8d318
commit e06e0c0282
5 changed files with 745 additions and 15 deletions

View File

@@ -0,0 +1,232 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Manual validation harness for the ``add_tool_change_messages`` feature.
When tools change mid-conversation, LLMs can produce a few different
flavors of tool-call-related hallucination:
- **Forward hallucination** — calling a tool that has been removed.
- **Negative hallucination** — refusing to call a tool that has been
re-added (because recent context is full of "I can't" responses).
- **Hallucinated output when tools are unavailable** — making up an
answer rather than declining gracefully, or producing JSON that
*looks* like a tool call but is actually just an assistant text
response.
The ``add_tool_change_messages`` feature mitigates these by appending a
developer-role message to the conversation whenever ``LLMSetToolsFrame``
changes the set of advertised tools, so the LLM stays in sync with what's
actually available.
This harness exercises all of those flavors by flipping the advertised
tool set on a turn counter:
Phase 0 (turns 14): weather tool ACTIVE — confirm baseline.
Phase 1 (turns 58): tool REMOVED — keep asking for weather.
Phase 2 (turn 9+): tool RE-ADDED — does the LLM call it again?
Set ``ADD_TOOL_CHANGE_MESSAGES=0`` to disable the mitigation and see the
unmitigated behavior. The default is ON so a fresh run shows the feature
working.
Defaults to Llama 3.1 8B Instruct via a locally-running Ollama —
anecdotally one of the more hallucination-prone of the easily accessible
models. Pull the model once with ``ollama pull llama3.1:8b`` and make
sure ``ollama serve`` is running. Swap the LLM service to validate other
providers.
Run with::
uv run examples/features/features-add-tool-change-messages.py
ADD_TOOL_CHANGE_MESSAGES=0 uv run examples/features/features-add-tool-change-messages.py
"""
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.ollama.llm import OLLamaLLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# Default ON so a fresh run shows the feature working. Set to "0" to A/B
# against the unmitigated behavior.
ADD_TOOL_CHANGE_MESSAGES = os.environ.get("ADD_TOOL_CHANGE_MESSAGES", "1") == "1"
async def fetch_weather_from_api(params: FunctionCallParams):
await params.result_callback({"conditions": "nice", "temperature": "75"})
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
weather_tools = ToolsSchema(standard_tools=[weather_function])
transport_params = {
"daily": lambda: DailyParams(audio_in_enabled=True, audio_out_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True),
"webrtc": lambda: TransportParams(audio_in_enabled=True, audio_out_enabled=True),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(
f"Starting add_tool_change_messages demo bot "
f"(ADD_TOOL_CHANGE_MESSAGES={ADD_TOOL_CHANGE_MESSAGES})"
)
stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"])
tts = CartesiaTTSService(
api_key=os.environ["CARTESIA_API_KEY"],
settings=CartesiaTTSService.Settings(
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
),
)
llm = OLLamaLLMService(
settings=OLLamaLLMService.Settings(
# Llama 3.1 8B Instruct is anecdotally one of the more
# hallucination-prone of the easily accessible models — exactly
# what we want for this validation harness. Pull it with
# ``ollama pull llama3.1:8b`` and make sure ``ollama serve``
# is running.
model="llama3.1:8b",
system_instruction=(
"You are a helpful assistant in a voice conversation. Your responses "
"will be spoken aloud, so avoid emojis, bullet points, or other "
"formatting that can't be spoken. Respond briefly and naturally. "
"If the user asks for the current weather, use the `get_current_weather` "
"function if it's available. IMPORTANT: if you do not have access to the function, "
"say something along the lines of 'Sorry, I can't check the weather right now.'."
),
),
)
llm.register_function("get_current_weather", fetch_weather_from_api)
context = LLMContext(tools=weather_tools)
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
add_tool_change_messages=ADD_TOOL_CHANGE_MESSAGES,
)
pipeline = Pipeline(
[
transport.input(),
stt,
user_aggregator,
llm,
tts,
transport.output(),
assistant_aggregator,
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(enable_metrics=True, enable_usage_metrics=True),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
# Phase controller: roughly 4 turns per phase.
user_turn_count = 0
REMOVE_AT_TURN = 5 # tool gone for turn N onward
READD_AT_TURN = 9 # tool back for turn N onward
@user_aggregator.event_handler("on_user_turn_stopped")
async def on_user_turn_stopped(aggregator, strategy, message):
nonlocal user_turn_count
user_turn_count += 1
logger.info(f"=== User turn {user_turn_count} complete ===")
if user_turn_count == REMOVE_AT_TURN - 1:
logger.info(
"=== Phase 1: weather tool REMOVED. Keep asking about the weather "
"to exercise hallucination scenarios. ==="
)
await task.queue_frame(LLMSetToolsFrame(tools=NOT_GIVEN))
elif user_turn_count == READD_AT_TURN - 1:
logger.info(
"=== Phase 2: weather tool RE-ADDED. Ask for the weather again — "
"does the LLM call it, or keep refusing? (THIS IS THE TEST.) ==="
)
await task.queue_frame(LLMSetToolsFrame(tools=weather_tools))
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected")
logger.info(
"=== Phase 0: weather tool ACTIVE. Ask for the weather a few times "
"to confirm it's working. ==="
)
context.add_message(
{
"role": "developer",
"content": (
"Please introduce yourself briefly to the user, then invite them "
"to ask about the weather."
),
}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -72,6 +72,7 @@ from pipecat.processors.aggregators.llm_context import (
LLMContextMessage,
LLMSpecificMessage,
NotGiven,
is_given,
)
from pipecat.processors.aggregators.llm_context_summarizer import (
LLMContextSummarizer,
@@ -118,6 +119,18 @@ class LLMUserAggregatorParams:
user_turn_completion_config: Configuration for turn completion behavior including
custom instructions, timeouts, and prompts. Only used when
filter_incomplete_user_turns is True.
add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the
aggregator computes the diff against the currently advertised tools
and appends a developer-role message to the context describing
additions/removals. Helps the LLM stay coherent across
mid-conversation tool changes, mitigating several flavors of
tool-call-related hallucination: calling tools that have been
removed, avoiding tools that have been re-added, and hallucinating
output (made-up answers or tool-call-shaped non-tool-calls) when
tools are unavailable. Only standard tools are diffed; custom
(LLM-specific) tools are ignored. When using
``LLMContextAggregatorPair``, prefer setting this via its
``add_tool_change_messages`` argument instead. Defaults to False.
"""
user_turn_strategies: UserTurnStrategies | None = None
@@ -128,6 +141,7 @@ class LLMUserAggregatorParams:
audio_idle_timeout: float = 1.0
filter_incomplete_user_turns: bool = False
user_turn_completion_config: UserTurnCompletionConfig | None = None
add_tool_change_messages: bool = False
@dataclass
@@ -143,10 +157,23 @@ class LLMAssistantAggregatorParams:
summarization. Controls trigger thresholds, message preservation, and
summarization prompts. If None, uses default
``LLMAutoContextSummarizationConfig`` values.
add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the
aggregator computes the diff against the currently advertised tools
and appends a developer-role message to the context describing
additions/removals. Helps the LLM stay coherent across
mid-conversation tool changes, mitigating several flavors of
tool-call-related hallucination: calling tools that have been
removed, avoiding tools that have been re-added, and hallucinating
output (made-up answers or tool-call-shaped non-tool-calls) when
tools are unavailable. Only standard tools are diffed; custom
(LLM-specific) tools are ignored. When using
``LLMContextAggregatorPair``, prefer setting this via its
``add_tool_change_messages`` argument instead. Defaults to False.
"""
enable_auto_context_summarization: bool = False
auto_context_summarization_config: LLMAutoContextSummarizationConfig | None = None
add_tool_change_messages: bool = False
# ---------------------------------------------------------------------------
# Deprecated field names — kept for backward compatibility.
@@ -248,20 +275,87 @@ class LLMContextAggregator(FrameProcessor):
common functionality for context-based conversation management.
"""
def __init__(self, *, context: LLMContext, role: str, **kwargs):
# Developer-role messages appended to the context when tools are added/
# removed via ``LLMSetToolsFrame`` (only when ``add_tool_change_messages``
# is enabled on the aggregator's params). ``{function_names}`` is
# substituted with a sorted, comma-separated, backtick-wrapped list.
TOOL_ACTIVATION_MESSAGE_TEMPLATE = (
"The following function(s) have just been added and may now be called: "
"{function_names}. Any previously available functions remain available."
)
TOOL_DEACTIVATION_MESSAGE_TEMPLATE = (
"The following function(s) have just been removed and should not be called: "
"{function_names}. Any previously available functions remain available. "
"The removed function(s) may become available again later, in which case "
"you will be informed."
)
def __init__(
self,
*,
context: LLMContext,
role: str,
add_tool_change_messages: bool = False,
**kwargs,
):
"""Initialize the context response aggregator.
Args:
context: The LLM context to use for conversation storage.
role: The role this aggregator represents (e.g. "user", "assistant").
add_tool_change_messages: See the field of the same name on the
aggregator-specific params dataclasses. Subclasses propagate
this from their ``params``.
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self._context = context
self._role = role
self._add_tool_change_messages = add_tool_change_messages
self._aggregation: list[TextPartForConcatenation] = []
def _maybe_add_tool_change_messages(self, new_tools: ToolsSchema | NotGiven) -> None:
"""Append a developer message describing tool add/remove deltas.
No-op unless ``add_tool_change_messages`` was enabled on the aggregator,
and no-op when the diff against the currently advertised tools is empty.
Custom (LLM-specific) tools are ignored — only standard tools are diffed.
Both aggregators call this on every ``LLMSetToolsFrame`` they handle.
Whichever aggregator handles the frame first computes a real diff
against the shared context and adds the announcement; by the time
the other aggregator sees it (if at all), the context already
reflects the new tools, so its diff is empty and no duplicate
message is added. This is order-independent: it works whether the
frame flows downstream (user aggregator first) or upstream
(assistant aggregator first, and consumed without being forwarded).
"""
if not self._add_tool_change_messages:
return
def _names(tools: ToolsSchema | NotGiven) -> set[str]:
if not is_given(tools):
return set()
return {s.name for s in tools.standard_tools}
old_names = _names(self._context.tools)
new_names = _names(new_tools)
added = new_names - old_names
removed = old_names - new_names
if not added and not removed:
return
parts: list[str] = []
if added:
names = ", ".join(f"`{n}`" for n in sorted(added))
parts.append(self.TOOL_ACTIVATION_MESSAGE_TEMPLATE.format(function_names=names))
if removed:
names = ", ".join(f"`{n}`" for n in sorted(removed))
parts.append(self.TOOL_DEACTIVATION_MESSAGE_TEMPLATE.format(function_names=names))
self._context.add_message({"role": "developer", "content": " ".join(parts)})
@property
def messages(self) -> list[LLMContextMessage]:
"""Get messages from the LLM context.
@@ -434,8 +528,14 @@ class LLMUserAggregator(LLMContextAggregator):
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments.
"""
super().__init__(context=context, role="user", **kwargs)
self._params = params or LLMUserAggregatorParams()
params = params or LLMUserAggregatorParams()
super().__init__(
context=context,
role="user",
add_tool_change_messages=params.add_tool_change_messages,
**kwargs,
)
self._params = params
self._register_event_handler("on_user_turn_started")
self._register_event_handler("on_user_turn_stopped")
@@ -536,6 +636,7 @@ class LLMUserAggregator(LLMContextAggregator):
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self._maybe_add_tool_change_messages(frame.tools)
self.set_tools(frame.tools)
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
# services (like OpenAI Realtime) may need to know about tool
@@ -843,8 +944,14 @@ class LLMAssistantAggregator(LLMContextAggregator):
params: Configuration parameters for aggregation behavior.
**kwargs: Additional arguments.
"""
super().__init__(context=context, role="assistant", **kwargs)
self._params = params or LLMAssistantAggregatorParams()
params = params or LLMAssistantAggregatorParams()
super().__init__(
context=context,
role="assistant",
add_tool_change_messages=params.add_tool_change_messages,
**kwargs,
)
self._params = params
self._function_calls_in_progress: dict[str, FunctionCallInProgressFrame | None] = {}
self._function_calls_image_results: dict[str, UserImageRawFrame] = {}
@@ -949,6 +1056,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self._maybe_add_tool_change_messages(frame.tools)
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
@@ -1478,6 +1586,7 @@ class LLMContextAggregatorPair:
*,
user_params: LLMUserAggregatorParams | None = None,
assistant_params: LLMAssistantAggregatorParams | None = None,
add_tool_change_messages: bool | None = None,
):
"""Initialize the LLM context aggregator pair.
@@ -1485,9 +1594,22 @@ class LLMContextAggregatorPair:
context: The context to be managed by the aggregators.
user_params: Parameters for the user context aggregator.
assistant_params: Parameters for the assistant context aggregator.
add_tool_change_messages: When provided, sets the field of the
same name on both ``user_params`` and ``assistant_params``,
overriding any value already set on either. This is the
preferred way to enable tool-change announcements: it ensures
both aggregators participate, which makes the feature robust
regardless of which aggregator handles a given
``LLMSetToolsFrame``. The shared context guarantees the
announcement is added exactly once (the second aggregator's
diff is empty by the time it sees the frame). Leave as
``None`` to respect per-params settings.
"""
user_params = user_params or LLMUserAggregatorParams()
assistant_params = assistant_params or LLMAssistantAggregatorParams()
if add_tool_change_messages is not None:
user_params.add_tool_change_messages = add_tool_change_messages
assistant_params.add_tool_change_messages = add_tool_change_messages
self._user = LLMUserAggregator(context, params=user_params)
self._assistant = LLMAssistantAggregator(context, params=assistant_params)

View File

@@ -53,6 +53,7 @@ from pipecat.frames.frames import (
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMSpecificMessage,
is_given,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
@@ -243,6 +244,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
# However, subclasses should override this with a more specific adapter when necessary.
adapter_class: type[BaseLLMAdapter] = OpenAILLMAdapter
# Returned to the LLM as the tool result when an unavailable function is
# called. Deliberately neutral about future availability so the LLM can
# pick the function up again if it returns (e.g. via the
# ``add_tool_change_messages`` activation message, or silently on a
# later inference). ``{function_name}`` is substituted at runtime.
MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE = (
"The function `{function_name}` is not currently available."
)
def __init__(
self,
run_in_parallel: bool = True,
@@ -764,9 +774,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
elif None in self._functions.keys():
item = self._functions[None]
else:
logger.warning(
f"{self} is calling '{function_call.function_name}', but it's not registered."
)
self._log_missing_function_call(function_call.function_name, function_call.context)
item = self._build_missing_function_call_registry_item(function_call.function_name)
runner_items.append(
@@ -835,8 +843,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
elif runner_item.registry_item.handler == self._missing_function_call_handler:
item = runner_item.registry_item
else:
# Function was unregistered between queue and execution; the
# registry-item-handler check above already covered the
# missing-from-the-start case.
logger.warning(
f"{self} is calling '{runner_item.function_name}', but it was just unregistered."
f"{self}: '{runner_item.function_name}' was just unregistered "
f"between queueing and execution."
)
item = self._build_missing_function_call_registry_item(runner_item.function_name)
@@ -962,7 +974,45 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
async def _missing_function_call_handler(self, params: FunctionCallParams):
"""Return a terminal tool result when the LLM calls an unknown function."""
await params.result_callback(f"Error: function '{params.function_name}' is not registered.")
await params.result_callback(
self.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=params.function_name)
)
@staticmethod
def _advertised_tool_names(context) -> set[str]:
"""Return the set of standard-tool names currently advertised to the LLM.
Custom (LLM-specific) tools are not included, since they have no
consistent name field across adapters.
"""
tools = context.tools if context is not None else None
if tools is None or not is_given(tools):
return set()
return {t.name for t in tools.standard_tools}
def _log_missing_function_call(self, function_name: str, context) -> None:
"""Log an appropriate message when a tool is called with no handler.
Distinguishes two cases:
- **Developer error:** the tool is advertised to the LLM but no handler
was registered (likely a missed ``register_function`` call). Logged
at error level since this almost always indicates a bug.
- **Hallucination:** the tool is not in the currently advertised tool
set. Logged at warning level since this is model behavior the
application can do little about beyond returning a terminal result.
"""
if function_name in self._advertised_tool_names(context):
logger.error(
f"{self}: tool '{function_name}' is advertised to the LLM "
f"but has no registered handler — did you forget to call "
f"register_function()?"
)
else:
logger.warning(
f"{self}: LLM called '{function_name}', which is not in the "
f"currently advertised tool set."
)
def _has_async_tools(self) -> bool:
"""Return True if at least one non-builtin async tool is registered."""

View File

@@ -7,6 +7,8 @@
import json
import unittest
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
@@ -25,6 +27,7 @@ from pipecat.frames.frames import (
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMSetToolsFrame,
LLMTextFrame,
LLMThoughtEndFrame,
LLMThoughtStartFrame,
@@ -46,6 +49,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
AssistantThoughtMessage,
AssistantTurnStoppedMessage,
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregator,
LLMUserAggregatorParams,
)
@@ -1167,5 +1172,204 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
assert context.messages[0]["content"] == "HELLO"
def _function_schema(name: str) -> FunctionSchema:
return FunctionSchema(name=name, description="", properties={}, required=[])
def _tools(*names: str) -> ToolsSchema:
return ToolsSchema(standard_tools=[_function_schema(n) for n in names])
def _developer_messages(context: LLMContext) -> list[str]:
return [
m["content"]
for m in context.messages
if isinstance(m, dict) and m.get("role") == "developer"
]
class TestToolChangeMessages(unittest.IsolatedAsyncioTestCase):
"""Coverage for the opt-in ``add_tool_change_messages`` feature.
The feature appends a developer-role message to the context whenever
``LLMSetToolsFrame`` changes the set of advertised standard tools.
"""
async def _send_set_tools_to_user_aggregator(self, aggregator, tools):
# User aggregator forwards LLMSetToolsFrame downstream, so we expect
# the SpeechControlParamsFrame (emitted on StartFrame) and the
# forwarded LLMSetToolsFrame.
await run_test(
aggregator,
frames_to_send=[LLMSetToolsFrame(tools=tools)],
expected_down_frames=[SpeechControlParamsFrame, LLMSetToolsFrame],
)
async def test_default_off_adds_no_message(self):
context = LLMContext(tools=_tools("a"))
aggregator = LLMUserAggregator(context)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
self.assertEqual(_developer_messages(context), [])
async def test_user_aggregator_announces_additions(self):
context = LLMContext(tools=_tools("a"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b", "c"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`b`", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertNotIn("removed", msgs[0])
# Sorted, stable order
self.assertLess(msgs[0].index("`b`"), msgs[0].index("`c`"))
async def test_user_aggregator_announces_removals(self):
context = LLMContext(tools=_tools("a", "b", "c"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been removed", msgs[0])
self.assertIn("`b`", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertNotIn("just been added", msgs[0])
async def test_user_aggregator_combined_add_and_remove(self):
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("b", "c"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertIn("just been removed", msgs[0])
self.assertIn("`a`", msgs[0])
# Activation phrase appears before deactivation phrase.
self.assertLess(msgs[0].index("just been added"), msgs[0].index("just been removed"))
async def test_no_message_when_diff_is_empty(self):
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
self.assertEqual(_developer_messages(context), [])
async def test_set_tools_to_not_given_lists_all_as_removed(self):
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, NOT_GIVEN)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been removed", msgs[0])
self.assertIn("`a`", msgs[0])
self.assertIn("`b`", msgs[0])
async def test_set_tools_from_not_given_lists_all_as_added(self):
context = LLMContext() # tools default to NOT_GIVEN
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("x", "y"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`x`", msgs[0])
self.assertIn("`y`", msgs[0])
async def test_custom_tools_only_change_no_message(self):
# Standard tools identical; only custom tools differ → no announcement.
context = LLMContext(
tools=ToolsSchema(
standard_tools=[_function_schema("a")],
custom_tools={AdapterType.OPENAI: [{"type": "web_search"}]},
)
)
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
new_tools = ToolsSchema(
standard_tools=[_function_schema("a")],
custom_tools={AdapterType.OPENAI: [{"type": "file_search"}]},
)
await self._send_set_tools_to_user_aggregator(aggregator, new_tools)
self.assertEqual(_developer_messages(context), [])
async def test_pipeline_with_both_aggregators_announces_once(self):
"""User agg runs first; assistant agg sees no diff and stays silent."""
context = LLMContext(tools=_tools("a"))
user, assistant = LLMContextAggregatorPair(context, add_tool_change_messages=True)
pipeline = Pipeline([user, assistant])
# The user aggregator forwards LLMSetToolsFrame downstream; the
# assistant aggregator consumes it (does not forward).
await run_test(
pipeline,
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
expected_down_frames=[SpeechControlParamsFrame],
)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1, f"expected exactly one announcement, got {msgs}")
self.assertIn("`b`", msgs[0])
async def test_assistant_aggregator_announces_when_handled_first(self):
"""Order-independence: an upstream LLMSetToolsFrame hits the assistant
aggregator first (before being consumed). It should announce, and the
user aggregator (which never sees it) shouldn't matter for correctness.
"""
context = LLMContext(tools=_tools("a"))
assistant = LLMAssistantAggregator(
context,
params=LLMAssistantAggregatorParams(add_tool_change_messages=True),
)
# Send the frame upstream so the assistant aggregator processes it.
await run_test(
assistant,
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=[],
)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("`b`", msgs[0])
async def test_pair_propagates_flag_to_both(self):
context = LLMContext()
pair = LLMContextAggregatorPair(context, add_tool_change_messages=True)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertTrue(pair.assistant()._add_tool_change_messages)
async def test_pair_arg_overrides_per_params_settings(self):
context = LLMContext()
pair = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(add_tool_change_messages=False),
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
add_tool_change_messages=True,
)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertTrue(pair.assistant()._add_tool_change_messages)
async def test_pair_default_respects_per_params(self):
context = LLMContext()
pair = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(add_tool_change_messages=True),
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertFalse(pair.assistant()._add_tool_change_messages)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,6 +8,8 @@ import unittest
from unittest.mock import AsyncMock, patch
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
FunctionCallFromLLM,
@@ -21,6 +23,10 @@ from pipecat.services.settings import LLMSettings
from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy
def _expected_missing_tool_message(name: str) -> str:
return LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=name)
class MockLLMService(LLMService):
"""Minimal LLM service for testing function call execution."""
@@ -104,13 +110,14 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
self.assertEqual(recorded_frames[1].function_name, "missing_tool")
self.assertEqual(
recorded_frames[2].result,
"Error: function 'missing_tool' is not registered.",
_expected_missing_tool_message("missing_tool"),
)
# Only the queue-time warning should fire; the execution-time
# "just unregistered" warning must not double-log.
# The tool was not advertised, so this is treated as a hallucination
# (warning at queue time). The execution-time "just unregistered"
# warning must not double-log.
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertTrue(any("not registered" in w for w in warnings))
self.assertTrue(any("not in the currently advertised tool set" in w for w in warnings))
self.assertFalse(any("just unregistered" in w for w in warnings))
async def test_function_unregistered_between_queue_and_execute(self):
@@ -160,9 +167,124 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
)
self.assertEqual(
recorded_frames[2].result,
"Error: function 'doomed_tool' is not registered.",
_expected_missing_tool_message("doomed_tool"),
)
async def test_missing_function_call_dev_error_logged_as_error(self):
"""Tool advertised to the LLM but missing a handler → logger.error."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
context = LLMContext(
tools=ToolsSchema(
standard_tools=[
FunctionSchema(
name="advertised_but_unhandled",
description="",
properties={},
required=[],
)
]
)
)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="advertised_but_unhandled",
tool_call_id="call_1",
arguments={},
context=context,
)
]
)
errors = [c.args[0] for c in mock_logger.error.call_args_list]
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertTrue(
any(
"advertised" in e and "register_function" in e and "advertised_but_unhandled" in e
for e in errors
),
f"expected dev-error log; got errors={errors}, warnings={warnings}",
)
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
async def test_missing_function_call_hallucination_logged_as_warning(self):
"""Tool not advertised to the LLM → logger.warning (hallucination)."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
context = LLMContext(
tools=ToolsSchema(
standard_tools=[
FunctionSchema(
name="something_else",
description="",
properties={},
required=[],
)
]
)
)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="never_advertised",
tool_call_id="call_1",
arguments={},
context=context,
)
]
)
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
errors = [c.args[0] for c in mock_logger.error.call_args_list]
self.assertTrue(
any(
"not in the currently advertised tool set" in w and "never_advertised" in w
for w in warnings
),
f"expected hallucination warning; got warnings={warnings}, errors={errors}",
)
self.assertFalse(any("advertised" in e and "register_function" in e for e in errors))
async def test_catch_all_handler_suppresses_missing_warnings(self):
"""register_function(None, ...) suppresses both dev-error and hallucination logs."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
async def catch_all(params):
await params.result_callback("handled")
service.register_function(None, catch_all)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="anything",
tool_call_id="call_1",
arguments={},
context=LLMContext(),
)
]
)
errors = [c.args[0] for c in mock_logger.error.call_args_list]
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertFalse(any("register_function" in e for e in errors))
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
async def test_missing_function_call_allows_user_mute_cleanup(self):
service = MockLLMService()
service._call_event_handler = AsyncMock()