fix: re-resolve registry item at execution time

Address review feedback: a function may be unregistered between when
run_function_calls queues it and when _run_function_call executes it.
Restore the live lookup, falling back to the missing-function handler
when the entry is gone, so the call still terminates with a normal
tool result. Factor the missing-handler item construction into a
helper since it's now built in two places.
This commit is contained in:
borislav
2026-04-27 17:22:30 +02:00
parent 14cf783647
commit 822392b0d4
2 changed files with 77 additions and 5 deletions

View File

@@ -725,10 +725,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
logger.warning(
f"{self} is calling '{function_call.function_name}', but it's not registered."
)
item = FunctionCallRegistryItem(
function_name=function_call.function_name,
handler=self._missing_function_call_handler,
cancel_on_interruption=True,
item = self._build_missing_function_call_registry_item(
function_call.function_name
)
runner_items.append(
@@ -786,7 +784,21 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
await self._sequential_runner_queue.put(runner_item)
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
item = runner_item.registry_item
# Re-resolve the registry item at execution time. The function may have
# been unregistered between queuing and execution, in which case we
# fall back to the missing-function handler so the call still terminates
# with a normal tool result.
if runner_item.function_name in self._functions.keys():
item = self._functions[runner_item.function_name]
elif None in self._functions.keys():
item = self._functions[None]
elif runner_item.registry_item.handler is self._missing_function_call_handler:
item = runner_item.registry_item
else:
logger.warning(
f"{self} is calling '{runner_item.function_name}', but it was just unregistered."
)
item = self._build_missing_function_call_registry_item(runner_item.function_name)
logger.debug(
f"{self} Calling function [{runner_item.function_name}:{runner_item.tool_call_id}] with arguments {runner_item.arguments}"
@@ -893,6 +905,16 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
if timeout_task and not timeout_task.done():
await self.cancel_task(timeout_task)
def _build_missing_function_call_registry_item(
self, function_name: str
) -> FunctionCallRegistryItem:
"""Build a registry item that routes to the missing-function handler."""
return FunctionCallRegistryItem(
function_name=function_name,
handler=self._missing_function_call_handler,
cancel_on_interruption=True,
)
async def _missing_function_call_handler(self, params: FunctionCallParams):
"""Return a terminal tool result when the LLM calls an unknown function."""
await params.result_callback(

View File

@@ -85,6 +85,56 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
"Error: function 'missing_tool' is not registered.",
)
async def test_function_unregistered_between_queue_and_execute(self):
"""Function unregistered between queuing and execution still terminates."""
service = MockLLMService()
service._call_event_handler = AsyncMock()
async def real_handler(params):
await params.result_callback("should not be called")
service.register_function("doomed_tool", real_handler)
recorded_frames = []
async def mock_broadcast_frame(frame_cls, **kwargs):
recorded_frames.append(frame_cls(**kwargs))
service.broadcast_frame = mock_broadcast_frame
async def run_inline(runner_items):
# Simulate the function being unregistered after queuing but before execution.
service.unregister_function("doomed_tool")
for runner_item in runner_items:
await service._run_function_call(runner_item)
service._run_parallel_function_calls = run_inline
service._run_sequential_function_calls = run_inline
await service.run_function_calls(
[
FunctionCallFromLLM(
function_name="doomed_tool",
tool_call_id="call_1",
arguments={},
context=LLMContext(),
)
]
)
self.assertEqual(
[type(frame) for frame in recorded_frames],
[
FunctionCallsStartedFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
],
)
self.assertEqual(
recorded_frames[2].result,
"Error: function 'doomed_tool' is not registered.",
)
async def test_missing_function_call_allows_user_mute_cleanup(self):
service = MockLLMService()
service._call_event_handler = AsyncMock()