Mitigate tool-call-related hallucination

When tools change mid-conversation, LLMs can produce a few different
flavors of tool-call-related hallucination: calling tools that have
been removed, avoiding tools that have been re-added, or hallucinating
output (made-up answers or tool-call-shaped non-tool-calls) when tools
are unavailable.

This change introduces an opt-in ``add_tool_change_messages`` flag on
the LLM aggregators (preferred entry point: ``LLMContextAggregatorPair(
..., add_tool_change_messages=True)``) that appends a developer-role
message to the context whenever ``LLMSetToolsFrame`` changes the set
of advertised standard tools. Helps the LLM stay coherent across tool
changes by spelling out exactly what just became available or
unavailable. Both aggregators participate; whichever handles the
frame first wins, and the other (if any) sees an empty diff against
the shared context and stays silent — order-independent regardless of
whether the frame flows downstream or upstream.

Also tightens the existing missing-handler path (introduced in #4301):

- Reworded the terminal tool result to a neutral "The function
  ``X`` is not currently available." (overridable via
  ``LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE``). Previously
  read "Error: function 'X' is not registered."
- Logs at the call site now distinguish developer error (tool
  advertised but no handler registered → ``logger.error``) from
  hallucination (tool not advertised → ``logger.warning``).

Includes a manual validation harness
(``examples/features/features-add-tool-change-messages.py``) that
exercises the new ``add_tool_change_messages`` mitigation by flipping
tool availability on a turn counter so its effect can be observed
end-to-end with the flag on vs. off.
This commit is contained in:
Paul Kompfner
2026-05-05 13:02:43 -04:00
parent a745e8d318
commit e06e0c0282
5 changed files with 745 additions and 15 deletions

View File

