Mitigate tool-call-related hallucination
When tools change mid-conversation, LLMs can produce a few different flavors of tool-call-related hallucination: calling tools that have been removed, avoiding tools that have been re-added, or hallucinating output (made-up answers or tool-call-shaped non-tool-calls) when tools are unavailable. This change introduces an opt-in ``add_tool_change_messages`` flag on the LLM aggregators (preferred entry point: ``LLMContextAggregatorPair( ..., add_tool_change_messages=True)``) that appends a developer-role message to the context whenever ``LLMSetToolsFrame`` changes the set of advertised standard tools. Helps the LLM stay coherent across tool changes by spelling out exactly what just became available or unavailable. Both aggregators participate; whichever handles the frame first wins, and the other (if any) sees an empty diff against the shared context and stays silent — order-independent regardless of whether the frame flows downstream or upstream. Also tightens the existing missing-handler path (introduced in #4301): - Reworded the terminal tool result to a neutral "The function ``X`` is not currently available." (overridable via ``LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE``). Previously read "Error: function 'X' is not registered." - Logs at the call site now distinguish developer error (tool advertised but no handler registered → ``logger.error``) from hallucination (tool not advertised → ``logger.warning``). Includes a manual validation harness (``examples/features/features-add-tool-change-messages.py``) that exercises the new ``add_tool_change_messages`` mitigation by flipping tool availability on a turn counter so its effect can be observed end-to-end with the flag on vs. off.
This commit is contained in:
232
examples/features/features-add-tool-change-messages.py
Normal file
232
examples/features/features-add-tool-change-messages.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Manual validation harness for the ``add_tool_change_messages`` feature.
|
||||
|
||||
When tools change mid-conversation, LLMs can produce a few different
|
||||
flavors of tool-call-related hallucination:
|
||||
|
||||
- **Forward hallucination** — calling a tool that has been removed.
|
||||
- **Negative hallucination** — refusing to call a tool that has been
|
||||
re-added (because recent context is full of "I can't" responses).
|
||||
- **Hallucinated output when tools are unavailable** — making up an
|
||||
answer rather than declining gracefully, or producing JSON that
|
||||
*looks* like a tool call but is actually just an assistant text
|
||||
response.
|
||||
|
||||
The ``add_tool_change_messages`` feature mitigates these by appending a
|
||||
developer-role message to the conversation whenever ``LLMSetToolsFrame``
|
||||
changes the set of advertised tools, so the LLM stays in sync with what's
|
||||
actually available.
|
||||
|
||||
This harness exercises all of those flavors by flipping the advertised
|
||||
tool set on a turn counter:
|
||||
|
||||
Phase 0 (turns 1–4): weather tool ACTIVE — confirm baseline.
|
||||
Phase 1 (turns 5–8): tool REMOVED — keep asking for weather.
|
||||
Phase 2 (turn 9+): tool RE-ADDED — does the LLM call it again?
|
||||
|
||||
Set ``ADD_TOOL_CHANGE_MESSAGES=0`` to disable the mitigation and see the
|
||||
unmitigated behavior. The default is ON so a fresh run shows the feature
|
||||
working.
|
||||
|
||||
Defaults to Llama 3.1 8B Instruct via a locally-running Ollama —
|
||||
anecdotally one of the more hallucination-prone of the easily accessible
|
||||
models. Pull the model once with ``ollama pull llama3.1:8b`` and make
|
||||
sure ``ollama serve`` is running. Swap the LLM service to validate other
|
||||
providers.
|
||||
|
||||
Run with::
|
||||
|
||||
uv run examples/features/features-add-tool-change-messages.py
|
||||
ADD_TOOL_CHANGE_MESSAGES=0 uv run examples/features/features-add-tool-change-messages.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.ollama.llm import OLLamaLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Default ON so a fresh run shows the feature working. Set to "0" to A/B
|
||||
# against the unmitigated behavior.
|
||||
ADD_TOOL_CHANGE_MESSAGES = os.environ.get("ADD_TOOL_CHANGE_MESSAGES", "1") == "1"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
weather_tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_in_enabled=True, audio_out_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_in_enabled=True, audio_out_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(
|
||||
f"Starting add_tool_change_messages demo bot "
|
||||
f"(ADD_TOOL_CHANGE_MESSAGES={ADD_TOOL_CHANGE_MESSAGES})"
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"])
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.environ["CARTESIA_API_KEY"],
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
)
|
||||
|
||||
llm = OLLamaLLMService(
|
||||
settings=OLLamaLLMService.Settings(
|
||||
# Llama 3.1 8B Instruct is anecdotally one of the more
|
||||
# hallucination-prone of the easily accessible models — exactly
|
||||
# what we want for this validation harness. Pull it with
|
||||
# ``ollama pull llama3.1:8b`` and make sure ``ollama serve``
|
||||
# is running.
|
||||
model="llama3.1:8b",
|
||||
system_instruction=(
|
||||
"You are a helpful assistant in a voice conversation. Your responses "
|
||||
"will be spoken aloud, so avoid emojis, bullet points, or other "
|
||||
"formatting that can't be spoken. Respond briefly and naturally. "
|
||||
"If the user asks for the current weather, use the `get_current_weather` "
|
||||
"function if it's available. IMPORTANT: if you do not have access to the function, "
|
||||
"say something along the lines of 'Sorry, I can't check the weather right now.'."
|
||||
),
|
||||
),
|
||||
)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
context = LLMContext(tools=weather_tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
add_tool_change_messages=ADD_TOOL_CHANGE_MESSAGES,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(enable_metrics=True, enable_usage_metrics=True),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Phase controller: roughly 4 turns per phase.
|
||||
user_turn_count = 0
|
||||
REMOVE_AT_TURN = 5 # tool gone for turn N onward
|
||||
READD_AT_TURN = 9 # tool back for turn N onward
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, strategy, message):
|
||||
nonlocal user_turn_count
|
||||
user_turn_count += 1
|
||||
logger.info(f"=== User turn {user_turn_count} complete ===")
|
||||
|
||||
if user_turn_count == REMOVE_AT_TURN - 1:
|
||||
logger.info(
|
||||
"=== Phase 1: weather tool REMOVED. Keep asking about the weather "
|
||||
"to exercise hallucination scenarios. ==="
|
||||
)
|
||||
await task.queue_frame(LLMSetToolsFrame(tools=NOT_GIVEN))
|
||||
elif user_turn_count == READD_AT_TURN - 1:
|
||||
logger.info(
|
||||
"=== Phase 2: weather tool RE-ADDED. Ask for the weather again — "
|
||||
"does the LLM call it, or keep refusing? (THIS IS THE TEST.) ==="
|
||||
)
|
||||
await task.queue_frame(LLMSetToolsFrame(tools=weather_tools))
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
logger.info(
|
||||
"=== Phase 0: weather tool ACTIVE. Ask for the weather a few times "
|
||||
"to confirm it's working. ==="
|
||||
)
|
||||
context.add_message(
|
||||
{
|
||||
"role": "developer",
|
||||
"content": (
|
||||
"Please introduce yourself briefly to the user, then invite them "
|
||||
"to ask about the weather."
|
||||
),
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -72,6 +72,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
is_given,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import (
|
||||
LLMContextSummarizer,
|
||||
@@ -118,6 +119,18 @@ class LLMUserAggregatorParams:
|
||||
user_turn_completion_config: Configuration for turn completion behavior including
|
||||
custom instructions, timeouts, and prompts. Only used when
|
||||
filter_incomplete_user_turns is True.
|
||||
add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the
|
||||
aggregator computes the diff against the currently advertised tools
|
||||
and appends a developer-role message to the context describing
|
||||
additions/removals. Helps the LLM stay coherent across
|
||||
mid-conversation tool changes, mitigating several flavors of
|
||||
tool-call-related hallucination: calling tools that have been
|
||||
removed, avoiding tools that have been re-added, and hallucinating
|
||||
output (made-up answers or tool-call-shaped non-tool-calls) when
|
||||
tools are unavailable. Only standard tools are diffed; custom
|
||||
(LLM-specific) tools are ignored. When using
|
||||
``LLMContextAggregatorPair``, prefer setting this via its
|
||||
``add_tool_change_messages`` argument instead. Defaults to False.
|
||||
"""
|
||||
|
||||
user_turn_strategies: UserTurnStrategies | None = None
|
||||
@@ -128,6 +141,7 @@ class LLMUserAggregatorParams:
|
||||
audio_idle_timeout: float = 1.0
|
||||
filter_incomplete_user_turns: bool = False
|
||||
user_turn_completion_config: UserTurnCompletionConfig | None = None
|
||||
add_tool_change_messages: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -143,10 +157,23 @@ class LLMAssistantAggregatorParams:
|
||||
summarization. Controls trigger thresholds, message preservation, and
|
||||
summarization prompts. If None, uses default
|
||||
``LLMAutoContextSummarizationConfig`` values.
|
||||
add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the
|
||||
aggregator computes the diff against the currently advertised tools
|
||||
and appends a developer-role message to the context describing
|
||||
additions/removals. Helps the LLM stay coherent across
|
||||
mid-conversation tool changes, mitigating several flavors of
|
||||
tool-call-related hallucination: calling tools that have been
|
||||
removed, avoiding tools that have been re-added, and hallucinating
|
||||
output (made-up answers or tool-call-shaped non-tool-calls) when
|
||||
tools are unavailable. Only standard tools are diffed; custom
|
||||
(LLM-specific) tools are ignored. When using
|
||||
``LLMContextAggregatorPair``, prefer setting this via its
|
||||
``add_tool_change_messages`` argument instead. Defaults to False.
|
||||
"""
|
||||
|
||||
enable_auto_context_summarization: bool = False
|
||||
auto_context_summarization_config: LLMAutoContextSummarizationConfig | None = None
|
||||
add_tool_change_messages: bool = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deprecated field names — kept for backward compatibility.
|
||||
@@ -248,20 +275,87 @@ class LLMContextAggregator(FrameProcessor):
|
||||
common functionality for context-based conversation management.
|
||||
"""
|
||||
|
||||
def __init__(self, *, context: LLMContext, role: str, **kwargs):
|
||||
# Developer-role messages appended to the context when tools are added/
|
||||
# removed via ``LLMSetToolsFrame`` (only when ``add_tool_change_messages``
|
||||
# is enabled on the aggregator's params). ``{function_names}`` is
|
||||
# substituted with a sorted, comma-separated, backtick-wrapped list.
|
||||
TOOL_ACTIVATION_MESSAGE_TEMPLATE = (
|
||||
"The following function(s) have just been added and may now be called: "
|
||||
"{function_names}. Any previously available functions remain available."
|
||||
)
|
||||
TOOL_DEACTIVATION_MESSAGE_TEMPLATE = (
|
||||
"The following function(s) have just been removed and should not be called: "
|
||||
"{function_names}. Any previously available functions remain available. "
|
||||
"The removed function(s) may become available again later, in which case "
|
||||
"you will be informed."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context: LLMContext,
|
||||
role: str,
|
||||
add_tool_change_messages: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the context response aggregator.
|
||||
|
||||
Args:
|
||||
context: The LLM context to use for conversation storage.
|
||||
role: The role this aggregator represents (e.g. "user", "assistant").
|
||||
add_tool_change_messages: See the field of the same name on the
|
||||
aggregator-specific params dataclasses. Subclasses propagate
|
||||
this from their ``params``.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._context = context
|
||||
self._role = role
|
||||
self._add_tool_change_messages = add_tool_change_messages
|
||||
|
||||
self._aggregation: list[TextPartForConcatenation] = []
|
||||
|
||||
def _maybe_add_tool_change_messages(self, new_tools: ToolsSchema | NotGiven) -> None:
|
||||
"""Append a developer message describing tool add/remove deltas.
|
||||
|
||||
No-op unless ``add_tool_change_messages`` was enabled on the aggregator,
|
||||
and no-op when the diff against the currently advertised tools is empty.
|
||||
Custom (LLM-specific) tools are ignored — only standard tools are diffed.
|
||||
|
||||
Both aggregators call this on every ``LLMSetToolsFrame`` they handle.
|
||||
Whichever aggregator handles the frame first computes a real diff
|
||||
against the shared context and adds the announcement; by the time
|
||||
the other aggregator sees it (if at all), the context already
|
||||
reflects the new tools, so its diff is empty and no duplicate
|
||||
message is added. This is order-independent: it works whether the
|
||||
frame flows downstream (user aggregator first) or upstream
|
||||
(assistant aggregator first, and consumed without being forwarded).
|
||||
"""
|
||||
if not self._add_tool_change_messages:
|
||||
return
|
||||
|
||||
def _names(tools: ToolsSchema | NotGiven) -> set[str]:
|
||||
if not is_given(tools):
|
||||
return set()
|
||||
return {s.name for s in tools.standard_tools}
|
||||
|
||||
old_names = _names(self._context.tools)
|
||||
new_names = _names(new_tools)
|
||||
added = new_names - old_names
|
||||
removed = old_names - new_names
|
||||
if not added and not removed:
|
||||
return
|
||||
|
||||
parts: list[str] = []
|
||||
if added:
|
||||
names = ", ".join(f"`{n}`" for n in sorted(added))
|
||||
parts.append(self.TOOL_ACTIVATION_MESSAGE_TEMPLATE.format(function_names=names))
|
||||
if removed:
|
||||
names = ", ".join(f"`{n}`" for n in sorted(removed))
|
||||
parts.append(self.TOOL_DEACTIVATION_MESSAGE_TEMPLATE.format(function_names=names))
|
||||
|
||||
self._context.add_message({"role": "developer", "content": " ".join(parts)})
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMContextMessage]:
|
||||
"""Get messages from the LLM context.
|
||||
@@ -434,8 +528,14 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserAggregatorParams()
|
||||
params = params or LLMUserAggregatorParams()
|
||||
super().__init__(
|
||||
context=context,
|
||||
role="user",
|
||||
add_tool_change_messages=params.add_tool_change_messages,
|
||||
**kwargs,
|
||||
)
|
||||
self._params = params
|
||||
|
||||
self._register_event_handler("on_user_turn_started")
|
||||
self._register_event_handler("on_user_turn_stopped")
|
||||
@@ -536,6 +636,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self._maybe_add_tool_change_messages(frame.tools)
|
||||
self.set_tools(frame.tools)
|
||||
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
|
||||
# services (like OpenAI Realtime) may need to know about tool
|
||||
@@ -843,8 +944,14 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
params: Configuration parameters for aggregation behavior.
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._params = params or LLMAssistantAggregatorParams()
|
||||
params = params or LLMAssistantAggregatorParams()
|
||||
super().__init__(
|
||||
context=context,
|
||||
role="assistant",
|
||||
add_tool_change_messages=params.add_tool_change_messages,
|
||||
**kwargs,
|
||||
)
|
||||
self._params = params
|
||||
|
||||
self._function_calls_in_progress: dict[str, FunctionCallInProgressFrame | None] = {}
|
||||
self._function_calls_image_results: dict[str, UserImageRawFrame] = {}
|
||||
@@ -949,6 +1056,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self._maybe_add_tool_change_messages(frame.tools)
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
@@ -1478,6 +1586,7 @@ class LLMContextAggregatorPair:
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams | None = None,
|
||||
assistant_params: LLMAssistantAggregatorParams | None = None,
|
||||
add_tool_change_messages: bool | None = None,
|
||||
):
|
||||
"""Initialize the LLM context aggregator pair.
|
||||
|
||||
@@ -1485,9 +1594,22 @@ class LLMContextAggregatorPair:
|
||||
context: The context to be managed by the aggregators.
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
add_tool_change_messages: When provided, sets the field of the
|
||||
same name on both ``user_params`` and ``assistant_params``,
|
||||
overriding any value already set on either. This is the
|
||||
preferred way to enable tool-change announcements: it ensures
|
||||
both aggregators participate, which makes the feature robust
|
||||
regardless of which aggregator handles a given
|
||||
``LLMSetToolsFrame``. The shared context guarantees the
|
||||
announcement is added exactly once (the second aggregator's
|
||||
diff is empty by the time it sees the frame). Leave as
|
||||
``None`` to respect per-params settings.
|
||||
"""
|
||||
user_params = user_params or LLMUserAggregatorParams()
|
||||
assistant_params = assistant_params or LLMAssistantAggregatorParams()
|
||||
if add_tool_change_messages is not None:
|
||||
user_params.add_tool_change_messages = add_tool_change_messages
|
||||
assistant_params.add_tool_change_messages = add_tool_change_messages
|
||||
self._user = LLMUserAggregator(context, params=user_params)
|
||||
self._assistant = LLMAssistantAggregator(context, params=assistant_params)
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
is_given,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
@@ -243,6 +244,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
# Returned to the LLM as the tool result when an unavailable function is
|
||||
# called. Deliberately neutral about future availability so the LLM can
|
||||
# pick the function up again if it returns (e.g. via the
|
||||
# ``add_tool_change_messages`` activation message, or silently on a
|
||||
# later inference). ``{function_name}`` is substituted at runtime.
|
||||
MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE = (
|
||||
"The function `{function_name}` is not currently available."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_in_parallel: bool = True,
|
||||
@@ -764,9 +774,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self} is calling '{function_call.function_name}', but it's not registered."
|
||||
)
|
||||
self._log_missing_function_call(function_call.function_name, function_call.context)
|
||||
item = self._build_missing_function_call_registry_item(function_call.function_name)
|
||||
|
||||
runner_items.append(
|
||||
@@ -835,8 +843,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
elif runner_item.registry_item.handler == self._missing_function_call_handler:
|
||||
item = runner_item.registry_item
|
||||
else:
|
||||
# Function was unregistered between queue and execution; the
|
||||
# registry-item-handler check above already covered the
|
||||
# missing-from-the-start case.
|
||||
logger.warning(
|
||||
f"{self} is calling '{runner_item.function_name}', but it was just unregistered."
|
||||
f"{self}: '{runner_item.function_name}' was just unregistered "
|
||||
f"between queueing and execution."
|
||||
)
|
||||
item = self._build_missing_function_call_registry_item(runner_item.function_name)
|
||||
|
||||
@@ -962,7 +974,45 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
|
||||
async def _missing_function_call_handler(self, params: FunctionCallParams):
|
||||
"""Return a terminal tool result when the LLM calls an unknown function."""
|
||||
await params.result_callback(f"Error: function '{params.function_name}' is not registered.")
|
||||
await params.result_callback(
|
||||
self.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=params.function_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _advertised_tool_names(context) -> set[str]:
|
||||
"""Return the set of standard-tool names currently advertised to the LLM.
|
||||
|
||||
Custom (LLM-specific) tools are not included, since they have no
|
||||
consistent name field across adapters.
|
||||
"""
|
||||
tools = context.tools if context is not None else None
|
||||
if tools is None or not is_given(tools):
|
||||
return set()
|
||||
return {t.name for t in tools.standard_tools}
|
||||
|
||||
def _log_missing_function_call(self, function_name: str, context) -> None:
|
||||
"""Log an appropriate message when a tool is called with no handler.
|
||||
|
||||
Distinguishes two cases:
|
||||
|
||||
- **Developer error:** the tool is advertised to the LLM but no handler
|
||||
was registered (likely a missed ``register_function`` call). Logged
|
||||
at error level since this almost always indicates a bug.
|
||||
- **Hallucination:** the tool is not in the currently advertised tool
|
||||
set. Logged at warning level since this is model behavior the
|
||||
application can do little about beyond returning a terminal result.
|
||||
"""
|
||||
if function_name in self._advertised_tool_names(context):
|
||||
logger.error(
|
||||
f"{self}: tool '{function_name}' is advertised to the LLM "
|
||||
f"but has no registered handler — did you forget to call "
|
||||
f"register_function()?"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self}: LLM called '{function_name}', which is not in the "
|
||||
f"currently advertised tool set."
|
||||
)
|
||||
|
||||
def _has_async_tools(self) -> bool:
|
||||
"""Return True if at least one non-builtin async tool is registered."""
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -25,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
@@ -46,6 +49,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
AssistantThoughtMessage,
|
||||
AssistantTurnStoppedMessage,
|
||||
LLMAssistantAggregator,
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
@@ -1167,5 +1172,204 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
|
||||
def _function_schema(name: str) -> FunctionSchema:
|
||||
return FunctionSchema(name=name, description="", properties={}, required=[])
|
||||
|
||||
|
||||
def _tools(*names: str) -> ToolsSchema:
|
||||
return ToolsSchema(standard_tools=[_function_schema(n) for n in names])
|
||||
|
||||
|
||||
def _developer_messages(context: LLMContext) -> list[str]:
|
||||
return [
|
||||
m["content"]
|
||||
for m in context.messages
|
||||
if isinstance(m, dict) and m.get("role") == "developer"
|
||||
]
|
||||
|
||||
|
||||
class TestToolChangeMessages(unittest.IsolatedAsyncioTestCase):
|
||||
"""Coverage for the opt-in ``add_tool_change_messages`` feature.
|
||||
|
||||
The feature appends a developer-role message to the context whenever
|
||||
``LLMSetToolsFrame`` changes the set of advertised standard tools.
|
||||
"""
|
||||
|
||||
async def _send_set_tools_to_user_aggregator(self, aggregator, tools):
|
||||
# User aggregator forwards LLMSetToolsFrame downstream, so we expect
|
||||
# the SpeechControlParamsFrame (emitted on StartFrame) and the
|
||||
# forwarded LLMSetToolsFrame.
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=tools)],
|
||||
expected_down_frames=[SpeechControlParamsFrame, LLMSetToolsFrame],
|
||||
)
|
||||
|
||||
async def test_default_off_adds_no_message(self):
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
aggregator = LLMUserAggregator(context)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_user_aggregator_announces_additions(self):
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b", "c"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertNotIn("removed", msgs[0])
|
||||
# Sorted, stable order
|
||||
self.assertLess(msgs[0].index("`b`"), msgs[0].index("`c`"))
|
||||
|
||||
async def test_user_aggregator_announces_removals(self):
|
||||
context = LLMContext(tools=_tools("a", "b", "c"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertNotIn("just been added", msgs[0])
|
||||
|
||||
async def test_user_aggregator_combined_add_and_remove(self):
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("b", "c"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`c`", msgs[0])
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`a`", msgs[0])
|
||||
# Activation phrase appears before deactivation phrase.
|
||||
self.assertLess(msgs[0].index("just been added"), msgs[0].index("just been removed"))
|
||||
|
||||
async def test_no_message_when_diff_is_empty(self):
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b"))
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_set_tools_to_not_given_lists_all_as_removed(self):
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN
|
||||
|
||||
context = LLMContext(tools=_tools("a", "b"))
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, NOT_GIVEN)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been removed", msgs[0])
|
||||
self.assertIn("`a`", msgs[0])
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_set_tools_from_not_given_lists_all_as_added(self):
|
||||
context = LLMContext() # tools default to NOT_GIVEN
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, _tools("x", "y"))
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("just been added", msgs[0])
|
||||
self.assertIn("`x`", msgs[0])
|
||||
self.assertIn("`y`", msgs[0])
|
||||
|
||||
async def test_custom_tools_only_change_no_message(self):
|
||||
# Standard tools identical; only custom tools differ → no announcement.
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[_function_schema("a")],
|
||||
custom_tools={AdapterType.OPENAI: [{"type": "web_search"}]},
|
||||
)
|
||||
)
|
||||
aggregator = LLMUserAggregator(
|
||||
context, params=LLMUserAggregatorParams(add_tool_change_messages=True)
|
||||
)
|
||||
new_tools = ToolsSchema(
|
||||
standard_tools=[_function_schema("a")],
|
||||
custom_tools={AdapterType.OPENAI: [{"type": "file_search"}]},
|
||||
)
|
||||
await self._send_set_tools_to_user_aggregator(aggregator, new_tools)
|
||||
self.assertEqual(_developer_messages(context), [])
|
||||
|
||||
async def test_pipeline_with_both_aggregators_announces_once(self):
|
||||
"""User agg runs first; assistant agg sees no diff and stays silent."""
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
user, assistant = LLMContextAggregatorPair(context, add_tool_change_messages=True)
|
||||
pipeline = Pipeline([user, assistant])
|
||||
# The user aggregator forwards LLMSetToolsFrame downstream; the
|
||||
# assistant aggregator consumes it (does not forward).
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
|
||||
expected_down_frames=[SpeechControlParamsFrame],
|
||||
)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1, f"expected exactly one announcement, got {msgs}")
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_assistant_aggregator_announces_when_handled_first(self):
|
||||
"""Order-independence: an upstream LLMSetToolsFrame hits the assistant
|
||||
aggregator first (before being consumed). It should announce, and the
|
||||
user aggregator (which never sees it) shouldn't matter for correctness.
|
||||
"""
|
||||
context = LLMContext(tools=_tools("a"))
|
||||
assistant = LLMAssistantAggregator(
|
||||
context,
|
||||
params=LLMAssistantAggregatorParams(add_tool_change_messages=True),
|
||||
)
|
||||
# Send the frame upstream so the assistant aggregator processes it.
|
||||
await run_test(
|
||||
assistant,
|
||||
frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=[],
|
||||
)
|
||||
msgs = _developer_messages(context)
|
||||
self.assertEqual(len(msgs), 1)
|
||||
self.assertIn("`b`", msgs[0])
|
||||
|
||||
async def test_pair_propagates_flag_to_both(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(context, add_tool_change_messages=True)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertTrue(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
async def test_pair_arg_overrides_per_params_settings(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(add_tool_change_messages=False),
|
||||
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
|
||||
add_tool_change_messages=True,
|
||||
)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertTrue(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
async def test_pair_default_respects_per_params(self):
|
||||
context = LLMContext()
|
||||
pair = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(add_tool_change_messages=True),
|
||||
assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False),
|
||||
)
|
||||
self.assertTrue(pair.user()._add_tool_change_messages)
|
||||
self.assertFalse(pair.assistant()._add_tool_change_messages)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,6 +8,8 @@ import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallFromLLM,
|
||||
@@ -21,6 +23,10 @@ from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy
|
||||
|
||||
|
||||
def _expected_missing_tool_message(name: str) -> str:
|
||||
return LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=name)
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
"""Minimal LLM service for testing function call execution."""
|
||||
|
||||
@@ -104,13 +110,14 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(recorded_frames[1].function_name, "missing_tool")
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'missing_tool' is not registered.",
|
||||
_expected_missing_tool_message("missing_tool"),
|
||||
)
|
||||
|
||||
# Only the queue-time warning should fire; the execution-time
|
||||
# "just unregistered" warning must not double-log.
|
||||
# The tool was not advertised, so this is treated as a hallucination
|
||||
# (warning at queue time). The execution-time "just unregistered"
|
||||
# warning must not double-log.
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertTrue(any("not registered" in w for w in warnings))
|
||||
self.assertTrue(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
self.assertFalse(any("just unregistered" in w for w in warnings))
|
||||
|
||||
async def test_function_unregistered_between_queue_and_execute(self):
|
||||
@@ -160,9 +167,124 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'doomed_tool' is not registered.",
|
||||
_expected_missing_tool_message("doomed_tool"),
|
||||
)
|
||||
|
||||
async def test_missing_function_call_dev_error_logged_as_error(self):
|
||||
"""Tool advertised to the LLM but missing a handler → logger.error."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[
|
||||
FunctionSchema(
|
||||
name="advertised_but_unhandled",
|
||||
description="",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="advertised_but_unhandled",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertTrue(
|
||||
any(
|
||||
"advertised" in e and "register_function" in e and "advertised_but_unhandled" in e
|
||||
for e in errors
|
||||
),
|
||||
f"expected dev-error log; got errors={errors}, warnings={warnings}",
|
||||
)
|
||||
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
|
||||
async def test_missing_function_call_hallucination_logged_as_warning(self):
|
||||
"""Tool not advertised to the LLM → logger.warning (hallucination)."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
tools=ToolsSchema(
|
||||
standard_tools=[
|
||||
FunctionSchema(
|
||||
name="something_else",
|
||||
description="",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="never_advertised",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
self.assertTrue(
|
||||
any(
|
||||
"not in the currently advertised tool set" in w and "never_advertised" in w
|
||||
for w in warnings
|
||||
),
|
||||
f"expected hallucination warning; got warnings={warnings}, errors={errors}",
|
||||
)
|
||||
self.assertFalse(any("advertised" in e and "register_function" in e for e in errors))
|
||||
|
||||
async def test_catch_all_handler_suppresses_missing_warnings(self):
|
||||
"""register_function(None, ...) suppresses both dev-error and hallucination logs."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
async def catch_all(params):
|
||||
await params.result_callback("handled")
|
||||
|
||||
service.register_function(None, catch_all)
|
||||
|
||||
with patch("pipecat.services.llm_service.logger") as mock_logger:
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="anything",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
errors = [c.args[0] for c in mock_logger.error.call_args_list]
|
||||
warnings = [c.args[0] for c in mock_logger.warning.call_args_list]
|
||||
self.assertFalse(any("register_function" in e for e in errors))
|
||||
self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings))
|
||||
|
||||
async def test_missing_function_call_allows_user_mute_cleanup(self):
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
|
||||
Reference in New Issue
Block a user