From e06e0c02823d9d09a5ca8aa869549b84bbdaff46 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Tue, 5 May 2026 13:02:43 -0400 Subject: [PATCH] Mitigate tool-call-related hallucination MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When tools change mid-conversation, LLMs can produce a few different flavors of tool-call-related hallucination: calling tools that have been removed, avoiding tools that have been re-added, or hallucinating output (made-up answers or tool-call-shaped non-tool-calls) when tools are unavailable. This change introduces an opt-in ``add_tool_change_messages`` flag on the LLM aggregators (preferred entry point: ``LLMContextAggregatorPair( ..., add_tool_change_messages=True)``) that appends a developer-role message to the context whenever ``LLMSetToolsFrame`` changes the set of advertised standard tools. Helps the LLM stay coherent across tool changes by spelling out exactly what just became available or unavailable. Both aggregators participate; whichever handles the frame first wins, and the other (if any) sees an empty diff against the shared context and stays silent — order-independent regardless of whether the frame flows downstream or upstream. Also tightens the existing missing-handler path (introduced in #4301): - Reworded the terminal tool result to a neutral "The function ``X`` is not currently available." (overridable via ``LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE``). Previously read "Error: function 'X' is not registered." - Logs at the call site now distinguish developer error (tool advertised but no handler registered → ``logger.error``) from hallucination (tool not advertised → ``logger.warning``). Includes a manual validation harness (``examples/features/features-add-tool-change-messages.py``) that exercises the new ``add_tool_change_messages`` mitigation by flipping tool availability on a turn counter so its effect can be observed end-to-end with the flag on vs. off. --- .../features-add-tool-change-messages.py | 232 ++++++++++++++++++ .../aggregators/llm_response_universal.py | 132 +++++++++- src/pipecat/services/llm_service.py | 60 ++++- tests/test_context_aggregators_universal.py | 204 +++++++++++++++ tests/test_llm_service.py | 132 +++++++++- 5 files changed, 745 insertions(+), 15 deletions(-) create mode 100644 examples/features/features-add-tool-change-messages.py diff --git a/examples/features/features-add-tool-change-messages.py b/examples/features/features-add-tool-change-messages.py new file mode 100644 index 000000000..4ad69a819 --- /dev/null +++ b/examples/features/features-add-tool-change-messages.py @@ -0,0 +1,232 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Manual validation harness for the ``add_tool_change_messages`` feature. + +When tools change mid-conversation, LLMs can produce a few different +flavors of tool-call-related hallucination: + +- **Forward hallucination** — calling a tool that has been removed. +- **Negative hallucination** — refusing to call a tool that has been + re-added (because recent context is full of "I can't" responses). +- **Hallucinated output when tools are unavailable** — making up an + answer rather than declining gracefully, or producing JSON that + *looks* like a tool call but is actually just an assistant text + response. + +The ``add_tool_change_messages`` feature mitigates these by appending a +developer-role message to the conversation whenever ``LLMSetToolsFrame`` +changes the set of advertised tools, so the LLM stays in sync with what's +actually available. + +This harness exercises all of those flavors by flipping the advertised +tool set on a turn counter: + + Phase 0 (turns 1–4): weather tool ACTIVE — confirm baseline. + Phase 1 (turns 5–8): tool REMOVED — keep asking for weather. + Phase 2 (turn 9+): tool RE-ADDED — does the LLM call it again? + +Set ``ADD_TOOL_CHANGE_MESSAGES=0`` to disable the mitigation and see the +unmitigated behavior. The default is ON so a fresh run shows the feature +working. + +Defaults to Llama 3.1 8B Instruct via a locally-running Ollama — +anecdotally one of the more hallucination-prone of the easily accessible +models. Pull the model once with ``ollama pull llama3.1:8b`` and make +sure ``ollama serve`` is running. Swap the LLM service to validate other +providers. + +Run with:: + + uv run examples/features/features-add-tool-change-messages.py + ADD_TOOL_CHANGE_MESSAGES=0 uv run examples/features/features-add-tool-change-messages.py +""" + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.services.ollama.llm import OLLamaLLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# Default ON so a fresh run shows the feature working. Set to "0" to A/B +# against the unmitigated behavior. +ADD_TOOL_CHANGE_MESSAGES = os.environ.get("ADD_TOOL_CHANGE_MESSAGES", "1") == "1" + + +async def fetch_weather_from_api(params: FunctionCallParams): + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], +) +weather_tools = ToolsSchema(standard_tools=[weather_function]) + + +transport_params = { + "daily": lambda: DailyParams(audio_in_enabled=True, audio_out_enabled=True), + "twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True), + "webrtc": lambda: TransportParams(audio_in_enabled=True, audio_out_enabled=True), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info( + f"Starting add_tool_change_messages demo bot " + f"(ADD_TOOL_CHANGE_MESSAGES={ADD_TOOL_CHANGE_MESSAGES})" + ) + + stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"]) + + tts = CartesiaTTSService( + api_key=os.environ["CARTESIA_API_KEY"], + settings=CartesiaTTSService.Settings( + voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ), + ) + + llm = OLLamaLLMService( + settings=OLLamaLLMService.Settings( + # Llama 3.1 8B Instruct is anecdotally one of the more + # hallucination-prone of the easily accessible models — exactly + # what we want for this validation harness. Pull it with + # ``ollama pull llama3.1:8b`` and make sure ``ollama serve`` + # is running. + model="llama3.1:8b", + system_instruction=( + "You are a helpful assistant in a voice conversation. Your responses " + "will be spoken aloud, so avoid emojis, bullet points, or other " + "formatting that can't be spoken. Respond briefly and naturally. " + "If the user asks for the current weather, use the `get_current_weather` " + "function if it's available. IMPORTANT: if you do not have access to the function, " + "say something along the lines of 'Sorry, I can't check the weather right now.'." + ), + ), + ) + llm.register_function("get_current_weather", fetch_weather_from_api) + + context = LLMContext(tools=weather_tools) + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), + add_tool_change_messages=ADD_TOOL_CHANGE_MESSAGES, + ) + + pipeline = Pipeline( + [ + transport.input(), + stt, + user_aggregator, + llm, + tts, + transport.output(), + assistant_aggregator, + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams(enable_metrics=True, enable_usage_metrics=True), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + # Phase controller: roughly 4 turns per phase. + user_turn_count = 0 + REMOVE_AT_TURN = 5 # tool gone for turn N onward + READD_AT_TURN = 9 # tool back for turn N onward + + @user_aggregator.event_handler("on_user_turn_stopped") + async def on_user_turn_stopped(aggregator, strategy, message): + nonlocal user_turn_count + user_turn_count += 1 + logger.info(f"=== User turn {user_turn_count} complete ===") + + if user_turn_count == REMOVE_AT_TURN - 1: + logger.info( + "=== Phase 1: weather tool REMOVED. Keep asking about the weather " + "to exercise hallucination scenarios. ===" + ) + await task.queue_frame(LLMSetToolsFrame(tools=NOT_GIVEN)) + elif user_turn_count == READD_AT_TURN - 1: + logger.info( + "=== Phase 2: weather tool RE-ADDED. Ask for the weather again — " + "does the LLM call it, or keep refusing? (THIS IS THE TEST.) ===" + ) + await task.queue_frame(LLMSetToolsFrame(tools=weather_tools)) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info("Client connected") + logger.info( + "=== Phase 0: weather tool ACTIVE. Ask for the weather a few times " + "to confirm it's working. ===" + ) + context.add_message( + { + "role": "developer", + "content": ( + "Please introduce yourself briefly to the user, then invite them " + "to ask about the weather." + ), + } + ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info("Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index bf910a0c4..bedb7ec92 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -72,6 +72,7 @@ from pipecat.processors.aggregators.llm_context import ( LLMContextMessage, LLMSpecificMessage, NotGiven, + is_given, ) from pipecat.processors.aggregators.llm_context_summarizer import ( LLMContextSummarizer, @@ -118,6 +119,18 @@ class LLMUserAggregatorParams: user_turn_completion_config: Configuration for turn completion behavior including custom instructions, timeouts, and prompts. Only used when filter_incomplete_user_turns is True. + add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the + aggregator computes the diff against the currently advertised tools + and appends a developer-role message to the context describing + additions/removals. Helps the LLM stay coherent across + mid-conversation tool changes, mitigating several flavors of + tool-call-related hallucination: calling tools that have been + removed, avoiding tools that have been re-added, and hallucinating + output (made-up answers or tool-call-shaped non-tool-calls) when + tools are unavailable. Only standard tools are diffed; custom + (LLM-specific) tools are ignored. When using + ``LLMContextAggregatorPair``, prefer setting this via its + ``add_tool_change_messages`` argument instead. Defaults to False. """ user_turn_strategies: UserTurnStrategies | None = None @@ -128,6 +141,7 @@ class LLMUserAggregatorParams: audio_idle_timeout: float = 1.0 filter_incomplete_user_turns: bool = False user_turn_completion_config: UserTurnCompletionConfig | None = None + add_tool_change_messages: bool = False @dataclass @@ -143,10 +157,23 @@ class LLMAssistantAggregatorParams: summarization. Controls trigger thresholds, message preservation, and summarization prompts. If None, uses default ``LLMAutoContextSummarizationConfig`` values. + add_tool_change_messages: When True, on each ``LLMSetToolsFrame`` the + aggregator computes the diff against the currently advertised tools + and appends a developer-role message to the context describing + additions/removals. Helps the LLM stay coherent across + mid-conversation tool changes, mitigating several flavors of + tool-call-related hallucination: calling tools that have been + removed, avoiding tools that have been re-added, and hallucinating + output (made-up answers or tool-call-shaped non-tool-calls) when + tools are unavailable. Only standard tools are diffed; custom + (LLM-specific) tools are ignored. When using + ``LLMContextAggregatorPair``, prefer setting this via its + ``add_tool_change_messages`` argument instead. Defaults to False. """ enable_auto_context_summarization: bool = False auto_context_summarization_config: LLMAutoContextSummarizationConfig | None = None + add_tool_change_messages: bool = False # --------------------------------------------------------------------------- # Deprecated field names — kept for backward compatibility. @@ -248,20 +275,87 @@ class LLMContextAggregator(FrameProcessor): common functionality for context-based conversation management. """ - def __init__(self, *, context: LLMContext, role: str, **kwargs): + # Developer-role messages appended to the context when tools are added/ + # removed via ``LLMSetToolsFrame`` (only when ``add_tool_change_messages`` + # is enabled on the aggregator's params). ``{function_names}`` is + # substituted with a sorted, comma-separated, backtick-wrapped list. + TOOL_ACTIVATION_MESSAGE_TEMPLATE = ( + "The following function(s) have just been added and may now be called: " + "{function_names}. Any previously available functions remain available." + ) + TOOL_DEACTIVATION_MESSAGE_TEMPLATE = ( + "The following function(s) have just been removed and should not be called: " + "{function_names}. Any previously available functions remain available. " + "The removed function(s) may become available again later, in which case " + "you will be informed." + ) + + def __init__( + self, + *, + context: LLMContext, + role: str, + add_tool_change_messages: bool = False, + **kwargs, + ): """Initialize the context response aggregator. Args: context: The LLM context to use for conversation storage. role: The role this aggregator represents (e.g. "user", "assistant"). + add_tool_change_messages: See the field of the same name on the + aggregator-specific params dataclasses. Subclasses propagate + this from their ``params``. **kwargs: Additional arguments passed to parent class. """ super().__init__(**kwargs) self._context = context self._role = role + self._add_tool_change_messages = add_tool_change_messages self._aggregation: list[TextPartForConcatenation] = [] + def _maybe_add_tool_change_messages(self, new_tools: ToolsSchema | NotGiven) -> None: + """Append a developer message describing tool add/remove deltas. + + No-op unless ``add_tool_change_messages`` was enabled on the aggregator, + and no-op when the diff against the currently advertised tools is empty. + Custom (LLM-specific) tools are ignored — only standard tools are diffed. + + Both aggregators call this on every ``LLMSetToolsFrame`` they handle. + Whichever aggregator handles the frame first computes a real diff + against the shared context and adds the announcement; by the time + the other aggregator sees it (if at all), the context already + reflects the new tools, so its diff is empty and no duplicate + message is added. This is order-independent: it works whether the + frame flows downstream (user aggregator first) or upstream + (assistant aggregator first, and consumed without being forwarded). + """ + if not self._add_tool_change_messages: + return + + def _names(tools: ToolsSchema | NotGiven) -> set[str]: + if not is_given(tools): + return set() + return {s.name for s in tools.standard_tools} + + old_names = _names(self._context.tools) + new_names = _names(new_tools) + added = new_names - old_names + removed = old_names - new_names + if not added and not removed: + return + + parts: list[str] = [] + if added: + names = ", ".join(f"`{n}`" for n in sorted(added)) + parts.append(self.TOOL_ACTIVATION_MESSAGE_TEMPLATE.format(function_names=names)) + if removed: + names = ", ".join(f"`{n}`" for n in sorted(removed)) + parts.append(self.TOOL_DEACTIVATION_MESSAGE_TEMPLATE.format(function_names=names)) + + self._context.add_message({"role": "developer", "content": " ".join(parts)}) + @property def messages(self) -> list[LLMContextMessage]: """Get messages from the LLM context. @@ -434,8 +528,14 @@ class LLMUserAggregator(LLMContextAggregator): params: Configuration parameters for aggregation behavior. **kwargs: Additional arguments. """ - super().__init__(context=context, role="user", **kwargs) - self._params = params or LLMUserAggregatorParams() + params = params or LLMUserAggregatorParams() + super().__init__( + context=context, + role="user", + add_tool_change_messages=params.add_tool_change_messages, + **kwargs, + ) + self._params = params self._register_event_handler("on_user_turn_started") self._register_event_handler("on_user_turn_stopped") @@ -536,6 +636,7 @@ class LLMUserAggregator(LLMContextAggregator): elif isinstance(frame, LLMMessagesTransformFrame): await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): + self._maybe_add_tool_change_messages(frame.tools) self.set_tools(frame.tools) # Push the LLMSetToolsFrame as well, since speech-to-speech LLM # services (like OpenAI Realtime) may need to know about tool @@ -843,8 +944,14 @@ class LLMAssistantAggregator(LLMContextAggregator): params: Configuration parameters for aggregation behavior. **kwargs: Additional arguments. """ - super().__init__(context=context, role="assistant", **kwargs) - self._params = params or LLMAssistantAggregatorParams() + params = params or LLMAssistantAggregatorParams() + super().__init__( + context=context, + role="assistant", + add_tool_change_messages=params.add_tool_change_messages, + **kwargs, + ) + self._params = params self._function_calls_in_progress: dict[str, FunctionCallInProgressFrame | None] = {} self._function_calls_image_results: dict[str, UserImageRawFrame] = {} @@ -949,6 +1056,7 @@ class LLMAssistantAggregator(LLMContextAggregator): elif isinstance(frame, LLMMessagesTransformFrame): await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): + self._maybe_add_tool_change_messages(frame.tools) self.set_tools(frame.tools) elif isinstance(frame, LLMSetToolChoiceFrame): self.set_tool_choice(frame.tool_choice) @@ -1478,6 +1586,7 @@ class LLMContextAggregatorPair: *, user_params: LLMUserAggregatorParams | None = None, assistant_params: LLMAssistantAggregatorParams | None = None, + add_tool_change_messages: bool | None = None, ): """Initialize the LLM context aggregator pair. @@ -1485,9 +1594,22 @@ class LLMContextAggregatorPair: context: The context to be managed by the aggregators. user_params: Parameters for the user context aggregator. assistant_params: Parameters for the assistant context aggregator. + add_tool_change_messages: When provided, sets the field of the + same name on both ``user_params`` and ``assistant_params``, + overriding any value already set on either. This is the + preferred way to enable tool-change announcements: it ensures + both aggregators participate, which makes the feature robust + regardless of which aggregator handles a given + ``LLMSetToolsFrame``. The shared context guarantees the + announcement is added exactly once (the second aggregator's + diff is empty by the time it sees the frame). Leave as + ``None`` to respect per-params settings. """ user_params = user_params or LLMUserAggregatorParams() assistant_params = assistant_params or LLMAssistantAggregatorParams() + if add_tool_change_messages is not None: + user_params.add_tool_change_messages = add_tool_change_messages + assistant_params.add_tool_change_messages = add_tool_change_messages self._user = LLMUserAggregator(context, params=user_params) self._assistant = LLMAssistantAggregator(context, params=assistant_params) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 95bf2c762..56eac7676 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -53,6 +53,7 @@ from pipecat.frames.frames import ( from pipecat.processors.aggregators.llm_context import ( LLMContext, LLMSpecificMessage, + is_given, ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService @@ -243,6 +244,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter] # However, subclasses should override this with a more specific adapter when necessary. adapter_class: type[BaseLLMAdapter] = OpenAILLMAdapter + # Returned to the LLM as the tool result when an unavailable function is + # called. Deliberately neutral about future availability so the LLM can + # pick the function up again if it returns (e.g. via the + # ``add_tool_change_messages`` activation message, or silently on a + # later inference). ``{function_name}`` is substituted at runtime. + MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE = ( + "The function `{function_name}` is not currently available." + ) + def __init__( self, run_in_parallel: bool = True, @@ -764,9 +774,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter] elif None in self._functions.keys(): item = self._functions[None] else: - logger.warning( - f"{self} is calling '{function_call.function_name}', but it's not registered." - ) + self._log_missing_function_call(function_call.function_name, function_call.context) item = self._build_missing_function_call_registry_item(function_call.function_name) runner_items.append( @@ -835,8 +843,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter] elif runner_item.registry_item.handler == self._missing_function_call_handler: item = runner_item.registry_item else: + # Function was unregistered between queue and execution; the + # registry-item-handler check above already covered the + # missing-from-the-start case. logger.warning( - f"{self} is calling '{runner_item.function_name}', but it was just unregistered." + f"{self}: '{runner_item.function_name}' was just unregistered " + f"between queueing and execution." ) item = self._build_missing_function_call_registry_item(runner_item.function_name) @@ -962,7 +974,45 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter] async def _missing_function_call_handler(self, params: FunctionCallParams): """Return a terminal tool result when the LLM calls an unknown function.""" - await params.result_callback(f"Error: function '{params.function_name}' is not registered.") + await params.result_callback( + self.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=params.function_name) + ) + + @staticmethod + def _advertised_tool_names(context) -> set[str]: + """Return the set of standard-tool names currently advertised to the LLM. + + Custom (LLM-specific) tools are not included, since they have no + consistent name field across adapters. + """ + tools = context.tools if context is not None else None + if tools is None or not is_given(tools): + return set() + return {t.name for t in tools.standard_tools} + + def _log_missing_function_call(self, function_name: str, context) -> None: + """Log an appropriate message when a tool is called with no handler. + + Distinguishes two cases: + + - **Developer error:** the tool is advertised to the LLM but no handler + was registered (likely a missed ``register_function`` call). Logged + at error level since this almost always indicates a bug. + - **Hallucination:** the tool is not in the currently advertised tool + set. Logged at warning level since this is model behavior the + application can do little about beyond returning a terminal result. + """ + if function_name in self._advertised_tool_names(context): + logger.error( + f"{self}: tool '{function_name}' is advertised to the LLM " + f"but has no registered handler — did you forget to call " + f"register_function()?" + ) + else: + logger.warning( + f"{self}: LLM called '{function_name}', which is not in the " + f"currently advertised tool set." + ) def _has_async_tools(self) -> bool: """Return True if at least one non-builtin async tool is registered.""" diff --git a/tests/test_context_aggregators_universal.py b/tests/test_context_aggregators_universal.py index 4cd195bee..1b9579129 100644 --- a/tests/test_context_aggregators_universal.py +++ b/tests/test_context_aggregators_universal.py @@ -7,6 +7,8 @@ import json import unittest +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, @@ -25,6 +27,7 @@ from pipecat.frames.frames import ( LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, + LLMSetToolsFrame, LLMTextFrame, LLMThoughtEndFrame, LLMThoughtStartFrame, @@ -46,6 +49,8 @@ from pipecat.processors.aggregators.llm_response_universal import ( AssistantThoughtMessage, AssistantTurnStoppedMessage, LLMAssistantAggregator, + LLMAssistantAggregatorParams, + LLMContextAggregatorPair, LLMUserAggregator, LLMUserAggregatorParams, ) @@ -1167,5 +1172,204 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): assert context.messages[0]["content"] == "HELLO" +def _function_schema(name: str) -> FunctionSchema: + return FunctionSchema(name=name, description="", properties={}, required=[]) + + +def _tools(*names: str) -> ToolsSchema: + return ToolsSchema(standard_tools=[_function_schema(n) for n in names]) + + +def _developer_messages(context: LLMContext) -> list[str]: + return [ + m["content"] + for m in context.messages + if isinstance(m, dict) and m.get("role") == "developer" + ] + + +class TestToolChangeMessages(unittest.IsolatedAsyncioTestCase): + """Coverage for the opt-in ``add_tool_change_messages`` feature. + + The feature appends a developer-role message to the context whenever + ``LLMSetToolsFrame`` changes the set of advertised standard tools. + """ + + async def _send_set_tools_to_user_aggregator(self, aggregator, tools): + # User aggregator forwards LLMSetToolsFrame downstream, so we expect + # the SpeechControlParamsFrame (emitted on StartFrame) and the + # forwarded LLMSetToolsFrame. + await run_test( + aggregator, + frames_to_send=[LLMSetToolsFrame(tools=tools)], + expected_down_frames=[SpeechControlParamsFrame, LLMSetToolsFrame], + ) + + async def test_default_off_adds_no_message(self): + context = LLMContext(tools=_tools("a")) + aggregator = LLMUserAggregator(context) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b")) + self.assertEqual(_developer_messages(context), []) + + async def test_user_aggregator_announces_additions(self): + context = LLMContext(tools=_tools("a")) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b", "c")) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("just been added", msgs[0]) + self.assertIn("`b`", msgs[0]) + self.assertIn("`c`", msgs[0]) + self.assertNotIn("removed", msgs[0]) + # Sorted, stable order + self.assertLess(msgs[0].index("`b`"), msgs[0].index("`c`")) + + async def test_user_aggregator_announces_removals(self): + context = LLMContext(tools=_tools("a", "b", "c")) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("a")) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("just been removed", msgs[0]) + self.assertIn("`b`", msgs[0]) + self.assertIn("`c`", msgs[0]) + self.assertNotIn("just been added", msgs[0]) + + async def test_user_aggregator_combined_add_and_remove(self): + context = LLMContext(tools=_tools("a", "b")) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("b", "c")) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("just been added", msgs[0]) + self.assertIn("`c`", msgs[0]) + self.assertIn("just been removed", msgs[0]) + self.assertIn("`a`", msgs[0]) + # Activation phrase appears before deactivation phrase. + self.assertLess(msgs[0].index("just been added"), msgs[0].index("just been removed")) + + async def test_no_message_when_diff_is_empty(self): + context = LLMContext(tools=_tools("a", "b")) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("a", "b")) + self.assertEqual(_developer_messages(context), []) + + async def test_set_tools_to_not_given_lists_all_as_removed(self): + from pipecat.processors.aggregators.llm_context import NOT_GIVEN + + context = LLMContext(tools=_tools("a", "b")) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, NOT_GIVEN) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("just been removed", msgs[0]) + self.assertIn("`a`", msgs[0]) + self.assertIn("`b`", msgs[0]) + + async def test_set_tools_from_not_given_lists_all_as_added(self): + context = LLMContext() # tools default to NOT_GIVEN + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + await self._send_set_tools_to_user_aggregator(aggregator, _tools("x", "y")) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("just been added", msgs[0]) + self.assertIn("`x`", msgs[0]) + self.assertIn("`y`", msgs[0]) + + async def test_custom_tools_only_change_no_message(self): + # Standard tools identical; only custom tools differ → no announcement. + context = LLMContext( + tools=ToolsSchema( + standard_tools=[_function_schema("a")], + custom_tools={AdapterType.OPENAI: [{"type": "web_search"}]}, + ) + ) + aggregator = LLMUserAggregator( + context, params=LLMUserAggregatorParams(add_tool_change_messages=True) + ) + new_tools = ToolsSchema( + standard_tools=[_function_schema("a")], + custom_tools={AdapterType.OPENAI: [{"type": "file_search"}]}, + ) + await self._send_set_tools_to_user_aggregator(aggregator, new_tools) + self.assertEqual(_developer_messages(context), []) + + async def test_pipeline_with_both_aggregators_announces_once(self): + """User agg runs first; assistant agg sees no diff and stays silent.""" + context = LLMContext(tools=_tools("a")) + user, assistant = LLMContextAggregatorPair(context, add_tool_change_messages=True) + pipeline = Pipeline([user, assistant]) + # The user aggregator forwards LLMSetToolsFrame downstream; the + # assistant aggregator consumes it (does not forward). + await run_test( + pipeline, + frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))], + expected_down_frames=[SpeechControlParamsFrame], + ) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1, f"expected exactly one announcement, got {msgs}") + self.assertIn("`b`", msgs[0]) + + async def test_assistant_aggregator_announces_when_handled_first(self): + """Order-independence: an upstream LLMSetToolsFrame hits the assistant + aggregator first (before being consumed). It should announce, and the + user aggregator (which never sees it) shouldn't matter for correctness. + """ + context = LLMContext(tools=_tools("a")) + assistant = LLMAssistantAggregator( + context, + params=LLMAssistantAggregatorParams(add_tool_change_messages=True), + ) + # Send the frame upstream so the assistant aggregator processes it. + await run_test( + assistant, + frames_to_send=[LLMSetToolsFrame(tools=_tools("a", "b"))], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=[], + ) + msgs = _developer_messages(context) + self.assertEqual(len(msgs), 1) + self.assertIn("`b`", msgs[0]) + + async def test_pair_propagates_flag_to_both(self): + context = LLMContext() + pair = LLMContextAggregatorPair(context, add_tool_change_messages=True) + self.assertTrue(pair.user()._add_tool_change_messages) + self.assertTrue(pair.assistant()._add_tool_change_messages) + + async def test_pair_arg_overrides_per_params_settings(self): + context = LLMContext() + pair = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(add_tool_change_messages=False), + assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False), + add_tool_change_messages=True, + ) + self.assertTrue(pair.user()._add_tool_change_messages) + self.assertTrue(pair.assistant()._add_tool_change_messages) + + async def test_pair_default_respects_per_params(self): + context = LLMContext() + pair = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(add_tool_change_messages=True), + assistant_params=LLMAssistantAggregatorParams(add_tool_change_messages=False), + ) + self.assertTrue(pair.user()._add_tool_change_messages) + self.assertFalse(pair.assistant()._add_tool_change_messages) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_llm_service.py b/tests/test_llm_service.py index 707c255f3..9476f42a1 100644 --- a/tests/test_llm_service.py +++ b/tests/test_llm_service.py @@ -8,6 +8,8 @@ import unittest from unittest.mock import AsyncMock, patch from pipecat.adapters.base_llm_adapter import BaseLLMAdapter +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter from pipecat.frames.frames import ( FunctionCallFromLLM, @@ -21,6 +23,10 @@ from pipecat.services.settings import LLMSettings from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy +def _expected_missing_tool_message(name: str) -> str: + return LLMService.MISSING_FUNCTION_CALL_MESSAGE_TEMPLATE.format(function_name=name) + + class MockLLMService(LLMService): """Minimal LLM service for testing function call execution.""" @@ -104,13 +110,14 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase): self.assertEqual(recorded_frames[1].function_name, "missing_tool") self.assertEqual( recorded_frames[2].result, - "Error: function 'missing_tool' is not registered.", + _expected_missing_tool_message("missing_tool"), ) - # Only the queue-time warning should fire; the execution-time - # "just unregistered" warning must not double-log. + # The tool was not advertised, so this is treated as a hallucination + # (warning at queue time). The execution-time "just unregistered" + # warning must not double-log. warnings = [c.args[0] for c in mock_logger.warning.call_args_list] - self.assertTrue(any("not registered" in w for w in warnings)) + self.assertTrue(any("not in the currently advertised tool set" in w for w in warnings)) self.assertFalse(any("just unregistered" in w for w in warnings)) async def test_function_unregistered_between_queue_and_execute(self): @@ -160,9 +167,124 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase): ) self.assertEqual( recorded_frames[2].result, - "Error: function 'doomed_tool' is not registered.", + _expected_missing_tool_message("doomed_tool"), ) + async def test_missing_function_call_dev_error_logged_as_error(self): + """Tool advertised to the LLM but missing a handler → logger.error.""" + service = MockLLMService() + service._call_event_handler = AsyncMock() + await self._run_function_calls_inline(service) + service.broadcast_frame = AsyncMock() + + context = LLMContext( + tools=ToolsSchema( + standard_tools=[ + FunctionSchema( + name="advertised_but_unhandled", + description="", + properties={}, + required=[], + ) + ] + ) + ) + + with patch("pipecat.services.llm_service.logger") as mock_logger: + await service.run_function_calls( + [ + FunctionCallFromLLM( + function_name="advertised_but_unhandled", + tool_call_id="call_1", + arguments={}, + context=context, + ) + ] + ) + + errors = [c.args[0] for c in mock_logger.error.call_args_list] + warnings = [c.args[0] for c in mock_logger.warning.call_args_list] + self.assertTrue( + any( + "advertised" in e and "register_function" in e and "advertised_but_unhandled" in e + for e in errors + ), + f"expected dev-error log; got errors={errors}, warnings={warnings}", + ) + self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings)) + + async def test_missing_function_call_hallucination_logged_as_warning(self): + """Tool not advertised to the LLM → logger.warning (hallucination).""" + service = MockLLMService() + service._call_event_handler = AsyncMock() + await self._run_function_calls_inline(service) + service.broadcast_frame = AsyncMock() + + context = LLMContext( + tools=ToolsSchema( + standard_tools=[ + FunctionSchema( + name="something_else", + description="", + properties={}, + required=[], + ) + ] + ) + ) + + with patch("pipecat.services.llm_service.logger") as mock_logger: + await service.run_function_calls( + [ + FunctionCallFromLLM( + function_name="never_advertised", + tool_call_id="call_1", + arguments={}, + context=context, + ) + ] + ) + + warnings = [c.args[0] for c in mock_logger.warning.call_args_list] + errors = [c.args[0] for c in mock_logger.error.call_args_list] + self.assertTrue( + any( + "not in the currently advertised tool set" in w and "never_advertised" in w + for w in warnings + ), + f"expected hallucination warning; got warnings={warnings}, errors={errors}", + ) + self.assertFalse(any("advertised" in e and "register_function" in e for e in errors)) + + async def test_catch_all_handler_suppresses_missing_warnings(self): + """register_function(None, ...) suppresses both dev-error and hallucination logs.""" + service = MockLLMService() + service._call_event_handler = AsyncMock() + await self._run_function_calls_inline(service) + service.broadcast_frame = AsyncMock() + + async def catch_all(params): + await params.result_callback("handled") + + service.register_function(None, catch_all) + + with patch("pipecat.services.llm_service.logger") as mock_logger: + await service.run_function_calls( + [ + FunctionCallFromLLM( + function_name="anything", + tool_call_id="call_1", + arguments={}, + context=LLMContext(), + ) + ] + ) + + errors = [c.args[0] for c in mock_logger.error.call_args_list] + warnings = [c.args[0] for c in mock_logger.warning.call_args_list] + self.assertFalse(any("register_function" in e for e in errors)) + self.assertFalse(any("not in the currently advertised tool set" in w for w in warnings)) + async def test_missing_function_call_allows_user_mute_cleanup(self): service = MockLLMService() service._call_event_handler = AsyncMock()