Addressing PR review comments.

This commit is contained in:
filipi87
2026-04-09 17:11:04 -03:00
parent 772fb57090
commit 5cf90cba98
3 changed files with 75 additions and 34 deletions

View File

@@ -10,6 +10,7 @@ This module provides the abstract base class for implementing LLM provider-speci
adapters that handle tool format conversion and standardization.
"""
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar
@@ -49,18 +50,19 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
def __init__(self):
"""Initialize the adapter."""
self._warned_system_instruction = False
self._builtin_tools: List[FunctionSchema] = []
self._builtin_tools: Dict[str, FunctionSchema] = {}
@property
def builtin_tools(self) -> List[FunctionSchema]:
def builtin_tools(self) -> Dict[str, FunctionSchema]:
"""Built-in tools automatically merged into every inference request.
Mixins (e.g. ``AsyncToolCancellationLLMServiceMixin``) append their
tool schemas here so that the tools are injected transparently without
the user having to add them to their ``ToolsSchema``.
Keyed by tool name for O(1) lookup, insertion, and removal. The
service injects tools here so they are sent transparently on every
inference request without the user having to add them to their
``ToolsSchema``.
Returns:
Mutable list of ``FunctionSchema`` instances.
Mutable dict mapping tool name to ``FunctionSchema``.
"""
return self._builtin_tools
@@ -150,23 +152,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
if self._builtin_tools:
if isinstance(tools, ToolsSchema):
tools = ToolsSchema(
standard_tools=tools.standard_tools + self._builtin_tools,
standard_tools=tools.standard_tools + list(self._builtin_tools.values()),
custom_tools=tools.custom_tools,
)
else:
# User supplied tools in a legacy/provider-specific format;
# we cannot safely merge — build a schema from builtins only.
# User supplied tools in a legacy/provider-specific format.
# Built-in tools cannot be safely merged, so they will not be injected.
# Migrate to ToolsSchema to enable built-in tool support; use custom_tools
# as an escape hatch for any provider-specific tools that don't fit the
# standard schema.
if tools is not None:
logger.warning(
"Built-in tools could not be merged because the supplied tools are not"
" a ToolsSchema instance. Only built-in tools will be sent."
warnings.warn(
"Built-in tools (e.g. async tool cancellation) could not be injected "
"because the supplied tools are not a ToolsSchema instance. "
"Migrate to ToolsSchema to enable built-in tool support. "
"Use ToolsSchema(custom_tools=...) as an escape hatch for any "
"provider-specific tools that don't fit the standard schema.",
DeprecationWarning,
stacklevel=2,
)
tools = ToolsSchema(standard_tools=self._builtin_tools)
# Fall through and return the original tools unchanged.
if isinstance(tools, ToolsSchema):
logger.debug(f"Retrieving the tools using the adapter: {type(self)}")
tool_names = [tool.name for tool in tools.standard_tools]
logger.debug(f"Tool names: {tool_names}")
return self.to_provider_tools_format(tools)
# Fallback to return the same tools in case they are not in a standard format
return tools

View File

@@ -207,6 +207,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
run_in_parallel: bool = True,
group_parallel_tools: bool = True,
function_call_timeout_secs: Optional[float] = None,
enable_async_tool_cancellation: bool = False,
settings: Optional[LLMSettings] = None,
**kwargs,
):
@@ -221,6 +222,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
arrives. Defaults to True.
function_call_timeout_secs: Optional timeout in seconds for deferred function
calls.
enable_async_tool_cancellation: When True and at least one async function
(``cancel_on_interruption=False``) is registered, automatically injects
the ``cancel_async_tool_call`` built-in tool and its system instructions
so the LLM can cancel stale in-progress calls. Defaults to False.
settings: The runtime-updatable settings for the LLM service.
**kwargs: Additional arguments passed to the parent AIService.
@@ -235,8 +240,9 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
self._run_in_parallel = run_in_parallel
self._group_parallel_tools = group_parallel_tools
self._function_call_timeout_secs = function_call_timeout_secs
self._enable_async_tool_cancellation: bool = enable_async_tool_cancellation
self._filter_incomplete_user_turns: bool = False
self._async_cancellation_enabled: bool = False
self._async_tool_cancellation_enabled: bool = False
self._base_system_instruction: Optional[str] = None
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
@@ -298,7 +304,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
await super().start(frame)
if not self._run_in_parallel:
await self._create_sequential_runner_task()
if self._has_async_functions():
if self._enable_async_tool_cancellation and self._has_async_tools():
self._setup_async_tool_cancellation()
async def stop(self, frame: EndFrame):
@@ -334,7 +340,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
parts = [base] if base else []
if self._filter_incomplete_user_turns:
parts.append(self._user_turn_completion_config.completion_instructions)
if self._async_cancellation_enabled:
if self._async_tool_cancellation_enabled:
parts.append(ASYNC_TOOL_CANCELLATION_INSTRUCTIONS)
composed = "\n\n".join(p for p in parts if p)
self._settings.system_instruction = composed or None
@@ -373,7 +379,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
if (
"system_instruction" in changed
and (self._filter_incomplete_user_turns or self._async_cancellation_enabled)
and (self._filter_incomplete_user_turns or self._async_tool_cancellation_enabled)
and "filter_incomplete_user_turns" not in changed
):
# system_instruction changed while composition is active.
@@ -588,6 +594,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
``function_call_timeout_secs`` for this specific function. Defaults to
None, which uses the global timeout.
"""
if function_name == CANCEL_ASYNC_TOOL_NAME:
raise ValueError(
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
"registered by user code."
)
# Registering a function with the function_name set to None will run
# that handler for all functions
self._functions[function_name] = FunctionCallRegistryItem(
@@ -622,6 +633,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
None, which uses the global timeout.
"""
wrapper = DirectFunctionWrapper(handler)
if wrapper.name == CANCEL_ASYNC_TOOL_NAME:
raise ValueError(
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
"registered by user code."
)
self._functions[wrapper.name] = FunctionCallRegistryItem(
function_name=wrapper.name,
handler=wrapper,
@@ -636,6 +652,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
function_name: The name of the function handler to remove.
"""
del self._functions[function_name]
if self._async_tool_cancellation_enabled and not self._has_async_tools():
self._teardown_async_tool_cancellation()
def unregister_direct_function(self, handler: Any):
"""Remove a registered direct function handler.
@@ -646,6 +664,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
wrapper = DirectFunctionWrapper(handler)
del self._functions[wrapper.name]
# Note: no need to remove start callback here, as direct functions don't support start callbacks.
if self._async_tool_cancellation_enabled and not self._has_async_tools():
self._teardown_async_tool_cancellation()
def has_function(self, function_name: str):
"""Check if a function handler is registered.
@@ -861,8 +881,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
if timeout_task and not timeout_task.done():
await self.cancel_task(timeout_task)
def _has_async_functions(self) -> bool:
"""Return True if at least one non-builtin async function is registered."""
def _has_async_tools(self) -> bool:
"""Return True if at least one non-builtin async tool is registered."""
return any(
not item.cancel_on_interruption
for name, item in self._functions.items()
@@ -874,19 +894,18 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
Saves the base system instruction, recomposes to include cancellation
instructions, registers the built-in ``cancel_async_tool_call`` handler,
and injects its schema into the adapter's built-in tool list.
and injects its schema into the adapter's built-in tool dict.
"""
logger.debug(f"{self}: Enabling async tool cancellation")
self._async_cancellation_enabled = True
self._async_tool_cancellation_enabled = True
if self._base_system_instruction is None:
self._base_system_instruction = self._settings.system_instruction
self._compose_system_instruction()
if not any(t.name == CANCEL_ASYNC_TOOL_NAME for t in self._adapter.builtin_tools):
self._adapter.builtin_tools.append(CANCEL_ASYNC_TOOL_SCHEMA)
self._adapter.builtin_tools[CANCEL_ASYNC_TOOL_NAME] = CANCEL_ASYNC_TOOL_SCHEMA
if CANCEL_ASYNC_TOOL_NAME not in self._functions:
self._functions[CANCEL_ASYNC_TOOL_NAME] = FunctionCallRegistryItem(
@@ -895,13 +914,26 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
cancel_on_interruption=True,
)
def _teardown_async_tool_cancellation(self):
"""Disable async tool cancellation.
Removes the built-in ``cancel_async_tool_call`` handler and its schema,
recomposes the system instruction without cancellation instructions.
"""
logger.debug(f"{self}: Disabling async tool cancellation")
self._async_tool_cancellation_enabled = False
self._adapter.builtin_tools.pop(CANCEL_ASYNC_TOOL_NAME, None)
self._functions.pop(CANCEL_ASYNC_TOOL_NAME, None)
self._compose_system_instruction()
async def _cancel_async_tool_call_handler(self, params: "FunctionCallParams"):
"""Handle a ``cancel_async_tool_call`` invocation from the LLM.
Args:
params: Function call parameters containing ``tool_call_id`` to cancel.
"""
logger.info("_cancel_async_tool_call_handler invoked!")
logger.debug(f"{self}: cancel_async_tool_call invoked")
tool_call_id: Optional[str] = params.arguments.get("tool_call_id")
if not tool_call_id:

View File

@@ -19,30 +19,32 @@ CANCEL_ASYNC_TOOL_NAME = "cancel_async_tool_call"
ASYNC_TOOL_CANCELLATION_INSTRUCTIONS = """ASYNC TOOL CANCELLATION:
Some tool calls run asynchronously in the background. When one starts, a tool response \
is added to the conversation whose content is a JSON object with \
"type": "tool", "status": "started", and a "tool_call_id" field containing the \
exact ID of that call (e.g. {"type": "tool", "status": "started", "tool_call_id": "..."}).
"type": "async_tool", "status": "running", and a "tool_call_id" field containing the \
exact ID of that call (e.g. {"type": "async_tool", "status": "running", "tool_call_id": "..."}).
If the user changes topic, explicitly says they no longer need the result, or the pending \
result would clearly be stale, call cancel_async_tool_call. \
To find the correct tool_call_id: locate the most recent tool response in the conversation \
whose content has "status": "started" and whose call has NOT already been cancelled, \
whose content has "status": "running" and whose call has NOT already been cancelled, \
then copy the "tool_call_id" value from that content exactly as-is. \
Never invent or guess a tool_call_id."""
CANCEL_ASYNC_TOOL_SCHEMA = FunctionSchema(
name=CANCEL_ASYNC_TOOL_NAME,
description=(
"Cancel a single async tool call that is no longer needed. "
"Cancel a single async tool call whose results are no longer needed. "
"Use this when the user changes topic, indicates a pending result is "
"no longer relevant, or when processing the result would produce a "
"stale or confusing response. "
"The tool_call_id must be the exact 'id' value from the assistant's "
"tool call which we wish to cancel, visible in the conversation history."
"The tool_call_id must be copied exactly from the 'tool_call_id' field "
"in the async tool's 'running' response visible in the conversation history."
),
properties={
"tool_call_id": {
"type": "string",
"description": ("The exact id of the async call to cancel."),
"description": (
"The exact tool_call_id from the async tool's 'running' response to cancel."
),
}
},
required=["tool_call_id"],