Update LLMService docstrings
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base classes for Large Language Model services with function calling support."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
@@ -41,9 +43,21 @@ FunctionCallHandler = Callable[["FunctionCallParams"], Awaitable[None]]
|
||||
|
||||
# Type alias for a callback function that handles the result of an LLM function call.
|
||||
class FunctionCallResultCallback(Protocol):
|
||||
"""Protocol for function call result callbacks.
|
||||
|
||||
Handles the result of an LLM function call execution.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self, result: Any, *, properties: Optional[FunctionCallResultProperties] = None
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Call the result callback.
|
||||
|
||||
Args:
|
||||
result: The result of the function call.
|
||||
properties: Optional properties for the result.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -51,13 +65,12 @@ class FunctionCallParams:
|
||||
"""Parameters for a function call.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function being called.
|
||||
arguments (Mapping[str, Any]): The arguments for the function.
|
||||
tool_call_id (str): A unique identifier for the function call.
|
||||
llm (LLMService): The LLMService instance being used.
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
result_callback (FunctionCallResultCallback): Callback to handle the result of the function call.
|
||||
|
||||
function_name: The name of the function being called.
|
||||
tool_call_id: A unique identifier for the function call.
|
||||
arguments: The arguments for the function.
|
||||
llm: The LLMService instance being used.
|
||||
context: The LLM context.
|
||||
result_callback: Callback to handle the result of the function call.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
@@ -70,14 +83,14 @@ class FunctionCallParams:
|
||||
|
||||
@dataclass
|
||||
class FunctionCallRegistryItem:
|
||||
"""Represents an entry in our function call registry. This is what the user
|
||||
registers.
|
||||
"""Represents an entry in the function call registry.
|
||||
|
||||
This is what the user registers when calling register_function.
|
||||
|
||||
Attributes:
|
||||
function_name (Optional[str]): The name of the function.
|
||||
handler (FunctionCallHandler): The handler for processing function call parameters.
|
||||
cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption.
|
||||
|
||||
function_name: The name of the function (None for catch-all handler).
|
||||
handler: The handler for processing function call parameters.
|
||||
cancel_on_interruption: Whether to cancel the call on interruption.
|
||||
"""
|
||||
|
||||
function_name: Optional[str]
|
||||
@@ -87,16 +100,17 @@ class FunctionCallRegistryItem:
|
||||
|
||||
@dataclass
|
||||
class FunctionCallRunnerItem:
|
||||
"""Represents an internal function call entry to our function call
|
||||
runner. The runner executes function calls in order.
|
||||
"""Internal function call entry for the function call runner.
|
||||
|
||||
The runner executes function calls in order.
|
||||
|
||||
Attributes:
|
||||
registry_name (Optional[str]): The function call name registration (could be None).
|
||||
function_name (str): The name of the function.
|
||||
tool_call_id (str): A unique identifier for the function call.
|
||||
arguments (Mapping[str, Any]): The arguments for the function.
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
|
||||
registry_item: The registry item containing handler information.
|
||||
function_name: The name of the function.
|
||||
tool_call_id: A unique identifier for the function call.
|
||||
arguments: The arguments for the function.
|
||||
context: The LLM context.
|
||||
run_llm: Optional flag to control LLM execution after function call.
|
||||
"""
|
||||
|
||||
registry_item: FunctionCallRegistryItem
|
||||
@@ -108,22 +122,32 @@ class FunctionCallRunnerItem:
|
||||
|
||||
|
||||
class LLMService(AIService):
|
||||
"""This is the base class for all LLM services. It handles function calling
|
||||
registration and execution. The class also provides event handlers.
|
||||
"""Base class for all LLM services.
|
||||
|
||||
An event to know when an LLM service completion timeout occurs:
|
||||
Handles function calling registration and execution with support for both
|
||||
parallel and sequential execution modes. Provides event handlers for
|
||||
completion timeouts and function call lifecycle events.
|
||||
|
||||
@task.event_handler("on_completion_timeout")
|
||||
async def on_completion_timeout(service):
|
||||
...
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
And an event to know that function calls have been received from the LLM
|
||||
service and that we are going to start executing them:
|
||||
Event handlers:
|
||||
on_completion_timeout: Called when an LLM completion timeout occurs.
|
||||
on_function_calls_started: Called when function calls are received and
|
||||
execution is about to start.
|
||||
|
||||
@task.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls: Sequence[FunctionCallFromLLM]):
|
||||
...
|
||||
Example:
|
||||
```python
|
||||
@task.event_handler("on_completion_timeout")
|
||||
async def on_completion_timeout(service):
|
||||
logger.warning("LLM completion timed out")
|
||||
|
||||
@task.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
logger.info(f"Starting {len(function_calls)} function calls")
|
||||
```
|
||||
"""
|
||||
|
||||
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
|
||||
@@ -143,6 +167,11 @@ class LLMService(AIService):
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
|
||||
def get_llm_adapter(self) -> BaseLLMAdapter:
|
||||
"""Get the LLM adapter instance.
|
||||
|
||||
Returns:
|
||||
The adapter instance used for LLM communication.
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
def create_context_aggregator(
|
||||
@@ -152,24 +181,57 @@ class LLMService(AIService):
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> Any:
|
||||
"""Create a context aggregator for managing LLM conversation context.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create an aggregator for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
A context aggregator instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the LLM service.
|
||||
|
||||
Args:
|
||||
frame: The start frame.
|
||||
"""
|
||||
await super().start(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._create_sequential_runner_task()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the LLM service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._cancel_sequential_runner_task()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the LLM service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._cancel_sequential_runner_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
@@ -188,6 +250,18 @@ class LLMService(AIService):
|
||||
*,
|
||||
cancel_on_interruption: bool = True,
|
||||
):
|
||||
"""Register a function handler for LLM function calls.
|
||||
|
||||
Args:
|
||||
function_name: The name of the function to handle. Use None to handle
|
||||
all function calls with a catch-all handler.
|
||||
handler: The function handler. Should accept a single FunctionCallParams
|
||||
parameter.
|
||||
start_callback: Legacy callback function (deprecated). Put initialization
|
||||
code at the top of your handler instead.
|
||||
cancel_on_interruption: Whether to cancel this function call when an
|
||||
interruption occurs. Defaults to True.
|
||||
"""
|
||||
# Registering a function with the function_name set to None will run
|
||||
# that handler for all functions
|
||||
self._functions[function_name] = FunctionCallRegistryItem(
|
||||
@@ -210,16 +284,38 @@ class LLMService(AIService):
|
||||
self._start_callbacks[function_name] = start_callback
|
||||
|
||||
def unregister_function(self, function_name: Optional[str]):
|
||||
"""Remove a registered function handler.
|
||||
|
||||
Args:
|
||||
function_name: The name of the function handler to remove.
|
||||
"""
|
||||
del self._functions[function_name]
|
||||
if self._start_callbacks[function_name]:
|
||||
del self._start_callbacks[function_name]
|
||||
|
||||
def has_function(self, function_name: str):
|
||||
"""Check if a function handler is registered.
|
||||
|
||||
Args:
|
||||
function_name: The name of the function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is registered or if a catch-all handler (None)
|
||||
is registered.
|
||||
"""
|
||||
if None in self._functions.keys():
|
||||
return True
|
||||
return function_name in self._functions.keys()
|
||||
|
||||
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
|
||||
"""Execute a sequence of function calls from the LLM.
|
||||
|
||||
Triggers the on_function_calls_started event and executes functions
|
||||
either in parallel or sequentially based on the run_in_parallel setting.
|
||||
|
||||
Args:
|
||||
function_calls: The function calls to execute.
|
||||
"""
|
||||
if len(function_calls) == 0:
|
||||
return
|
||||
|
||||
@@ -272,6 +368,18 @@ class LLMService(AIService):
|
||||
text_content: Optional[str] = None,
|
||||
video_source: Optional[str] = None,
|
||||
):
|
||||
"""Request an image from a user.
|
||||
|
||||
Pushes a UserImageRequestFrame upstream to request an image from the
|
||||
specified user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to request an image from.
|
||||
function_name: Optional function name associated with the request.
|
||||
tool_call_id: Optional tool call ID associated with the request.
|
||||
text_content: Optional text content/context for the image request.
|
||||
video_source: Optional video source identifier.
|
||||
"""
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
|
||||
Reference in New Issue
Block a user