Merge pull request #2051 from pipecat-ai/pk/direct-functions

Implement "direct functions", which allow you to bypass specifying a …
This commit is contained in:
kompfner
2025-07-01 14:19:33 -04:00
committed by GitHub
7 changed files with 736 additions and 32 deletions

View File

@@ -9,6 +9,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support for providing "direct" functions, which don't need an
accompanying `FlowsFunctionSchema` or function definition dict. Instead,
metadata (i.e. `name`, `description`, `properties`, and `required`) are
automatically extracted from a combination of the function signature and
docstring.
Usage:
```python
# "Direct" function
# `params` must be the first parameter
async def do_something(params: FunctionCallParams, foo: int, bar: str = ""):
"""
Do something interesting.
Args:
foo (int): The foo to do something interesting with.
bar (string): The bar to do something interesting with.
"""
result = await process(foo, bar)
await params.result_callback({"result": result})
# ...
llm.register_direct_function(do_something)
# ...
tools = ToolsSchema(standard_tools=[do_something])
```
- Added `watchdog_coroutine()`. This is a watchdog helper for couroutines. So,
if you have a coroutine that is waiting for a result and that takes a long
time, you will need to wrap it with `watchdog_coroutine()` so the watchdog

View File

@@ -0,0 +1,146 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
async def get_current_weather(params: FunctionCallParams, location: str, format: str):
"""
Get the current weather.
Args:
location (str): The city and state, e.g. "San Francisco, CA".
format (str): The temperature unit to use. Must be either "celsius" or "fahrenheit". Infer this from the user's location.
"""
await params.result_callback({"conditions": "nice", "temperature": "75"})
async def get_restaurant_recommendation(params: FunctionCallParams, location: str):
"""
Get a restaurant recommendation.
Args:
location (str): The city and state, e.g. "San Francisco, CA".
"""
await params.result_callback({"name": "The Golden Dragon"})
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_direct_function(get_current_weather)
llm.register_direct_function(get_restaurant_recommendation)
@llm.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
tools = ToolsSchema(standard_tools=[get_current_weather, get_restaurant_recommendation])
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=handle_sigint)
await runner.run(task)
if __name__ == "__main__":
from pipecat.examples.run import main
main(run_example, transport_params=transport_params)

View File

@@ -31,7 +31,8 @@ dependencies = [
"pyloudnorm~=0.1.1",
"resampy~=0.4.3",
"soxr~=0.5.0",
"openai~=1.70.0"
"openai~=1.70.0",
"docstring_parser~=0.16"
]
[project.urls]

View File