@@ -7,6 +7,8 @@
import json
import unittest
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
@@ -25,6 +27,7 @@ from pipecat.frames.frames import (
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMSetToolsFrame,
LLMTextFrame,
LLMThoughtEndFrame,
LLMThoughtStartFrame,
@@ -46,6 +49,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
AssistantThoughtMessage,
AssistantTurnStoppedMessage,
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregator,
LLMUserAggregatorParams,
)
@@ -1167,5 +1172,204 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
assert context.messages[0]["content"] == "HELLO"
def _function_schema(name: str) -> FunctionSchema:
return FunctionSchema(name=name, description="", properties={}, required=[])
def _tools(*names: str) -> ToolsSchema:
return ToolsSchema(standard_tools=[_function_schema(n) for n in names])
def _developer_messages(context: LLMContext) -> list[str]:
return [
m["content"]
for m in context.messages
if isinstance(m, dict) and m.get("role") == "developer"
]
class TestToolChangeMessages(unittest.IsolatedAsyncioTestCase):
"""Coverage for the opt-in ``add_tool_change_messages`` feature.
The feature appends a developer-role message to the context whenever
``LLMSetToolsFrame`` changes the set of advertised standard tools.
"""
async def _send_set_tools_to_user_aggregator(self, aggregator, tools):
# User aggregator forwards LLMSetToolsFrame downstream, so we expect
# the SpeechControlParamsFrame (emitted on StartFrame) and the
# forwarded LLMSetToolsFrame.
await run_test(
aggregator,
frames_to_send=[LLMSetToolsFrame(tools=tools)],
expected_down_frames=[SpeechControlParamsFrame, LLMSetToolsFrame],
)
async def test_default_off_adds_no_message(self):
context = LLMContext(tools=_tools("a"))
aggregator = LLMUserAggregator(context)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
self.assertEqual(_developer_messages(context), [])
async def test_user_aggregator_announces_additions(self):
context = LLMContext(tools=_tools("a"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b", "c"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`b`", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertNotIn("removed", msgs[0])
# Sorted, stable order
self.assertLess(msgs[0].index("`b`"), msgs[0].index("`c`"))
async def test_user_aggregator_announces_removals(self):
context = LLMContext(tools=_tools("a", "b", "c"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been removed", msgs[0])
self.assertIn("`b`", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertNotIn("just been added", msgs[0])
async def test_user_aggregator_combined_add_and_remove(self):
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("b", "c"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`c`", msgs[0])
self.assertIn("just been removed", msgs[0])
self.assertIn("`a`", msgs[0])
# Activation phrase appears before deactivation phrase.
self.assertLess(msgs[0].index("just been added"), msgs[0].index("just been removed"))
async def test_no_message_when_diff_is_empty(self):
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
self.assertEqual(_developer_messages(context), [])
async def test_set_tools_to_not_given_lists_all_as_removed(self):
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
context = LLMContext(tools=_tools("a", "b"))
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, NOT_GIVEN)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been removed", msgs[0])
self.assertIn("`a`", msgs[0])
self.assertIn("`b`", msgs[0])
async def test_set_tools_from_not_given_lists_all_as_added(self):
context = LLMContext() # tools default to NOT_GIVEN
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
await self._send_set_tools_to_user_aggregator(aggregator, _tools("x", "y"))
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("just been added", msgs[0])
self.assertIn("`x`", msgs[0])
self.assertIn("`y`", msgs[0])
async def test_custom_tools_only_change_no_message(self):
# Standard tools identical; only custom tools differ → no announcement.
context = LLMContext(
tools=ToolsSchema(
standard_tools=[_function_schema("a")],
custom_tools={AdapterType.OPENAI: [{"type": "web_search"}]},
)
)
aggregator = LLMUserAggregator(
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
)
new_tools = ToolsSchema(
standard_tools=[_function_schema("a")],
custom_tools={AdapterType.OPENAI: [{"type": "file_search"}]},
)
await self._send_set_tools_to_user_aggregator(aggregator, new_tools)
self.assertEqual(_developer_messages(context), [])
async def test_pipeline_with_both_aggregators_announces_once(self):
"""User agg runs first; assistant agg sees no diff and stays silent."""
context = LLMContext(tools=_tools("a"))
user, assistant = LLMContextAggregatorPair(context, add_tool_change_messages=True)
pipeline = Pipeline([user, assistant])
# The user aggregator forwards LLMSetToolsFrame downstream; the
# assistant aggregator consumes it (does not forward).
await run_test(
pipeline,
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
expected_down_frames=[SpeechControlParamsFrame],
)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1, f"expected exactly one announcement, got {msgs}")
self.assertIn("`b`", msgs[0])
async def test_assistant_aggregator_announces_when_handled_first(self):
"""Order-independence: an upstream LLMSetToolsFrame hits the assistant
aggregator first (before being consumed). It should announce, and the
user aggregator (which never sees it) shouldn't matter for correctness.
"""
context = LLMContext(tools=_tools("a"))
assistant = LLMAssistantAggregator(
context,
params=LLMAssistantAggregatorParams(add_tool_change_messages=True),
)
# Send the frame upstream so the assistant aggregator processes it.
await run_test(
assistant,
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=[],
)
msgs = _developer_messages(context)
self.assertEqual(len(msgs), 1)
self.assertIn("`b`", msgs[0])
async def test_pair_propagates_flag_to_both(self):
context = LLMContext()
pair = LLMContextAggregatorPair(context, add_tool_change_messages=True)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertTrue(pair.assistant()._add_tool_change_messages)
async def test_pair_arg_overrides_per_params_settings(self):
context = LLMContext()
pair = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(add_tool_change_messages=False),
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
add_tool_change_messages=True,
)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertTrue(pair.assistant()._add_tool_change_messages)
async def test_pair_default_respects_per_params(self):
context = LLMContext()
pair = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(add_tool_change_messages=True),
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
)
self.assertTrue(pair.user()._add_tool_change_messages)
self.assertFalse(pair.assistant()._add_tool_change_messages)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,6 +8,8 @@ import unittest
from unittest.mock import AsyncMock, patch
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
FunctionCallFromLLM,
@@ -21,6 +23,10 @@ from pipecat.services.settings import LLMSettings
from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy
def _expected_missing_tool_message(name: str) -> str:
return LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=name)
class MockLLMService(LLMService):
"""Minimal LLM service for testing function call execution."""
@@ -104,13 +110,14 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
self.assertEqual(recorded_frames[1].function_name, "missing_tool")
self.assertEqual(
recorded_frames[2].result,
"Error: function 'missing_tool' is not registered.",
_expected_missing_tool_message("missing_tool"),
)
# Only the queue-time warning should fire; the execution-time
# "just unregistered" warning must not double-log.
# The tool was not advertised, so this is treated as a hallucination
# (warning at queue time). The execution-time "just unregistered"
# warning must not double-log.
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertTrue(any("not registered" in w for w in warnings))
self.assertTrue(any("not in the currently advertised tool set" in w for w in warnings))
self.assertFalse(any("just unregistered" in w for w in warnings))
async def test_function_unregistered_between_queue_and_execute(self):
@@ -160,9 +167,124 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
)
self.assertEqual(
recorded_frames[2].result,
"Error: function 'doomed_tool' is not registered.",
_expected_missing_tool_message("doomed_tool"),
)
async def test_missing_function_call_dev_error_logged_as_error(self):
"""Tool advertised to the LLM but missing a handler → logger.error."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
context = LLMContext(
tools=ToolsSchema(
standard_tools=[
FunctionSchema(
name="advertised_but_unhandled",
description="",
properties={},
required=[],
)
]
)
)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="advertised_but_unhandled",
tool_call_id="call_1",
arguments={},
context=context,
)
]
)
errors = [c.args[0] for c in mock_logger.error.call_args_list]
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertTrue(
any(
"advertised" in e and "register_function" in e and "advertised_but_unhandled" in e
for e in errors
),
f"expected dev-error log; got errors={errors}, warnings={warnings}",
)
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
async def test_missing_function_call_hallucination_logged_as_warning(self):
"""Tool not advertised to the LLM → logger.warning (hallucination)."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
context = LLMContext(
tools=ToolsSchema(
standard_tools=[
FunctionSchema(
name="something_else",
description="",
properties={},
required=[],
)
]
)
)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="never_advertised",
tool_call_id="call_1",
arguments={},
context=context,
)
]
)
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
errors = [c.args[0] for c in mock_logger.error.call_args_list]
self.assertTrue(
any(
"not in the currently advertised tool set" in w and "never_advertised" in w
for w in warnings
),
f"expected hallucination warning; got warnings={warnings}, errors={errors}",
)
self.assertFalse(any("advertised" in e and "register_function" in e for e in errors))
async def test_catch_all_handler_suppresses_missing_warnings(self):
"""register_function(None, ...) suppresses both dev-error and hallucination logs."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
await self._run_function_calls_inline(service)
service.broadcast_frame = AsyncMock()
async def catch_all(params):
await params.result_callback("handled")
service.register_function(None, catch_all)
with patch("pipecat.services.llm_service.logger") as mock_logger:
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="anything",
tool_call_id="call_1",
arguments={},
context=LLMContext(),
)
]
)
errors = [c.args[0] for c in mock_logger.error.call_args_list]
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
self.assertFalse(any("register_function" in e for e in errors))
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
async def test_missing_function_call_allows_user_mute_cleanup(self):
service = MockLLMService()
service._call_event_handler = AsyncMock()