diff --git a/src/pipecat/adapters/base_llm_adapter.py b/src/pipecat/adapters/base_llm_adapter.py index 7d210d1f5..b1080d197 100644 --- a/src/pipecat/adapters/base_llm_adapter.py +++ b/src/pipecat/adapters/base_llm_adapter.py @@ -10,6 +10,7 @@ This module provides the abstract base class for implementing LLM provider-speci adapters that handle tool format conversion and standardization. """ +import warnings from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, TypeVar @@ -49,18 +50,19 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]): def __init__(self): """Initialize the adapter.""" self._warned_system_instruction = False - self._builtin_tools: List[FunctionSchema] = [] + self._builtin_tools: Dict[str, FunctionSchema] = {} @property - def builtin_tools(self) -> List[FunctionSchema]: + def builtin_tools(self) -> Dict[str, FunctionSchema]: """Built-in tools automatically merged into every inference request. - Mixins (e.g. ``AsyncToolCancellationLLMServiceMixin``) append their - tool schemas here so that the tools are injected transparently without - the user having to add them to their ``ToolsSchema``. + Keyed by tool name for O(1) lookup, insertion, and removal. The + service injects tools here so they are sent transparently on every + inference request without the user having to add them to their + ``ToolsSchema``. Returns: - Mutable list of ``FunctionSchema`` instances. + Mutable dict mapping tool name to ``FunctionSchema``. """ return self._builtin_tools @@ -150,23 +152,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]): if self._builtin_tools: if isinstance(tools, ToolsSchema): tools = ToolsSchema( - standard_tools=tools.standard_tools + self._builtin_tools, + standard_tools=tools.standard_tools + list(self._builtin_tools.values()), custom_tools=tools.custom_tools, ) else: - # User supplied tools in a legacy/provider-specific format; - # we cannot safely merge — build a schema from builtins only. + # User supplied tools in a legacy/provider-specific format. + # Built-in tools cannot be safely merged, so they will not be injected. + # Migrate to ToolsSchema to enable built-in tool support; use custom_tools + # as an escape hatch for any provider-specific tools that don't fit the + # standard schema. if tools is not None: - logger.warning( - "Built-in tools could not be merged because the supplied tools are not" - " a ToolsSchema instance. Only built-in tools will be sent." + warnings.warn( + "Built-in tools (e.g. async tool cancellation) could not be injected " + "because the supplied tools are not a ToolsSchema instance. " + "Migrate to ToolsSchema to enable built-in tool support. " + "Use ToolsSchema(custom_tools=...) as an escape hatch for any " + "provider-specific tools that don't fit the standard schema.", + DeprecationWarning, + stacklevel=2, ) - tools = ToolsSchema(standard_tools=self._builtin_tools) + # Fall through and return the original tools unchanged. if isinstance(tools, ToolsSchema): - logger.debug(f"Retrieving the tools using the adapter: {type(self)}") - tool_names = [tool.name for tool in tools.standard_tools] - logger.debug(f"Tool names: {tool_names}") return self.to_provider_tools_format(tools) # Fallback to return the same tools in case they are not in a standard format return tools diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 127b9d60e..3602e8eda 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -207,6 +207,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): run_in_parallel: bool = True, group_parallel_tools: bool = True, function_call_timeout_secs: Optional[float] = None, + enable_async_tool_cancellation: bool = False, settings: Optional[LLMSettings] = None, **kwargs, ): @@ -221,6 +222,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): arrives. Defaults to True. function_call_timeout_secs: Optional timeout in seconds for deferred function calls. + enable_async_tool_cancellation: When True and at least one async function + (``cancel_on_interruption=False``) is registered, automatically injects + the ``cancel_async_tool_call`` built-in tool and its system instructions + so the LLM can cancel stale in-progress calls. Defaults to False. settings: The runtime-updatable settings for the LLM service. **kwargs: Additional arguments passed to the parent AIService. @@ -235,8 +240,9 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): self._run_in_parallel = run_in_parallel self._group_parallel_tools = group_parallel_tools self._function_call_timeout_secs = function_call_timeout_secs + self._enable_async_tool_cancellation: bool = enable_async_tool_cancellation self._filter_incomplete_user_turns: bool = False - self._async_cancellation_enabled: bool = False + self._async_tool_cancellation_enabled: bool = False self._base_system_instruction: Optional[str] = None self._adapter = self.adapter_class() self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {} @@ -298,7 +304,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): await super().start(frame) if not self._run_in_parallel: await self._create_sequential_runner_task() - if self._has_async_functions(): + if self._enable_async_tool_cancellation and self._has_async_tools(): self._setup_async_tool_cancellation() async def stop(self, frame: EndFrame): @@ -334,7 +340,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): parts = [base] if base else [] if self._filter_incomplete_user_turns: parts.append(self._user_turn_completion_config.completion_instructions) - if self._async_cancellation_enabled: + if self._async_tool_cancellation_enabled: parts.append(ASYNC_TOOL_CANCELLATION_INSTRUCTIONS) composed = "\n\n".join(p for p in parts if p) self._settings.system_instruction = composed or None @@ -373,7 +379,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): if ( "system_instruction" in changed - and (self._filter_incomplete_user_turns or self._async_cancellation_enabled) + and (self._filter_incomplete_user_turns or self._async_tool_cancellation_enabled) and "filter_incomplete_user_turns" not in changed ): # system_instruction changed while composition is active. @@ -588,6 +594,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): ``function_call_timeout_secs`` for this specific function. Defaults to None, which uses the global timeout. """ + if function_name == CANCEL_ASYNC_TOOL_NAME: + raise ValueError( + f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be " + "registered by user code." + ) # Registering a function with the function_name set to None will run # that handler for all functions self._functions[function_name] = FunctionCallRegistryItem( @@ -622,6 +633,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): None, which uses the global timeout. """ wrapper = DirectFunctionWrapper(handler) + if wrapper.name == CANCEL_ASYNC_TOOL_NAME: + raise ValueError( + f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be " + "registered by user code." + ) self._functions[wrapper.name] = FunctionCallRegistryItem( function_name=wrapper.name, handler=wrapper, @@ -636,6 +652,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): function_name: The name of the function handler to remove. """ del self._functions[function_name] + if self._async_tool_cancellation_enabled and not self._has_async_tools(): + self._teardown_async_tool_cancellation() def unregister_direct_function(self, handler: Any): """Remove a registered direct function handler. @@ -646,6 +664,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): wrapper = DirectFunctionWrapper(handler) del self._functions[wrapper.name] # Note: no need to remove start callback here, as direct functions don't support start callbacks. + if self._async_tool_cancellation_enabled and not self._has_async_tools(): + self._teardown_async_tool_cancellation() def has_function(self, function_name: str): """Check if a function handler is registered. @@ -861,8 +881,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): if timeout_task and not timeout_task.done(): await self.cancel_task(timeout_task) - def _has_async_functions(self) -> bool: - """Return True if at least one non-builtin async function is registered.""" + def _has_async_tools(self) -> bool: + """Return True if at least one non-builtin async tool is registered.""" return any( not item.cancel_on_interruption for name, item in self._functions.items() @@ -874,19 +894,18 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): Saves the base system instruction, recomposes to include cancellation instructions, registers the built-in ``cancel_async_tool_call`` handler, - and injects its schema into the adapter's built-in tool list. + and injects its schema into the adapter's built-in tool dict. """ logger.debug(f"{self}: Enabling async tool cancellation") - self._async_cancellation_enabled = True + self._async_tool_cancellation_enabled = True if self._base_system_instruction is None: self._base_system_instruction = self._settings.system_instruction self._compose_system_instruction() - if not any(t.name == CANCEL_ASYNC_TOOL_NAME for t in self._adapter.builtin_tools): - self._adapter.builtin_tools.append(CANCEL_ASYNC_TOOL_SCHEMA) + self._adapter.builtin_tools[CANCEL_ASYNC_TOOL_NAME] = CANCEL_ASYNC_TOOL_SCHEMA if CANCEL_ASYNC_TOOL_NAME not in self._functions: self._functions[CANCEL_ASYNC_TOOL_NAME] = FunctionCallRegistryItem( @@ -895,13 +914,26 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): cancel_on_interruption=True, ) + def _teardown_async_tool_cancellation(self): + """Disable async tool cancellation. + + Removes the built-in ``cancel_async_tool_call`` handler and its schema, + recomposes the system instruction without cancellation instructions. + """ + logger.debug(f"{self}: Disabling async tool cancellation") + + self._async_tool_cancellation_enabled = False + self._adapter.builtin_tools.pop(CANCEL_ASYNC_TOOL_NAME, None) + self._functions.pop(CANCEL_ASYNC_TOOL_NAME, None) + self._compose_system_instruction() + async def _cancel_async_tool_call_handler(self, params: "FunctionCallParams"): """Handle a ``cancel_async_tool_call`` invocation from the LLM. Args: params: Function call parameters containing ``tool_call_id`` to cancel. """ - logger.info("_cancel_async_tool_call_handler invoked!") + logger.debug(f"{self}: cancel_async_tool_call invoked") tool_call_id: Optional[str] = params.arguments.get("tool_call_id") if not tool_call_id: diff --git a/src/pipecat/utils/async_tool_cancellation.py b/src/pipecat/utils/async_tool_cancellation.py index e00741508..c32dc1a29 100644 --- a/src/pipecat/utils/async_tool_cancellation.py +++ b/src/pipecat/utils/async_tool_cancellation.py @@ -19,30 +19,32 @@ CANCEL_ASYNC_TOOL_NAME = "cancel_async_tool_call" ASYNC_TOOL_CANCELLATION_INSTRUCTIONS = """ASYNC TOOL CANCELLATION: Some tool calls run asynchronously in the background. When one starts, a tool response \ is added to the conversation whose content is a JSON object with \ -"type": "tool", "status": "started", and a "tool_call_id" field containing the \ -exact ID of that call (e.g. {"type": "tool", "status": "started", "tool_call_id": "..."}). +"type": "async_tool", "status": "running", and a "tool_call_id" field containing the \ +exact ID of that call (e.g. {"type": "async_tool", "status": "running", "tool_call_id": "..."}). If the user changes topic, explicitly says they no longer need the result, or the pending \ result would clearly be stale, call cancel_async_tool_call. \ To find the correct tool_call_id: locate the most recent tool response in the conversation \ -whose content has "status": "started" and whose call has NOT already been cancelled, \ +whose content has "status": "running" and whose call has NOT already been cancelled, \ then copy the "tool_call_id" value from that content exactly as-is. \ Never invent or guess a tool_call_id.""" CANCEL_ASYNC_TOOL_SCHEMA = FunctionSchema( name=CANCEL_ASYNC_TOOL_NAME, description=( - "Cancel a single async tool call that is no longer needed. " + "Cancel a single async tool call whose results are no longer needed. " "Use this when the user changes topic, indicates a pending result is " "no longer relevant, or when processing the result would produce a " "stale or confusing response. " - "The tool_call_id must be the exact 'id' value from the assistant's " - "tool call which we wish to cancel, visible in the conversation history." + "The tool_call_id must be copied exactly from the 'tool_call_id' field " + "in the async tool's 'running' response visible in the conversation history." ), properties={ "tool_call_id": { "type": "string", - "description": ("The exact id of the async call to cancel."), + "description": ( + "The exact tool_call_id from the async tool's 'running' response to cancel." + ), } }, required=["tool_call_id"],