@@ -0,0 +1,228 @@
import inspect
import types
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Protocol,
Set,
Tuple,
Union,
get_args,
get_origin,
get_type_hints,
)
import docstring_parser
from pipecat.adapters.schemas.function_schema import FunctionSchema
class DirectFunction(Protocol):
"""Protocol for a "direct" function that handles LLM function calls.
"Direct" functions' metadata is automatically extracted from their function signature and
docstrings, allowing them to be used without accompanying function configurations (as
`FunctionSchema`s or in provider-specific formats).
"""
async def __call__(self, params: "FunctionCallParams", **kwargs: Any) -> None: ...
class BaseDirectFunctionWrapper:
"""
Base class for a wrapper around a DirectFunction that:
- extracts metadata from the function signature and docstring
- using that metadata, generates a corresponding FunctionSchema
"""
@classmethod
def special_first_param_name(cls) -> str:
"""The name of the "special" first function parameter that is ignored by the metadata
extraction, as it's not relevant to the LLM.
"""
raise NotImplementedError("Subclasses must define the special first parameter name.")
def __init__(self, function: Callable):
self.__class__.validate_function(function)
self.function = function
self._initialize_metadata()
@classmethod
def validate_function(cls, function: Callable) -> None:
if not inspect.iscoroutinefunction(function):
raise Exception(f"Direct function {function.__name__} must be async")
params = list(inspect.signature(function).parameters.items())
special_first_param_name = cls.special_first_param_name()
if len(params) == 0:
raise Exception(
f"Direct function {function.__name__} must have at least one parameter ({special_first_param_name})"
)
first_param_name = params[0][0]
if first_param_name != special_first_param_name:
raise Exception(
f"Direct function {function.__name__} first parameter must be named '{special_first_param_name}'"
)
def to_function_schema(self) -> FunctionSchema:
return FunctionSchema(
name=self.name,
description=self.description,
properties=self.properties,
required=self.required,
)
def _initialize_metadata(self):
# Get function name
self.name = self.function.__name__
# Parse docstring for description and parameters
docstring = docstring_parser.parse(inspect.getdoc(self.function))
# Get function description
self.description = (docstring.description or "").strip()
# Get function parameters as JSON schemas, and the list of required parameters
self.properties, self.required = self._get_parameters_as_jsonschema(
self.function, docstring.params
)
# TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function
def _get_parameters_as_jsonschema(
self, func: Callable, docstring_params: List[docstring_parser.DocstringParam]
) -> Tuple[Dict[str, Any], List[str]]:
"""
Get function parameters as a dictionary of JSON schemas and a list of required parameters.
Ignore the first parameter, as it's expected to be the "special" one.
Args:
func: Function to get parameters from
docstring_params: List of parameters extracted from the function's docstring
Returns:
A tuple containing:
- A dictionary mapping each function parameter to its JSON schema
- A list of required parameter names
"""
sig = inspect.signature(func)
hints = get_type_hints(func)
properties = {}
required = []
for name, param in sig.parameters.items():
# Ignore 'self' parameter
if name == "self":
continue
# Ignore the first parameter, which is expected to be the "special" one
# (We have already validated that this is the case in validate_function())
is_first_param = name == next(iter(sig.parameters))
if is_first_param:
continue
type_hint = hints.get(name)
# Convert type hint to JSON schema
properties[name] = self._typehint_to_jsonschema(type_hint)
# Add whether the parameter is required
# If the parameter has no default value, it's required
if param.default is inspect.Parameter.empty:
required.append(name)
# Add parameter description from docstring
for doc_param in docstring_params:
if doc_param.arg_name == name:
properties[name]["description"] = doc_param.description or ""
return properties, required
def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]:
"""
Convert a Python type hint to a JSON Schema.
Args:
type_hint: A Python type hint
Returns:
A dictionary representing the JSON Schema
"""
if type_hint is None:
return {}
# Handle basic types
if type_hint is type(None):
return {"type": "null"}
if type_hint is str:
return {"type": "string"}
elif type_hint is int:
return {"type": "integer"}
elif type_hint is float:
return {"type": "number"}
elif type_hint is bool:
return {"type": "boolean"}
elif type_hint is dict or type_hint is Dict:
return {"type": "object"}
elif type_hint is list or type_hint is List:
return {"type": "array"}
# Get origin and arguments for complex types
origin = get_origin(type_hint)
args = get_args(type_hint)
# Handle Optional/Union types
if origin is Union or origin is types.UnionType:
return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]}
# Handle List, Tuple, Set with specific item types
if origin in (list, List, tuple, Tuple, set, Set) and args:
return {"type": "array", "items": self._typehint_to_jsonschema(args[0])}
# Handle Dict with specific key/value types
if origin in (dict, Dict) and len(args) == 2:
# For JSON Schema, keys must be strings
return {"type": "object", "additionalProperties": self._typehint_to_jsonschema(args[1])}
# Handle TypedDict
if hasattr(type_hint, "__annotations__"):
properties = {}
required = []
# NOTE: this does not yet support some fields being required and others not, which could happen when:
# - the base class is a TypedDict with required fields (total=True or not specified) and the derived class has optional fields (total=False)
# - Python 3.11+ NotRequired is used
all_fields_required = getattr(type_hint, "__total__", True)
for field_name, field_type in get_type_hints(type_hint).items():
properties[field_name] = self._typehint_to_jsonschema(field_type)
if all_fields_required:
required.append(field_name)
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
# Default to any type if we can't determine the specific schema
return {}
class DirectFunctionWrapper(BaseDirectFunctionWrapper):
"""
Wrapper around a DirectFunction that:
- extracts metadata from the function signature and docstring
- generates a corresponding FunctionSchema
- helps with function invocation
"""
@classmethod
def special_first_param_name(cls) -> str:
return "params"
async def invoke(self, args: Mapping[str, Any], params: "FunctionCallParams"):
return await self.function(params=params, **args)

View File

