Update LLMService docstrings

This commit is contained in:
Mark Backman
2025-06-25 15:57:43 -04:00
parent 202055a9b8
commit fb12bf9b4c

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Large Language Model services with function calling support."""
import asyncio
import inspect
from dataclasses import dataclass
@@ -41,9 +43,21 @@ FunctionCallHandler = Callable[["FunctionCallParams"], Awaitable[None]]
# Type alias for a callback function that handles the result of an LLM function call.
class FunctionCallResultCallback(Protocol):
"""Protocol for function call result callbacks.
Handles the result of an LLM function call execution.
"""
async def __call__(
self, result: Any, *, properties: Optional[FunctionCallResultProperties] = None
) -> None: ...
) -> None:
"""Call the result callback.
Args:
result: The result of the function call.
properties: Optional properties for the result.
"""
...
@dataclass
@@ -51,13 +65,12 @@ class FunctionCallParams:
"""Parameters for a function call.
Attributes:
function_name (str): The name of the function being called.
arguments (Mapping[str, Any]): The arguments for the function.
tool_call_id (str): A unique identifier for the function call.
llm (LLMService): The LLMService instance being used.
context (OpenAILLMContext): The LLM context.
result_callback (FunctionCallResultCallback): Callback to handle the result of the function call.
function_name: The name of the function being called.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
llm: The LLMService instance being used.
context: The LLM context.
result_callback: Callback to handle the result of the function call.
"""
function_name: str
@@ -70,14 +83,14 @@ class FunctionCallParams:
@dataclass
class FunctionCallRegistryItem:
"""Represents an entry in our function call registry. This is what the user
registers.
"""Represents an entry in the function call registry.
This is what the user registers when calling register_function.
Attributes:
function_name (Optional[str]): The name of the function.
handler (FunctionCallHandler): The handler for processing function call parameters.
cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption.
function_name: The name of the function (None for catch-all handler).
handler: The handler for processing function call parameters.
cancel_on_interruption: Whether to cancel the call on interruption.
"""
function_name: Optional[str]
@@ -87,16 +100,17 @@ class FunctionCallRegistryItem:
@dataclass
class FunctionCallRunnerItem:
"""Represents an internal function call entry to our function call
runner. The runner executes function calls in order.
"""Internal function call entry for the function call runner.
The runner executes function calls in order.
Attributes:
registry_name (Optional[str]): The function call name registration (could be None).
function_name (str): The name of the function.
tool_call_id (str): A unique identifier for the function call.
arguments (Mapping[str, Any]): The arguments for the function.
context (OpenAILLMContext): The LLM context.
registry_item: The registry item containing handler information.
function_name: The name of the function.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
context: The LLM context.
run_llm: Optional flag to control LLM execution after function call.
"""
registry_item: FunctionCallRegistryItem
@@ -108,22 +122,32 @@ class FunctionCallRunnerItem:
class LLMService(AIService):
"""This is the base class for all LLM services. It handles function calling
registration and execution. The class also provides event handlers.
"""Base class for all LLM services.
An event to know when an LLM service completion timeout occurs:
Handles function calling registration and execution with support for both
parallel and sequential execution modes. Provides event handlers for
completion timeouts and function call lifecycle events.
@task.event_handler("on_completion_timeout")
async def on_completion_timeout(service):
...
Args:
run_in_parallel: Whether to run function calls in parallel or sequentially.
Defaults to True.
**kwargs: Additional arguments passed to the parent AIService.
And an event to know that function calls have been received from the LLM
service and that we are going to start executing them:
Event handlers:
on_completion_timeout: Called when an LLM completion timeout occurs.
on_function_calls_started: Called when function calls are received and
execution is about to start.
@task.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls: Sequence[FunctionCallFromLLM]):
...
Example:
```python
@task.event_handler("on_completion_timeout")
async def on_completion_timeout(service):
logger.warning("LLM completion timed out")
@task.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
logger.info(f"Starting {len(function_calls)} function calls")
```
"""
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
@@ -143,6 +167,11 @@ class LLMService(AIService):
self._register_event_handler("on_completion_timeout")
def get_llm_adapter(self) -> BaseLLMAdapter:
"""Get the LLM adapter instance.
Returns:
The adapter instance used for LLM communication.
"""
return self._adapter
def create_context_aggregator(
@@ -152,24 +181,57 @@ class LLMService(AIService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> Any:
"""Create a context aggregator for managing LLM conversation context.
Must be implemented by subclasses.
Args:
context: The LLM context to create an aggregator for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
A context aggregator instance.
"""
pass
async def start(self, frame: StartFrame):
"""Start the LLM service.
Args:
frame: The start frame.
"""
await super().start(frame)
if not self._run_in_parallel:
await self._create_sequential_runner_task()
async def stop(self, frame: EndFrame):
"""Stop the LLM service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def cancel(self, frame: CancelFrame):
"""Cancel the LLM service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
@@ -188,6 +250,18 @@ class LLMService(AIService):
*,
cancel_on_interruption: bool = True,
):
"""Register a function handler for LLM function calls.
Args:
function_name: The name of the function to handle. Use None to handle
all function calls with a catch-all handler.
handler: The function handler. Should accept a single FunctionCallParams
parameter.
start_callback: Legacy callback function (deprecated). Put initialization
code at the top of your handler instead.
cancel_on_interruption: Whether to cancel this function call when an
interruption occurs. Defaults to True.
"""
# Registering a function with the function_name set to None will run
# that handler for all functions
self._functions[function_name] = FunctionCallRegistryItem(
@@ -210,16 +284,38 @@ class LLMService(AIService):
self._start_callbacks[function_name] = start_callback
def unregister_function(self, function_name: Optional[str]):
"""Remove a registered function handler.
Args:
function_name: The name of the function handler to remove.
"""
del self._functions[function_name]
if self._start_callbacks[function_name]:
del self._start_callbacks[function_name]
def has_function(self, function_name: str):
"""Check if a function handler is registered.
Args:
function_name: The name of the function to check.
Returns:
True if the function is registered or if a catch-all handler (None)
is registered.
"""
if None in self._functions.keys():
return True
return function_name in self._functions.keys()
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
"""Execute a sequence of function calls from the LLM.
Triggers the on_function_calls_started event and executes functions
either in parallel or sequentially based on the run_in_parallel setting.
Args:
function_calls: The function calls to execute.
"""
if len(function_calls) == 0:
return
@@ -272,6 +368,18 @@ class LLMService(AIService):
text_content: Optional[str] = None,
video_source: Optional[str] = None,
):
"""Request an image from a user.
Pushes a UserImageRequestFrame upstream to request an image from the
specified user.
Args:
user_id: The ID of the user to request an image from.
function_name: Optional function name associated with the request.
tool_call_id: Optional tool call ID associated with the request.
text_content: Optional text content/context for the image request.
video_source: Optional video source identifier.
"""
await self.push_frame(
UserImageRequestFrame(
user_id=user_id,