Mitigate tool-call-related hallucination
When tools change mid-conversation, LLMs can produce a few different flavors of tool-call-related hallucination: calling tools that have been removed, avoiding tools that have been re-added, or hallucinating output (made-up answers or tool-call-shaped non-tool-calls) when tools are unavailable. This change introduces an opt-in ``add_tool_change_messages`` flag on the LLM aggregators (preferred entry point: ``LLMContextAggregatorPair( ..., add_tool_change_messages=True)``) that appends a developer-role message to the context whenever ``LLMSetToolsFrame`` changes the set of advertised standard tools. Helps the LLM stay coherent across tool changes by spelling out exactly what just became available or unavailable. Both aggregators participate; whichever handles the frame first wins, and the other (if any) sees an empty diff against the shared context and stays silent — order-independent regardless of whether the frame flows downstream or upstream. Also tightens the existing missing-handler path (introduced in #4301): - Reworded the terminal tool result to a neutral "The function ``X`` is not currently available." (overridable via ``LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE``). Previously read "Error: function 'X' is not registered." - Logs at the call site now distinguish developer error (tool advertised but no handler registered → ``logger.error``) from hallucination (tool not advertised → ``logger.warning``). Includes a manual validation harness (``examples/features/features-add-tool-change-messages.py``) that exercises the new ``add_tool_change_messages`` mitigation by flipping tool availability on a turn counter so its effect can be observed end-to-end with the flag on vs. off.
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -25,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
@@ -46,6 +49,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
AssistantThoughtMessage,
|
||||
AssistantTurnStoppedMessage,
|
||||
LLMAssistantAggregator,
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
@@ -1167,5 +1172,204 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
|
||||
def _function_schema(name: str) -> FunctionSchema:
|
||||
return FunctionSchema(name=name, description="", properties={}, required=[])
|
||||
|
||||
|
||||
def _tools(*names: str) -> ToolsSchema:
|
||||
return ToolsSchema(standard_tools=[_function_schema(n) for n in names])
|
||||
|
||||
|
||||
def _developer_messages(context: LLMContext) -> list[str]:
|
||||
return [
|
||||
m["content"]
|
||||
for m in context.messages
|
||||
if isinstance(m, dict) and m.get("role") == "developer"
|
||||
]
|
||||
|
||||
|
||||
class TestToolChangeMessages(unittest.IsolatedAsyncioTestCase):
|
||||
"""Coverage for the opt-in ``add_tool_change_messages`` feature.
|
||||
|
||||
The feature appends a developer-role message to the context whenever
|
||||
``LLMSetToolsFrame`` changes the set of advertised standard tools.
|
||||
"""
|
||||
|
||||
async def _send_set_tools_to_user_aggregator(self, aggregator, tools):
|
||||
# User aggregator forwards LLMSetToolsFrame downstream, so we expect
|
||||
# the SpeechControlParamsFrame (emitted on StartFrame) and the
|
||||
# forwarded LLMSetToolsFrame.
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=tools)],
|
||||
expected_down_frames=[SpeechControlParamsFrame, LLMSetToolsFrame],
|
||||
)
|
||||
|
||||
async def test_default_off_adds_no_message(self):
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
aggregator = LLMUserAggregator(context)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_user_aggregator_announces_additions(self):
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b", "c"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertNotIn("removed", msgs[0])
|
||||
# Sorted, stable order
|
||||
self.assertLess(msgs[0].index("`b`"), msgs[0].index("`c`"))
|
||||
|
||||
async def test_user_aggregator_announces_removals(self):
|
||||
context = LLMContext(tools=_tools("a", "b", "c"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertNotIn("just been added", msgs[0])
|
||||
|
||||
async def test_user_aggregator_combined_add_and_remove(self):
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("b", "c"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`a`", msgs[0])
|
||||
# Activation phrase appears before deactivation phrase.
|
||||
self.assertLess(msgs[0].index("just been added"), msgs[0].index("just been removed"))
|
||||
|
||||
async def test_no_message_when_diff_is_empty(self):
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_set_tools_to_not_given_lists_all_as_removed(self):
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, NOT_GIVEN)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`a`", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_set_tools_from_not_given_lists_all_as_added(self):
|
||||
context = LLMContext() # tools default to NOT_GIVEN
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("x", "y"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`x`", msgs[0])
|
||||
self.assertIn("`y`", msgs[0])
|
||||
|
||||
async def test_custom_tools_only_change_no_message(self):
|
||||
# Standard tools identical; only custom tools differ → no announcement.
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[_function_schema("a")],
|
||||
custom_tools={AdapterType.OPENAI: [{"type": "web_search"}]},
|
||||
)
|
||||
)
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
new_tools = ToolsSchema(
|
||||
standard_tools=[_function_schema("a")],
|
||||
custom_tools={AdapterType.OPENAI: [{"type": "file_search"}]},
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, new_tools)
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_pipeline_with_both_aggregators_announces_once(self):
|
||||
"""User agg runs first; assistant agg sees no diff and stays silent."""
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
user, assistant = LLMContextAggregatorPair(context, add_tool_change_messages=True)
|
||||
pipeline = Pipeline([user, assistant])
|
||||
# The user aggregator forwards LLMSetToolsFrame downstream; the
|
||||
# assistant aggregator consumes it (does not forward).
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
|
||||
expected_down_frames=[SpeechControlParamsFrame],
|
||||
)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1, f"expected exactly one announcement, got {msgs}")
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_assistant_aggregator_announces_when_handled_first(self):
|
||||
"""Order-independence: an upstream LLMSetToolsFrame hits the assistant
|
||||
aggregator first (before being consumed). It should announce, and the
|
||||
user aggregator (which never sees it) shouldn't matter for correctness.
|
||||
"""
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
assistant = LLMAssistantAggregator(
|
||||
context,
|
||||
params=LLMAssistantAggregatorParams(add_tool_change_messages=True),
|
||||
)
|
||||
# Send the frame upstream so the assistant aggregator processes it.
|
||||
await run_test(
|
||||
assistant,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=[],
|
||||
)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_pair_propagates_flag_to_both(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(context, add_tool_change_messages=True)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertTrue(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
async def test_pair_arg_overrides_per_params_settings(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(add_tool_change_messages=False),
|
||||
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
|
||||
add_tool_change_messages=True,
|
||||
)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertTrue(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
async def test_pair_default_respects_per_params(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(add_tool_change_messages=True),
|
||||
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
|
||||
)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertFalse(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,6 +8,8 @@ import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallFromLLM,
|
||||
@@ -21,6 +23,10 @@ from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy
|
||||
|
||||
|
||||
def _expected_missing_tool_message(name: str) -> str:
|
||||
return LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=name)
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
"""Minimal LLM service for testing function call execution."""
|
||||
|
||||
@@ -104,13 +110,14 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(recorded_frames[1].function_name, "missing_tool")
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'missing_tool' is not registered.",
|
||||
_expected_missing_tool_message("missing_tool"),
|
||||
)
|
||||
|
||||
# Only the queue-time warning should fire; the execution-time
|
||||
# "just unregistered" warning must not double-log.
|
||||
# The tool was not advertised, so this is treated as a hallucination
|
||||
# (warning at queue time). The execution-time "just unregistered"
|
||||
# warning must not double-log.
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertTrue(any("not registered" in w for w in warnings))
|
||||
self.assertTrue(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
self.assertFalse(any("just unregistered" in w for w in warnings))
|
||||
|
||||
async def test_function_unregistered_between_queue_and_execute(self):
|
||||
@@ -160,9 +167,124 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'doomed_tool' is not registered.",
|
||||
_expected_missing_tool_message("doomed_tool"),
|
||||
)
|
||||
|
||||
async def test_missing_function_call_dev_error_logged_as_error(self):
|
||||
"""Tool advertised to the LLM but missing a handler → logger.error."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[
|
||||
FunctionSchema(
|
||||
name="advertised_but_unhandled",
|
||||
description="",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="advertised_but_unhandled",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertTrue(
|
||||
any(
|
||||
"advertised" in e and "register_function" in e and "advertised_but_unhandled" in e
|
||||
for e in errors
|
||||
),
|
||||
f"expected dev-error log; got errors={errors}, warnings={warnings}",
|
||||
)
|
||||
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
|
||||
async def test_missing_function_call_hallucination_logged_as_warning(self):
|
||||
"""Tool not advertised to the LLM → logger.warning (hallucination)."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[
|
||||
FunctionSchema(
|
||||
name="something_else",
|
||||
description="",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="never_advertised",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
self.assertTrue(
|
||||
any(
|
||||
"not in the currently advertised tool set" in w and "never_advertised" in w
|
||||
for w in warnings
|
||||
),
|
||||
f"expected hallucination warning; got warnings={warnings}, errors={errors}",
|
||||
)
|
||||
self.assertFalse(any("advertised" in e and "register_function" in e for e in errors))
|
||||
|
||||
async def test_catch_all_handler_suppresses_missing_warnings(self):
|
||||
"""register_function(None, ...) suppresses both dev-error and hallucination logs."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
async def catch_all(params):
|
||||
await params.result_callback("handled")
|
||||
|
||||
service.register_function(None, catch_all)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="anything",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertFalse(any("register_function" in e for e in errors))
|
||||
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
|
||||
async def test_missing_function_call_allows_user_mute_cleanup(self):
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
|
||||
Reference in New Issue
Block a user