Merge pull request #2051 from pipecat-ai/pk/direct-functions
Implement "direct functions", which allow you to bypass specifying a …
This commit is contained in:
32
CHANGELOG.md
32
CHANGELOG.md
@@ -9,6 +9,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for providing "direct" functions, which don't need an
|
||||
accompanying `FlowsFunctionSchema` or function definition dict. Instead,
|
||||
metadata (i.e. `name`, `description`, `properties`, and `required`) are
|
||||
automatically extracted from a combination of the function signature and
|
||||
docstring.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
# "Direct" function
|
||||
# `params` must be the first parameter
|
||||
async def do_something(params: FunctionCallParams, foo: int, bar: str = ""):
|
||||
"""
|
||||
Do something interesting.
|
||||
|
||||
Args:
|
||||
foo (int): The foo to do something interesting with.
|
||||
bar (string): The bar to do something interesting with.
|
||||
"""
|
||||
|
||||
result = await process(foo, bar)
|
||||
await params.result_callback({"result": result})
|
||||
|
||||
# ...
|
||||
|
||||
llm.register_direct_function(do_something)
|
||||
|
||||
# ...
|
||||
|
||||
tools = ToolsSchema(standard_tools=[do_something])
|
||||
```
|
||||
|
||||
- Added `watchdog_coroutine()`. This is a watchdog helper for couroutines. So,
|
||||
if you have a coroutine that is waiting for a result and that takes a long
|
||||
time, you will need to wrap it with `watchdog_coroutine()` so the watchdog
|
||||
|
||||
146
examples/foundational/14t-function-calling-direct.py
Normal file
146
examples/foundational/14t-function-calling-direct.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def get_current_weather(params: FunctionCallParams, location: str, format: str):
|
||||
"""
|
||||
Get the current weather.
|
||||
|
||||
Args:
|
||||
location (str): The city and state, e.g. "San Francisco, CA".
|
||||
format (str): The temperature unit to use. Must be either "celsius" or "fahrenheit". Infer this from the user's location.
|
||||
"""
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def get_restaurant_recommendation(params: FunctionCallParams, location: str):
|
||||
"""
|
||||
Get a restaurant recommendation.
|
||||
|
||||
Args:
|
||||
location (str): The city and state, e.g. "San Francisco, CA".
|
||||
"""
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_direct_function(get_current_weather)
|
||||
llm.register_direct_function(get_restaurant_recommendation)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
tools = ToolsSchema(standard_tools=[get_current_weather, get_restaurant_recommendation])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -31,7 +31,8 @@ dependencies = [
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai~=1.70.0"
|
||||
"openai~=1.70.0",
|
||||
"docstring_parser~=0.16"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
228
src/pipecat/adapters/schemas/direct_function.py
Normal file
228
src/pipecat/adapters/schemas/direct_function.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import inspect
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import docstring_parser
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
|
||||
class DirectFunction(Protocol):
|
||||
"""Protocol for a "direct" function that handles LLM function calls.
|
||||
|
||||
"Direct" functions' metadata is automatically extracted from their function signature and
|
||||
docstrings, allowing them to be used without accompanying function configurations (as
|
||||
`FunctionSchema`s or in provider-specific formats).
|
||||
"""
|
||||
|
||||
async def __call__(self, params: "FunctionCallParams", **kwargs: Any) -> None: ...
|
||||
|
||||
|
||||
class BaseDirectFunctionWrapper:
|
||||
"""
|
||||
Base class for a wrapper around a DirectFunction that:
|
||||
- extracts metadata from the function signature and docstring
|
||||
- using that metadata, generates a corresponding FunctionSchema
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def special_first_param_name(cls) -> str:
|
||||
"""The name of the "special" first function parameter that is ignored by the metadata
|
||||
extraction, as it's not relevant to the LLM.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must define the special first parameter name.")
|
||||
|
||||
def __init__(self, function: Callable):
|
||||
self.__class__.validate_function(function)
|
||||
self.function = function
|
||||
self._initialize_metadata()
|
||||
|
||||
@classmethod
|
||||
def validate_function(cls, function: Callable) -> None:
|
||||
if not inspect.iscoroutinefunction(function):
|
||||
raise Exception(f"Direct function {function.__name__} must be async")
|
||||
params = list(inspect.signature(function).parameters.items())
|
||||
special_first_param_name = cls.special_first_param_name()
|
||||
if len(params) == 0:
|
||||
raise Exception(
|
||||
f"Direct function {function.__name__} must have at least one parameter ({special_first_param_name})"
|
||||
)
|
||||
first_param_name = params[0][0]
|
||||
if first_param_name != special_first_param_name:
|
||||
raise Exception(
|
||||
f"Direct function {function.__name__} first parameter must be named '{special_first_param_name}'"
|
||||
)
|
||||
|
||||
def to_function_schema(self) -> FunctionSchema:
|
||||
return FunctionSchema(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
properties=self.properties,
|
||||
required=self.required,
|
||||
)
|
||||
|
||||
def _initialize_metadata(self):
|
||||
# Get function name
|
||||
self.name = self.function.__name__
|
||||
|
||||
# Parse docstring for description and parameters
|
||||
docstring = docstring_parser.parse(inspect.getdoc(self.function))
|
||||
|
||||
# Get function description
|
||||
self.description = (docstring.description or "").strip()
|
||||
|
||||
# Get function parameters as JSON schemas, and the list of required parameters
|
||||
self.properties, self.required = self._get_parameters_as_jsonschema(
|
||||
self.function, docstring.params
|
||||
)
|
||||
|
||||
# TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function
|
||||
def _get_parameters_as_jsonschema(
|
||||
self, func: Callable, docstring_params: List[docstring_parser.DocstringParam]
|
||||
) -> Tuple[Dict[str, Any], List[str]]:
|
||||
"""
|
||||
Get function parameters as a dictionary of JSON schemas and a list of required parameters.
|
||||
Ignore the first parameter, as it's expected to be the "special" one.
|
||||
|
||||
Args:
|
||||
func: Function to get parameters from
|
||||
docstring_params: List of parameters extracted from the function's docstring
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A dictionary mapping each function parameter to its JSON schema
|
||||
- A list of required parameter names
|
||||
"""
|
||||
|
||||
sig = inspect.signature(func)
|
||||
hints = get_type_hints(func)
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
# Ignore 'self' parameter
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
# Ignore the first parameter, which is expected to be the "special" one
|
||||
# (We have already validated that this is the case in validate_function())
|
||||
is_first_param = name == next(iter(sig.parameters))
|
||||
if is_first_param:
|
||||
continue
|
||||
|
||||
type_hint = hints.get(name)
|
||||
|
||||
# Convert type hint to JSON schema
|
||||
properties[name] = self._typehint_to_jsonschema(type_hint)
|
||||
|
||||
# Add whether the parameter is required
|
||||
# If the parameter has no default value, it's required
|
||||
if param.default is inspect.Parameter.empty:
|
||||
required.append(name)
|
||||
|
||||
# Add parameter description from docstring
|
||||
for doc_param in docstring_params:
|
||||
if doc_param.arg_name == name:
|
||||
properties[name]["description"] = doc_param.description or ""
|
||||
|
||||
return properties, required
|
||||
|
||||
def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Python type hint to a JSON Schema.
|
||||
|
||||
Args:
|
||||
type_hint: A Python type hint
|
||||
|
||||
Returns:
|
||||
A dictionary representing the JSON Schema
|
||||
"""
|
||||
if type_hint is None:
|
||||
return {}
|
||||
|
||||
# Handle basic types
|
||||
if type_hint is type(None):
|
||||
return {"type": "null"}
|
||||
if type_hint is str:
|
||||
return {"type": "string"}
|
||||
elif type_hint is int:
|
||||
return {"type": "integer"}
|
||||
elif type_hint is float:
|
||||
return {"type": "number"}
|
||||
elif type_hint is bool:
|
||||
return {"type": "boolean"}
|
||||
elif type_hint is dict or type_hint is Dict:
|
||||
return {"type": "object"}
|
||||
elif type_hint is list or type_hint is List:
|
||||
return {"type": "array"}
|
||||
|
||||
# Get origin and arguments for complex types
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Handle Optional/Union types
|
||||
if origin is Union or origin is types.UnionType:
|
||||
return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]}
|
||||
|
||||
# Handle List, Tuple, Set with specific item types
|
||||
if origin in (list, List, tuple, Tuple, set, Set) and args:
|
||||
return {"type": "array", "items": self._typehint_to_jsonschema(args[0])}
|
||||
|
||||
# Handle Dict with specific key/value types
|
||||
if origin in (dict, Dict) and len(args) == 2:
|
||||
# For JSON Schema, keys must be strings
|
||||
return {"type": "object", "additionalProperties": self._typehint_to_jsonschema(args[1])}
|
||||
|
||||
# Handle TypedDict
|
||||
if hasattr(type_hint, "__annotations__"):
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
# NOTE: this does not yet support some fields being required and others not, which could happen when:
|
||||
# - the base class is a TypedDict with required fields (total=True or not specified) and the derived class has optional fields (total=False)
|
||||
# - Python 3.11+ NotRequired is used
|
||||
all_fields_required = getattr(type_hint, "__total__", True)
|
||||
|
||||
for field_name, field_type in get_type_hints(type_hint).items():
|
||||
properties[field_name] = self._typehint_to_jsonschema(field_type)
|
||||
if all_fields_required:
|
||||
required.append(field_name)
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
||||
return schema
|
||||
|
||||
# Default to any type if we can't determine the specific schema
|
||||
return {}
|
||||
|
||||
|
||||
class DirectFunctionWrapper(BaseDirectFunctionWrapper):
|
||||
"""
|
||||
Wrapper around a DirectFunction that:
|
||||
- extracts metadata from the function signature and docstring
|
||||
- generates a corresponding FunctionSchema
|
||||
- helps with function invocation
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def special_first_param_name(cls) -> str:
|
||||
return "params"
|
||||
|
||||
async def invoke(self, args: Mapping[str, Any], params: "FunctionCallParams"):
|
||||
return await self.function(params=params, **args)
|
||||
@@ -13,6 +13,7 @@ and custom adapter-specific tools in the Pipecat framework.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
|
||||
@@ -36,7 +37,7 @@ class ToolsSchema:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
standard_tools: List[FunctionSchema],
|
||||
standard_tools: List[FunctionSchema | DirectFunction],
|
||||
custom_tools: Optional[Dict[AdapterType, List[Dict[str, Any]]]] = None,
|
||||
) -> None:
|
||||
"""Initialize the tools schema.
|
||||
@@ -46,7 +47,20 @@ class ToolsSchema:
|
||||
custom_tools: Dictionary mapping adapter types to their custom tool definitions.
|
||||
These tools may not follow the FunctionSchema format (e.g., search_tool).
|
||||
"""
|
||||
self._standard_tools = standard_tools
|
||||
|
||||
def _map_standard_tools(tools):
|
||||
schemas = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, FunctionSchema):
|
||||
schemas.append(tool)
|
||||
elif callable(tool):
|
||||
wrapper = DirectFunctionWrapper(tool)
|
||||
schemas.append(wrapper.to_function_schema())
|
||||
else:
|
||||
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
||||
return schemas
|
||||
|
||||
self._standard_tools = _map_standard_tools(standard_tools)
|
||||
self._custom_tools = custom_tools
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,12 +8,33 @@
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import docstring_parser
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -94,7 +115,7 @@ class FunctionCallRegistryItem:
|
||||
"""
|
||||
|
||||
function_name: Optional[str]
|
||||
handler: FunctionCallHandler
|
||||
handler: FunctionCallHandler | "DirectFunctionWrapper"
|
||||
cancel_on_interruption: bool
|
||||
|
||||
|
||||
@@ -285,6 +306,19 @@ class LLMService(AIService):
|
||||
|
||||
self._start_callbacks[function_name] = start_callback
|
||||
|
||||
def register_direct_function(
|
||||
self,
|
||||
handler: DirectFunction,
|
||||
*,
|
||||
cancel_on_interruption: bool = True,
|
||||
):
|
||||
wrapper = DirectFunctionWrapper(handler)
|
||||
self._functions[wrapper.name] = FunctionCallRegistryItem(
|
||||
function_name=wrapper.name,
|
||||
handler=wrapper,
|
||||
cancel_on_interruption=cancel_on_interruption,
|
||||
)
|
||||
|
||||
def unregister_function(self, function_name: Optional[str]):
|
||||
"""Remove a registered function handler.
|
||||
|
||||
@@ -295,6 +329,11 @@ class LLMService(AIService):
|
||||
if self._start_callbacks[function_name]:
|
||||
del self._start_callbacks[function_name]
|
||||
|
||||
def unregister_direct_function(self, handler: Any):
|
||||
wrapper = DirectFunctionWrapper(handler)
|
||||
del self._functions[wrapper.name]
|
||||
# Note: no need to remove start callback here, as direct functions don't support start callbacks.
|
||||
|
||||
def has_function(self, function_name: str):
|
||||
"""Check if a function handler is registered.
|
||||
|
||||
@@ -474,35 +513,50 @@ class LLMService(AIService):
|
||||
await self.push_frame(result_frame_downstream, FrameDirection.DOWNSTREAM)
|
||||
await self.push_frame(result_frame_upstream, FrameDirection.UPSTREAM)
|
||||
|
||||
signature = inspect.signature(item.handler)
|
||||
if len(signature.parameters) > 1:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
await item.handler(
|
||||
runner_item.function_name,
|
||||
runner_item.tool_call_id,
|
||||
runner_item.arguments,
|
||||
self,
|
||||
runner_item.context,
|
||||
function_call_result_callback,
|
||||
if isinstance(item.handler, DirectFunctionWrapper):
|
||||
# Handler is a DirectFunctionWrapper
|
||||
await item.handler.invoke(
|
||||
args=runner_item.arguments,
|
||||
params=FunctionCallParams(
|
||||
function_name=runner_item.function_name,
|
||||
tool_call_id=runner_item.tool_call_id,
|
||||
arguments=runner_item.arguments,
|
||||
llm=self,
|
||||
context=runner_item.context,
|
||||
result_callback=function_call_result_callback,
|
||||
),
|
||||
)
|
||||
else:
|
||||
params = FunctionCallParams(
|
||||
function_name=runner_item.function_name,
|
||||
tool_call_id=runner_item.tool_call_id,
|
||||
arguments=runner_item.arguments,
|
||||
llm=self,
|
||||
context=runner_item.context,
|
||||
result_callback=function_call_result_callback,
|
||||
)
|
||||
await item.handler(params)
|
||||
# Handler is a FunctionCallHandler
|
||||
signature = inspect.signature(item.handler)
|
||||
if len(signature.parameters) > 1:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
await item.handler(
|
||||
runner_item.function_name,
|
||||
runner_item.tool_call_id,
|
||||
runner_item.arguments,
|
||||
self,
|
||||
runner_item.context,
|
||||
function_call_result_callback,
|
||||
)
|
||||
else:
|
||||
params = FunctionCallParams(
|
||||
function_name=runner_item.function_name,
|
||||
tool_call_id=runner_item.tool_call_id,
|
||||
arguments=runner_item.arguments,
|
||||
llm=self,
|
||||
context=runner_item.context,
|
||||
result_callback=function_call_result_callback,
|
||||
)
|
||||
await item.handler(params)
|
||||
|
||||
async def _cancel_function_call(self, function_name: Optional[str]):
|
||||
cancelled_tasks = set()
|
||||
|
||||
229
tests/test_direct_functions.py
Normal file
229
tests/test_direct_functions.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from typing import Optional, TypedDict, Union
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
class TestDirectFunction(unittest.TestCase):
|
||||
def test_name_is_set_from_function(self):
|
||||
async def my_function(params: FunctionCallParams):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function)
|
||||
self.assertEqual(func.name, "my_function")
|
||||
|
||||
def test_description_is_set_from_function(self):
|
||||
async def my_function_short_description(params: FunctionCallParams):
|
||||
"""This is a test function."""
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_short_description)
|
||||
self.assertEqual(func.description, "This is a test function.")
|
||||
|
||||
async def my_function_long_description(params: FunctionCallParams):
|
||||
"""
|
||||
This is a test function.
|
||||
|
||||
It does some really cool stuff.
|
||||
|
||||
Trust me, you'll want to use it.
|
||||
"""
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_long_description)
|
||||
self.assertEqual(
|
||||
func.description,
|
||||
"This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.",
|
||||
)
|
||||
|
||||
def test_properties_are_set_from_function(self):
|
||||
async def my_function_no_params(params: FunctionCallParams):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_no_params)
|
||||
self.assertEqual(func.properties, {})
|
||||
|
||||
async def my_function_simple_params(
|
||||
params: FunctionCallParams, name: str, age: int, height: Union[float, None]
|
||||
):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_simple_params)
|
||||
self.assertEqual(
|
||||
func.properties,
|
||||
{
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"height": {"anyOf": [{"type": "number"}, {"type": "null"}]},
|
||||
},
|
||||
)
|
||||
|
||||
async def my_function_complex_params(
|
||||
params: FunctionCallParams,
|
||||
address_lines: list[str],
|
||||
nickname: str | int | float,
|
||||
extra: Optional[dict[str, str]],
|
||||
):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_complex_params)
|
||||
self.assertEqual(
|
||||
func.properties,
|
||||
{
|
||||
"address_lines": {"type": "array", "items": {"type": "string"}},
|
||||
"nickname": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}]
|
||||
},
|
||||
"extra": {
|
||||
"anyOf": [
|
||||
{"type": "object", "additionalProperties": {"type": "string"}},
|
||||
{"type": "null"},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
class MyInfo1(TypedDict):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class MyInfo2(TypedDict, total=False):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
async def my_function_complex_type_params(
|
||||
params: FunctionCallParams, info1: MyInfo1, info2: MyInfo2
|
||||
):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_complex_type_params)
|
||||
self.assertEqual(
|
||||
func.properties,
|
||||
{
|
||||
"info1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
"info2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_required_is_set_from_function(self):
|
||||
async def my_function_no_params(params: FunctionCallParams):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_no_params)
|
||||
self.assertEqual(func.required, [])
|
||||
|
||||
async def my_function_simple_params(
|
||||
params: FunctionCallParams, name: str, age: int, height: Union[float, None] = None
|
||||
):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_simple_params)
|
||||
self.assertEqual(func.required, ["name", "age"])
|
||||
|
||||
async def my_function_complex_params(
|
||||
params: FunctionCallParams,
|
||||
address_lines: Optional[list[str]],
|
||||
nickname: str | int = "Bud",
|
||||
extra: Optional[dict[str, str]] = None,
|
||||
):
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function_complex_params)
|
||||
self.assertEqual(func.required, ["address_lines"])
|
||||
|
||||
def test_property_descriptions_are_set_from_function(self):
|
||||
async def my_function(
|
||||
params: FunctionCallParams, name: str, age: int, height: Union[float, None]
|
||||
):
|
||||
"""
|
||||
This is a test function.
|
||||
|
||||
Args:
|
||||
name (str): The name of the person.
|
||||
age (int): The age of the person.
|
||||
height (float | None): The height of the person in meters. Defaults to None.
|
||||
"""
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function)
|
||||
|
||||
# Validate that the function description is still set correctly even with the longer docstring
|
||||
self.assertEqual(func.description, "This is a test function.")
|
||||
|
||||
# Validate that the property descriptions are set correctly
|
||||
self.assertEqual(
|
||||
func.properties,
|
||||
{
|
||||
"name": {"type": "string", "description": "The name of the person."},
|
||||
"age": {"type": "integer", "description": "The age of the person."},
|
||||
"height": {
|
||||
"anyOf": [{"type": "number"}, {"type": "null"}],
|
||||
"description": "The height of the person in meters. Defaults to None.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_functions_fail_validation(self):
|
||||
def my_function_non_async(params: FunctionCallParams):
|
||||
return {"status": "success"}, None
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
DirectFunctionWrapper(function=my_function_non_async)
|
||||
|
||||
async def my_function_missing_params():
|
||||
return {"status": "success"}, None
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
DirectFunctionWrapper(my_function_missing_params)
|
||||
|
||||
async def my_function_misplaced_params(foo: str, params: FunctionCallParams):
|
||||
return {"status": "success"}, None
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
DirectFunctionWrapper(my_function_misplaced_params)
|
||||
|
||||
def test_invoke_calls_function_with_args_and_params_object(self):
|
||||
called = {}
|
||||
|
||||
class DummyParams:
|
||||
pass
|
||||
|
||||
async def my_function(params: DummyParams, name: str, age: int):
|
||||
called["params"] = params
|
||||
called["name"] = name
|
||||
called["age"] = age
|
||||
return {"status": "success"}, None
|
||||
|
||||
func = DirectFunctionWrapper(function=my_function)
|
||||
params = DummyParams()
|
||||
args = {"name": "Alice", "age": 30}
|
||||
|
||||
result = asyncio.run(func.invoke(args=args, params=params))
|
||||
self.assertEqual(result, ({"status": "success"}, None))
|
||||
self.assertIs(called["params"], params)
|
||||
self.assertEqual(called["name"], "Alice")
|
||||
self.assertEqual(called["age"], 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user