fix: re-resolve registry item at execution time
Address review feedback: a function may be unregistered between when run_function_calls queues it and when _run_function_call executes it. Restore the live lookup, falling back to the missing-function handler when the entry is gone, so the call still terminates with a normal tool result. Factor the missing-handler item construction into a helper since it's now built in two places.
This commit is contained in:
@@ -725,10 +725,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
logger.warning(
|
||||
f"{self} is calling '{function_call.function_name}', but it's not registered."
|
||||
)
|
||||
item = FunctionCallRegistryItem(
|
||||
function_name=function_call.function_name,
|
||||
handler=self._missing_function_call_handler,
|
||||
cancel_on_interruption=True,
|
||||
item = self._build_missing_function_call_registry_item(
|
||||
function_call.function_name
|
||||
)
|
||||
|
||||
runner_items.append(
|
||||
@@ -786,7 +784,21 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
item = runner_item.registry_item
|
||||
# Re-resolve the registry item at execution time. The function may have
|
||||
# been unregistered between queuing and execution, in which case we
|
||||
# fall back to the missing-function handler so the call still terminates
|
||||
# with a normal tool result.
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
elif runner_item.registry_item.handler is self._missing_function_call_handler:
|
||||
item = runner_item.registry_item
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self} is calling '{runner_item.function_name}', but it was just unregistered."
|
||||
)
|
||||
item = self._build_missing_function_call_registry_item(runner_item.function_name)
|
||||
|
||||
logger.debug(
|
||||
f"{self} Calling function [{runner_item.function_name}:{runner_item.tool_call_id}] with arguments {runner_item.arguments}"
|
||||
@@ -893,6 +905,16 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
if timeout_task and not timeout_task.done():
|
||||
await self.cancel_task(timeout_task)
|
||||
|
||||
def _build_missing_function_call_registry_item(
|
||||
self, function_name: str
|
||||
) -> FunctionCallRegistryItem:
|
||||
"""Build a registry item that routes to the missing-function handler."""
|
||||
return FunctionCallRegistryItem(
|
||||
function_name=function_name,
|
||||
handler=self._missing_function_call_handler,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
|
||||
async def _missing_function_call_handler(self, params: FunctionCallParams):
|
||||
"""Return a terminal tool result when the LLM calls an unknown function."""
|
||||
await params.result_callback(
|
||||
|
||||
@@ -85,6 +85,56 @@ class TestLLMService(unittest.IsolatedAsyncioTestCase):
|
||||
"Error: function 'missing_tool' is not registered.",
|
||||
)
|
||||
|
||||
async def test_function_unregistered_between_queue_and_execute(self):
|
||||
"""Function unregistered between queuing and execution still terminates."""
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
|
||||
async def real_handler(params):
|
||||
await params.result_callback("should not be called")
|
||||
|
||||
service.register_function("doomed_tool", real_handler)
|
||||
|
||||
recorded_frames = []
|
||||
|
||||
async def mock_broadcast_frame(frame_cls, **kwargs):
|
||||
recorded_frames.append(frame_cls(**kwargs))
|
||||
|
||||
service.broadcast_frame = mock_broadcast_frame
|
||||
|
||||
async def run_inline(runner_items):
|
||||
# Simulate the function being unregistered after queuing but before execution.
|
||||
service.unregister_function("doomed_tool")
|
||||
for runner_item in runner_items:
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
service._run_parallel_function_calls = run_inline
|
||||
service._run_sequential_function_calls = run_inline
|
||||
|
||||
await service.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name="doomed_tool",
|
||||
tool_call_id="call_1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[type(frame) for frame in recorded_frames],
|
||||
[
|
||||
FunctionCallsStartedFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
recorded_frames[2].result,
|
||||
"Error: function 'doomed_tool' is not registered.",
|
||||
)
|
||||
|
||||
async def test_missing_function_call_allows_user_mute_cleanup(self):
|
||||
service = MockLLMService()
|
||||
service._call_event_handler = AsyncMock()
|
||||
|
||||
Reference in New Issue
Block a user