Addressing PR review comments.
This commit is contained in:
@@ -10,6 +10,7 @@ This module provides the abstract base class for implementing LLM provider-speci
|
||||
adapters that handle tool format conversion and standardization.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||
|
||||
@@ -49,18 +50,19 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
def __init__(self):
|
||||
"""Initialize the adapter."""
|
||||
self._warned_system_instruction = False
|
||||
self._builtin_tools: List[FunctionSchema] = []
|
||||
self._builtin_tools: Dict[str, FunctionSchema] = {}
|
||||
|
||||
@property
|
||||
def builtin_tools(self) -> List[FunctionSchema]:
|
||||
def builtin_tools(self) -> Dict[str, FunctionSchema]:
|
||||
"""Built-in tools automatically merged into every inference request.
|
||||
|
||||
Mixins (e.g. ``AsyncToolCancellationLLMServiceMixin``) append their
|
||||
tool schemas here so that the tools are injected transparently without
|
||||
the user having to add them to their ``ToolsSchema``.
|
||||
Keyed by tool name for O(1) lookup, insertion, and removal. The
|
||||
service injects tools here so they are sent transparently on every
|
||||
inference request without the user having to add them to their
|
||||
``ToolsSchema``.
|
||||
|
||||
Returns:
|
||||
Mutable list of ``FunctionSchema`` instances.
|
||||
Mutable dict mapping tool name to ``FunctionSchema``.
|
||||
"""
|
||||
return self._builtin_tools
|
||||
|
||||
@@ -150,23 +152,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
if self._builtin_tools:
|
||||
if isinstance(tools, ToolsSchema):
|
||||
tools = ToolsSchema(
|
||||
standard_tools=tools.standard_tools + self._builtin_tools,
|
||||
standard_tools=tools.standard_tools + list(self._builtin_tools.values()),
|
||||
custom_tools=tools.custom_tools,
|
||||
)
|
||||
else:
|
||||
# User supplied tools in a legacy/provider-specific format;
|
||||
# we cannot safely merge — build a schema from builtins only.
|
||||
# User supplied tools in a legacy/provider-specific format.
|
||||
# Built-in tools cannot be safely merged, so they will not be injected.
|
||||
# Migrate to ToolsSchema to enable built-in tool support; use custom_tools
|
||||
# as an escape hatch for any provider-specific tools that don't fit the
|
||||
# standard schema.
|
||||
if tools is not None:
|
||||
logger.warning(
|
||||
"Built-in tools could not be merged because the supplied tools are not"
|
||||
" a ToolsSchema instance. Only built-in tools will be sent."
|
||||
warnings.warn(
|
||||
"Built-in tools (e.g. async tool cancellation) could not be injected "
|
||||
"because the supplied tools are not a ToolsSchema instance. "
|
||||
"Migrate to ToolsSchema to enable built-in tool support. "
|
||||
"Use ToolsSchema(custom_tools=...) as an escape hatch for any "
|
||||
"provider-specific tools that don't fit the standard schema.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=self._builtin_tools)
|
||||
# Fall through and return the original tools unchanged.
|
||||
|
||||
if isinstance(tools, ToolsSchema):
|
||||
logger.debug(f"Retrieving the tools using the adapter: {type(self)}")
|
||||
tool_names = [tool.name for tool in tools.standard_tools]
|
||||
logger.debug(f"Tool names: {tool_names}")
|
||||
return self.to_provider_tools_format(tools)
|
||||
# Fallback to return the same tools in case they are not in a standard format
|
||||
return tools
|
||||
|
||||
@@ -207,6 +207,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
run_in_parallel: bool = True,
|
||||
group_parallel_tools: bool = True,
|
||||
function_call_timeout_secs: Optional[float] = None,
|
||||
enable_async_tool_cancellation: bool = False,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -221,6 +222,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
arrives. Defaults to True.
|
||||
function_call_timeout_secs: Optional timeout in seconds for deferred function
|
||||
calls.
|
||||
enable_async_tool_cancellation: When True and at least one async function
|
||||
(``cancel_on_interruption=False``) is registered, automatically injects
|
||||
the ``cancel_async_tool_call`` built-in tool and its system instructions
|
||||
so the LLM can cancel stale in-progress calls. Defaults to False.
|
||||
settings: The runtime-updatable settings for the LLM service.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
@@ -235,8 +240,9 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._group_parallel_tools = group_parallel_tools
|
||||
self._function_call_timeout_secs = function_call_timeout_secs
|
||||
self._enable_async_tool_cancellation: bool = enable_async_tool_cancellation
|
||||
self._filter_incomplete_user_turns: bool = False
|
||||
self._async_cancellation_enabled: bool = False
|
||||
self._async_tool_cancellation_enabled: bool = False
|
||||
self._base_system_instruction: Optional[str] = None
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
@@ -298,7 +304,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
await super().start(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._create_sequential_runner_task()
|
||||
if self._has_async_functions():
|
||||
if self._enable_async_tool_cancellation and self._has_async_tools():
|
||||
self._setup_async_tool_cancellation()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -334,7 +340,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
parts = [base] if base else []
|
||||
if self._filter_incomplete_user_turns:
|
||||
parts.append(self._user_turn_completion_config.completion_instructions)
|
||||
if self._async_cancellation_enabled:
|
||||
if self._async_tool_cancellation_enabled:
|
||||
parts.append(ASYNC_TOOL_CANCELLATION_INSTRUCTIONS)
|
||||
composed = "\n\n".join(p for p in parts if p)
|
||||
self._settings.system_instruction = composed or None
|
||||
@@ -373,7 +379,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
|
||||
if (
|
||||
"system_instruction" in changed
|
||||
and (self._filter_incomplete_user_turns or self._async_cancellation_enabled)
|
||||
and (self._filter_incomplete_user_turns or self._async_tool_cancellation_enabled)
|
||||
and "filter_incomplete_user_turns" not in changed
|
||||
):
|
||||
# system_instruction changed while composition is active.
|
||||
@@ -588,6 +594,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
``function_call_timeout_secs`` for this specific function. Defaults to
|
||||
None, which uses the global timeout.
|
||||
"""
|
||||
if function_name == CANCEL_ASYNC_TOOL_NAME:
|
||||
raise ValueError(
|
||||
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
|
||||
"registered by user code."
|
||||
)
|
||||
# Registering a function with the function_name set to None will run
|
||||
# that handler for all functions
|
||||
self._functions[function_name] = FunctionCallRegistryItem(
|
||||
@@ -622,6 +633,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
None, which uses the global timeout.
|
||||
"""
|
||||
wrapper = DirectFunctionWrapper(handler)
|
||||
if wrapper.name == CANCEL_ASYNC_TOOL_NAME:
|
||||
raise ValueError(
|
||||
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
|
||||
"registered by user code."
|
||||
)
|
||||
self._functions[wrapper.name] = FunctionCallRegistryItem(
|
||||
function_name=wrapper.name,
|
||||
handler=wrapper,
|
||||
@@ -636,6 +652,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
function_name: The name of the function handler to remove.
|
||||
"""
|
||||
del self._functions[function_name]
|
||||
if self._async_tool_cancellation_enabled and not self._has_async_tools():
|
||||
self._teardown_async_tool_cancellation()
|
||||
|
||||
def unregister_direct_function(self, handler: Any):
|
||||
"""Remove a registered direct function handler.
|
||||
@@ -646,6 +664,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
wrapper = DirectFunctionWrapper(handler)
|
||||
del self._functions[wrapper.name]
|
||||
# Note: no need to remove start callback here, as direct functions don't support start callbacks.
|
||||
if self._async_tool_cancellation_enabled and not self._has_async_tools():
|
||||
self._teardown_async_tool_cancellation()
|
||||
|
||||
def has_function(self, function_name: str):
|
||||
"""Check if a function handler is registered.
|
||||
@@ -861,8 +881,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
if timeout_task and not timeout_task.done():
|
||||
await self.cancel_task(timeout_task)
|
||||
|
||||
def _has_async_functions(self) -> bool:
|
||||
"""Return True if at least one non-builtin async function is registered."""
|
||||
def _has_async_tools(self) -> bool:
|
||||
"""Return True if at least one non-builtin async tool is registered."""
|
||||
return any(
|
||||
not item.cancel_on_interruption
|
||||
for name, item in self._functions.items()
|
||||
@@ -874,19 +894,18 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
|
||||
Saves the base system instruction, recomposes to include cancellation
|
||||
instructions, registers the built-in ``cancel_async_tool_call`` handler,
|
||||
and injects its schema into the adapter's built-in tool list.
|
||||
and injects its schema into the adapter's built-in tool dict.
|
||||
"""
|
||||
logger.debug(f"{self}: Enabling async tool cancellation")
|
||||
|
||||
self._async_cancellation_enabled = True
|
||||
self._async_tool_cancellation_enabled = True
|
||||
|
||||
if self._base_system_instruction is None:
|
||||
self._base_system_instruction = self._settings.system_instruction
|
||||
|
||||
self._compose_system_instruction()
|
||||
|
||||
if not any(t.name == CANCEL_ASYNC_TOOL_NAME for t in self._adapter.builtin_tools):
|
||||
self._adapter.builtin_tools.append(CANCEL_ASYNC_TOOL_SCHEMA)
|
||||
self._adapter.builtin_tools[CANCEL_ASYNC_TOOL_NAME] = CANCEL_ASYNC_TOOL_SCHEMA
|
||||
|
||||
if CANCEL_ASYNC_TOOL_NAME not in self._functions:
|
||||
self._functions[CANCEL_ASYNC_TOOL_NAME] = FunctionCallRegistryItem(
|
||||
@@ -895,13 +914,26 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
|
||||
def _teardown_async_tool_cancellation(self):
|
||||
"""Disable async tool cancellation.
|
||||
|
||||
Removes the built-in ``cancel_async_tool_call`` handler and its schema,
|
||||
recomposes the system instruction without cancellation instructions.
|
||||
"""
|
||||
logger.debug(f"{self}: Disabling async tool cancellation")
|
||||
|
||||
self._async_tool_cancellation_enabled = False
|
||||
self._adapter.builtin_tools.pop(CANCEL_ASYNC_TOOL_NAME, None)
|
||||
self._functions.pop(CANCEL_ASYNC_TOOL_NAME, None)
|
||||
self._compose_system_instruction()
|
||||
|
||||
async def _cancel_async_tool_call_handler(self, params: "FunctionCallParams"):
|
||||
"""Handle a ``cancel_async_tool_call`` invocation from the LLM.
|
||||
|
||||
Args:
|
||||
params: Function call parameters containing ``tool_call_id`` to cancel.
|
||||
"""
|
||||
logger.info("_cancel_async_tool_call_handler invoked!")
|
||||
logger.debug(f"{self}: cancel_async_tool_call invoked")
|
||||
|
||||
tool_call_id: Optional[str] = params.arguments.get("tool_call_id")
|
||||
if not tool_call_id:
|
||||
|
||||
@@ -19,30 +19,32 @@ CANCEL_ASYNC_TOOL_NAME = "cancel_async_tool_call"
|
||||
ASYNC_TOOL_CANCELLATION_INSTRUCTIONS = """ASYNC TOOL CANCELLATION:
|
||||
Some tool calls run asynchronously in the background. When one starts, a tool response \
|
||||
is added to the conversation whose content is a JSON object with \
|
||||
"type": "tool", "status": "started", and a "tool_call_id" field containing the \
|
||||
exact ID of that call (e.g. {"type": "tool", "status": "started", "tool_call_id": "..."}).
|
||||
"type": "async_tool", "status": "running", and a "tool_call_id" field containing the \
|
||||
exact ID of that call (e.g. {"type": "async_tool", "status": "running", "tool_call_id": "..."}).
|
||||
|
||||
If the user changes topic, explicitly says they no longer need the result, or the pending \
|
||||
result would clearly be stale, call cancel_async_tool_call. \
|
||||
To find the correct tool_call_id: locate the most recent tool response in the conversation \
|
||||
whose content has "status": "started" and whose call has NOT already been cancelled, \
|
||||
whose content has "status": "running" and whose call has NOT already been cancelled, \
|
||||
then copy the "tool_call_id" value from that content exactly as-is. \
|
||||
Never invent or guess a tool_call_id."""
|
||||
|
||||
CANCEL_ASYNC_TOOL_SCHEMA = FunctionSchema(
|
||||
name=CANCEL_ASYNC_TOOL_NAME,
|
||||
description=(
|
||||
"Cancel a single async tool call that is no longer needed. "
|
||||
"Cancel a single async tool call whose results are no longer needed. "
|
||||
"Use this when the user changes topic, indicates a pending result is "
|
||||
"no longer relevant, or when processing the result would produce a "
|
||||
"stale or confusing response. "
|
||||
"The tool_call_id must be the exact 'id' value from the assistant's "
|
||||
"tool call which we wish to cancel, visible in the conversation history."
|
||||
"The tool_call_id must be copied exactly from the 'tool_call_id' field "
|
||||
"in the async tool's 'running' response visible in the conversation history."
|
||||
),
|
||||
properties={
|
||||
"tool_call_id": {
|
||||
"type": "string",
|
||||
"description": ("The exact id of the async call to cancel."),
|
||||
"description": (
|
||||
"The exact tool_call_id from the async tool's 'running' response to cancel."
|
||||
),
|
||||
}
|
||||
},
|
||||
required=["tool_call_id"],
|
||||
|
||||
Reference in New Issue
Block a user