fix: fail missing tool calls cleanly
This commit is contained in:
@@ -725,7 +725,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
logger.warning(
|
||||
f"{self} is calling '{function_call.function_name}', but it's not registered."
|
||||
)
|
||||
continue
|
||||
item = FunctionCallRegistryItem(
|
||||
function_name=function_call.function_name,
|
||||
handler=self._missing_function_call_handler,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
|
||||
runner_items.append(
|
||||
FunctionCallRunnerItem(
|
||||
@@ -782,12 +786,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
else:
|
||||
return
|
||||
item = runner_item.registry_item
|
||||
|
||||
logger.debug(
|
||||
f"{self} Calling function [{runner_item.function_name}:{runner_item.tool_call_id}] with arguments {runner_item.arguments}"
|
||||
@@ -894,6 +893,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
if timeout_task and not timeout_task.done():
|
||||
await self.cancel_task(timeout_task)
|
||||
|
||||
async def _missing_function_call_handler(self, params: FunctionCallParams):
|
||||
"""Return a terminal tool result when the LLM calls an unknown function."""
|
||||
await params.result_callback(
|
||||
f"Error: function '{params.function_name}' is not registered."
|
||||
)
|
||||
|
||||
def _has_async_tools(self) -> bool:
|
||||
"""Return True if at least one non-builtin async tool is registered."""
|
||||
return any(
|
||||
|
||||
116
tests/test_llm_service.py
Normal file
116
tests/test_llm_service.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallFromLLM,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_mute.function_call_user_mute_strategy import FunctionCallUserMuteStrategy
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
"""Minimal LLM service for testing function call execution."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
settings = LLMSettings(
|
||||
model="test-model",
|
||||
system_instruction=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=None,
|
||||
user_turn_completion_config=None,
|
||||
)
|
||||
super().__init__(settings=settings, **kwargs)
|
||||
|
||||
|
||||
class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
async def _run_function_calls_inline(self, service: MockLLMService):
|
||||
async def run_inline(runner_items):
|
||||
for runner_item in runner_items:
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
service._run_parallel_function_calls = run_inline
|
||||
service._run_sequential_function_calls = run_inline
|
||||
|
||||
async def test_missing_function_call_emits_terminal_result(self):
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
|
||||
recorded_frames = []
|
||||
|
||||
async def mock_broadcast_frame(frame_cls, **kwargs):
|
||||
recorded_frames.append(frame_cls(**kwargs))
|
||||
|
||||
service.broadcast_frame = mock_broadcast_frame
|
||||
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="missing_tool",
|
||||
tool_call_id="call_1",
|
||||
arguments={"query": "weather"},
|
||||
context=LLMContext(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[type(frame) for frame in recorded_frames],
|
||||
[
|
||||
FunctionCallsStartedFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
],
|
||||
)
|
||||
self.assertEqual(recorded_frames[1].function_name, "missing_tool")
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'missing_tool' is not registered.",
|
||||
)
|
||||
|
||||
async def test_missing_function_call_allows_user_mute_cleanup(self):
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
await self._run_function_calls_inline(service)
|
||||
|
||||
recorded_frames = []
|
||||
|
||||
async def mock_broadcast_frame(frame_cls, **kwargs):
|
||||
recorded_frames.append(frame_cls(**kwargs))
|
||||
|
||||
service.broadcast_frame = mock_broadcast_frame
|
||||
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="missing_tool",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
strategy = FunctionCallUserMuteStrategy()
|
||||
muted = False
|
||||
for frame in recorded_frames:
|
||||
muted = await strategy.process_frame(frame)
|
||||
|
||||
self.assertFalse(muted)
|
||||
Reference in New Issue
Block a user