@@ -13,6 +13,7 @@ and custom adapter-specific tools in the Pipecat framework.
from enum import Enum
from typing import Any, Dict, List, Optional
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
from pipecat.adapters.schemas.function_schema import FunctionSchema
@@ -36,7 +37,7 @@ class ToolsSchema:
def __init__(
self,
standard_tools: List[FunctionSchema],
standard_tools: List[FunctionSchema | DirectFunction],
custom_tools: Optional[Dict[AdapterType, List[Dict[str, Any]]]] = None,
) -> None:
"""Initialize the tools schema.
@@ -46,7 +47,20 @@ class ToolsSchema:
custom_tools: Dictionary mapping adapter types to their custom tool definitions.
These tools may not follow the FunctionSchema format (e.g., search_tool).
"""
self._standard_tools = standard_tools
def _map_standard_tools(tools):
schemas = []
for tool in tools:
if isinstance(tool, FunctionSchema):
schemas.append(tool)
elif callable(tool):
wrapper = DirectFunctionWrapper(tool)
schemas.append(wrapper.to_function_schema())
else:
raise TypeError(f"Unsupported tool type: {type(tool)}")
return schemas
self._standard_tools = _map_standard_tools(standard_tools)
self._custom_tools = custom_tools
@property

View File

@@ -8,12 +8,33 @@
import asyncio
import inspect
import types
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)
import docstring_parser
from loguru import logger
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
CancelFrame,
@@ -94,7 +115,7 @@ class FunctionCallRegistryItem:
"""
function_name: Optional[str]
handler: FunctionCallHandler
handler: FunctionCallHandler | "DirectFunctionWrapper"
cancel_on_interruption: bool
@@ -285,6 +306,19 @@ class LLMService(AIService):
self._start_callbacks[function_name] = start_callback
def register_direct_function(
self,
handler: DirectFunction,
*,
cancel_on_interruption: bool = True,
):
wrapper = DirectFunctionWrapper(handler)
self._functions[wrapper.name] = FunctionCallRegistryItem(
function_name=wrapper.name,
handler=wrapper,
cancel_on_interruption=cancel_on_interruption,
)
def unregister_function(self, function_name: Optional[str]):
"""Remove a registered function handler.
@@ -295,6 +329,11 @@ class LLMService(AIService):
if self._start_callbacks[function_name]:
del self._start_callbacks[function_name]
def unregister_direct_function(self, handler: Any):
wrapper = DirectFunctionWrapper(handler)
del self._functions[wrapper.name]
# Note: no need to remove start callback here, as direct functions don't support start callbacks.
def has_function(self, function_name: str):
"""Check if a function handler is registered.
@@ -474,35 +513,50 @@ class LLMService(AIService):
await self.push_frame(result_frame_downstream, FrameDirection.DOWNSTREAM)
await self.push_frame(result_frame_upstream, FrameDirection.UPSTREAM)
signature = inspect.signature(item.handler)
if len(signature.parameters) > 1:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.",
DeprecationWarning,
)
await item.handler(
runner_item.function_name,
runner_item.tool_call_id,
runner_item.arguments,
self,
runner_item.context,
function_call_result_callback,
if isinstance(item.handler, DirectFunctionWrapper):
# Handler is a DirectFunctionWrapper
await item.handler.invoke(
args=runner_item.arguments,
params=FunctionCallParams(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
llm=self,
context=runner_item.context,
result_callback=function_call_result_callback,
),
)
else:
params = FunctionCallParams(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
llm=self,
context=runner_item.context,
result_callback=function_call_result_callback,
)
await item.handler(params)
# Handler is a FunctionCallHandler
signature = inspect.signature(item.handler)
if len(signature.parameters) > 1:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.",
DeprecationWarning,
)
await item.handler(
runner_item.function_name,
runner_item.tool_call_id,
runner_item.arguments,
self,
runner_item.context,
function_call_result_callback,
)
else:
params = FunctionCallParams(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
llm=self,
context=runner_item.context,
result_callback=function_call_result_callback,
)
await item.handler(params)
async def _cancel_function_call(self, function_name: Optional[str]):
cancelled_tasks = set()

View File

@@ -0,0 +1,229 @@
import asyncio
import unittest
from typing import Optional, TypedDict, Union
from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper
from pipecat.services.llm_service import FunctionCallParams
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
class TestDirectFunction(unittest.TestCase):
def test_name_is_set_from_function(self):
async def my_function(params: FunctionCallParams):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function)
self.assertEqual(func.name, "my_function")
def test_description_is_set_from_function(self):
async def my_function_short_description(params: FunctionCallParams):
"""This is a test function."""
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_short_description)
self.assertEqual(func.description, "This is a test function.")
async def my_function_long_description(params: FunctionCallParams):
"""
This is a test function.
It does some really cool stuff.
Trust me, you'll want to use it.
"""
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_long_description)
self.assertEqual(
func.description,
"This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.",
)
def test_properties_are_set_from_function(self):
async def my_function_no_params(params: FunctionCallParams):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_no_params)
self.assertEqual(func.properties, {})
async def my_function_simple_params(
params: FunctionCallParams, name: str, age: int, height: Union[float, None]
):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_simple_params)
self.assertEqual(
func.properties,
{
"name": {"type": "string"},
"age": {"type": "integer"},
"height": {"anyOf": [{"type": "number"}, {"type": "null"}]},
},
)
async def my_function_complex_params(
params: FunctionCallParams,
address_lines: list[str],
nickname: str | int | float,
extra: Optional[dict[str, str]],
):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_complex_params)
self.assertEqual(
func.properties,
{
"address_lines": {"type": "array", "items": {"type": "string"}},
"nickname": {
"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}]
},
"extra": {
"anyOf": [
{"type": "object", "additionalProperties": {"type": "string"}},
{"type": "null"},
]
},
},
)
class MyInfo1(TypedDict):
name: str
age: int
class MyInfo2(TypedDict, total=False):
name: str
age: int
async def my_function_complex_type_params(
params: FunctionCallParams, info1: MyInfo1, info2: MyInfo2
):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_complex_type_params)
self.assertEqual(
func.properties,
{
"info1": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
},
"info2": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
},
},
)
def test_required_is_set_from_function(self):
async def my_function_no_params(params: FunctionCallParams):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_no_params)
self.assertEqual(func.required, [])
async def my_function_simple_params(
params: FunctionCallParams, name: str, age: int, height: Union[float, None] = None
):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_simple_params)
self.assertEqual(func.required, ["name", "age"])
async def my_function_complex_params(
params: FunctionCallParams,
address_lines: Optional[list[str]],
nickname: str | int = "Bud",
extra: Optional[dict[str, str]] = None,
):
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function_complex_params)
self.assertEqual(func.required, ["address_lines"])
def test_property_descriptions_are_set_from_function(self):
async def my_function(
params: FunctionCallParams, name: str, age: int, height: Union[float, None]
):
"""
This is a test function.
Args:
name (str): The name of the person.
age (int): The age of the person.
height (float | None): The height of the person in meters. Defaults to None.
"""
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function)
# Validate that the function description is still set correctly even with the longer docstring
self.assertEqual(func.description, "This is a test function.")
# Validate that the property descriptions are set correctly
self.assertEqual(
func.properties,
{
"name": {"type": "string", "description": "The name of the person."},
"age": {"type": "integer", "description": "The age of the person."},
"height": {
"anyOf": [{"type": "number"}, {"type": "null"}],
"description": "The height of the person in meters. Defaults to None.",
},
},
)
def test_invalid_functions_fail_validation(self):
def my_function_non_async(params: FunctionCallParams):
return {"status": "success"}, None
with self.assertRaises(Exception):
DirectFunctionWrapper(function=my_function_non_async)
async def my_function_missing_params():
return {"status": "success"}, None
with self.assertRaises(Exception):
DirectFunctionWrapper(my_function_missing_params)
async def my_function_misplaced_params(foo: str, params: FunctionCallParams):
return {"status": "success"}, None
with self.assertRaises(Exception):
DirectFunctionWrapper(my_function_misplaced_params)
def test_invoke_calls_function_with_args_and_params_object(self):
called = {}
class DummyParams:
pass
async def my_function(params: DummyParams, name: str, age: int):
called["params"] = params
called["name"] = name
called["age"] = age
return {"status": "success"}, None
func = DirectFunctionWrapper(function=my_function)
params = DummyParams()
args = {"name": "Alice", "age": 30}
result = asyncio.run(func.invoke(args=args, params=params))
self.assertEqual(result, ({"status": "success"}, None))
self.assertIs(called["params"], params)
self.assertEqual(called["name"], "Alice")
self.assertEqual(called["age"], 30)
if __name__ == "__main__":
unittest.main()