diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d6236f7d..d200d58b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for providing "direct" functions, which don't need an + accompanying `FlowsFunctionSchema` or function definition dict. Instead, + metadata (i.e. `name`, `description`, `properties`, and `required`) are + automatically extracted from a combination of the function signature and + docstring. + + Usage: + + ```python + # "Direct" function + # `params` must be the first parameter + async def do_something(params: FunctionCallParams, foo: int, bar: str = ""): + """ + Do something interesting. + + Args: + foo (int): The foo to do something interesting with. + bar (string): The bar to do something interesting with. + """ + + result = await process(foo, bar) + await params.result_callback({"result": result}) + + # ... + + llm.register_direct_function(do_something) + + # ... + + tools = ToolsSchema(standard_tools=[do_something]) + ``` + - Added `watchdog_coroutine()`. This is a watchdog helper for couroutines. So, if you have a coroutine that is waiting for a result and that takes a long time, you will need to wrap it with `watchdog_coroutine()` so the watchdog diff --git a/examples/foundational/14t-function-calling-direct.py b/examples/foundational/14t-function-calling-direct.py new file mode 100644 index 000000000..7778461f6 --- /dev/null +++ b/examples/foundational/14t-function-calling-direct.py @@ -0,0 +1,146 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import TTSSpeakFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams +from pipecat.transports.services.daily import DailyParams + +load_dotenv(override=True) + + +async def get_current_weather(params: FunctionCallParams, location: str, format: str): + """ + Get the current weather. + + Args: + location (str): The city and state, e.g. "San Francisco, CA". + format (str): The temperature unit to use. Must be either "celsius" or "fahrenheit". Infer this from the user's location. + """ + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +async def get_restaurant_recommendation(params: FunctionCallParams, location: str): + """ + Get a restaurant recommendation. + + Args: + location (str): The city and state, e.g. "San Francisco, CA". + """ + await params.result_callback({"name": "The Golden Dragon"}) + + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), +} + + +async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) + + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_direct_function(get_current_weather) + llm.register_direct_function(get_restaurant_recommendation) + + @llm.event_handler("on_function_calls_started") + async def on_function_calls_started(service, function_calls): + await tts.queue_frame(TTSSpeakFrame("Let me check on that.")) + + tools = ToolsSchema(standard_tools=[get_current_weather, get_restaurant_recommendation]) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=handle_sigint) + + await runner.run(task) + + +if __name__ == "__main__": + from pipecat.examples.run import main + + main(run_example, transport_params=transport_params) diff --git a/pyproject.toml b/pyproject.toml index 9dfd84ec6..489da06dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ dependencies = [ "pyloudnorm~=0.1.1", "resampy~=0.4.3", "soxr~=0.5.0", - "openai~=1.70.0" + "openai~=1.70.0", + "docstring_parser~=0.16" ] [project.urls] diff --git a/src/pipecat/adapters/schemas/direct_function.py b/src/pipecat/adapters/schemas/direct_function.py new file mode 100644 index 000000000..54763c3d8 --- /dev/null +++ b/src/pipecat/adapters/schemas/direct_function.py @@ -0,0 +1,228 @@ +import inspect +import types +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Protocol, + Set, + Tuple, + Union, + get_args, + get_origin, + get_type_hints, +) + +import docstring_parser + +from pipecat.adapters.schemas.function_schema import FunctionSchema + + +class DirectFunction(Protocol): + """Protocol for a "direct" function that handles LLM function calls. + + "Direct" functions' metadata is automatically extracted from their function signature and + docstrings, allowing them to be used without accompanying function configurations (as + `FunctionSchema`s or in provider-specific formats). + """ + + async def __call__(self, params: "FunctionCallParams", **kwargs: Any) -> None: ... + + +class BaseDirectFunctionWrapper: + """ + Base class for a wrapper around a DirectFunction that: + - extracts metadata from the function signature and docstring + - using that metadata, generates a corresponding FunctionSchema + """ + + @classmethod + def special_first_param_name(cls) -> str: + """The name of the "special" first function parameter that is ignored by the metadata + extraction, as it's not relevant to the LLM. + """ + raise NotImplementedError("Subclasses must define the special first parameter name.") + + def __init__(self, function: Callable): + self.__class__.validate_function(function) + self.function = function + self._initialize_metadata() + + @classmethod + def validate_function(cls, function: Callable) -> None: + if not inspect.iscoroutinefunction(function): + raise Exception(f"Direct function {function.__name__} must be async") + params = list(inspect.signature(function).parameters.items()) + special_first_param_name = cls.special_first_param_name() + if len(params) == 0: + raise Exception( + f"Direct function {function.__name__} must have at least one parameter ({special_first_param_name})" + ) + first_param_name = params[0][0] + if first_param_name != special_first_param_name: + raise Exception( + f"Direct function {function.__name__} first parameter must be named '{special_first_param_name}'" + ) + + def to_function_schema(self) -> FunctionSchema: + return FunctionSchema( + name=self.name, + description=self.description, + properties=self.properties, + required=self.required, + ) + + def _initialize_metadata(self): + # Get function name + self.name = self.function.__name__ + + # Parse docstring for description and parameters + docstring = docstring_parser.parse(inspect.getdoc(self.function)) + + # Get function description + self.description = (docstring.description or "").strip() + + # Get function parameters as JSON schemas, and the list of required parameters + self.properties, self.required = self._get_parameters_as_jsonschema( + self.function, docstring.params + ) + + # TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function + def _get_parameters_as_jsonschema( + self, func: Callable, docstring_params: List[docstring_parser.DocstringParam] + ) -> Tuple[Dict[str, Any], List[str]]: + """ + Get function parameters as a dictionary of JSON schemas and a list of required parameters. + Ignore the first parameter, as it's expected to be the "special" one. + + Args: + func: Function to get parameters from + docstring_params: List of parameters extracted from the function's docstring + + Returns: + A tuple containing: + - A dictionary mapping each function parameter to its JSON schema + - A list of required parameter names + """ + + sig = inspect.signature(func) + hints = get_type_hints(func) + properties = {} + required = [] + + for name, param in sig.parameters.items(): + # Ignore 'self' parameter + if name == "self": + continue + + # Ignore the first parameter, which is expected to be the "special" one + # (We have already validated that this is the case in validate_function()) + is_first_param = name == next(iter(sig.parameters)) + if is_first_param: + continue + + type_hint = hints.get(name) + + # Convert type hint to JSON schema + properties[name] = self._typehint_to_jsonschema(type_hint) + + # Add whether the parameter is required + # If the parameter has no default value, it's required + if param.default is inspect.Parameter.empty: + required.append(name) + + # Add parameter description from docstring + for doc_param in docstring_params: + if doc_param.arg_name == name: + properties[name]["description"] = doc_param.description or "" + + return properties, required + + def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: + """ + Convert a Python type hint to a JSON Schema. + + Args: + type_hint: A Python type hint + + Returns: + A dictionary representing the JSON Schema + """ + if type_hint is None: + return {} + + # Handle basic types + if type_hint is type(None): + return {"type": "null"} + if type_hint is str: + return {"type": "string"} + elif type_hint is int: + return {"type": "integer"} + elif type_hint is float: + return {"type": "number"} + elif type_hint is bool: + return {"type": "boolean"} + elif type_hint is dict or type_hint is Dict: + return {"type": "object"} + elif type_hint is list or type_hint is List: + return {"type": "array"} + + # Get origin and arguments for complex types + origin = get_origin(type_hint) + args = get_args(type_hint) + + # Handle Optional/Union types + if origin is Union or origin is types.UnionType: + return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]} + + # Handle List, Tuple, Set with specific item types + if origin in (list, List, tuple, Tuple, set, Set) and args: + return {"type": "array", "items": self._typehint_to_jsonschema(args[0])} + + # Handle Dict with specific key/value types + if origin in (dict, Dict) and len(args) == 2: + # For JSON Schema, keys must be strings + return {"type": "object", "additionalProperties": self._typehint_to_jsonschema(args[1])} + + # Handle TypedDict + if hasattr(type_hint, "__annotations__"): + properties = {} + required = [] + + # NOTE: this does not yet support some fields being required and others not, which could happen when: + # - the base class is a TypedDict with required fields (total=True or not specified) and the derived class has optional fields (total=False) + # - Python 3.11+ NotRequired is used + all_fields_required = getattr(type_hint, "__total__", True) + + for field_name, field_type in get_type_hints(type_hint).items(): + properties[field_name] = self._typehint_to_jsonschema(field_type) + if all_fields_required: + required.append(field_name) + + schema = {"type": "object", "properties": properties} + + if required: + schema["required"] = required + + return schema + + # Default to any type if we can't determine the specific schema + return {} + + +class DirectFunctionWrapper(BaseDirectFunctionWrapper): + """ + Wrapper around a DirectFunction that: + - extracts metadata from the function signature and docstring + - generates a corresponding FunctionSchema + - helps with function invocation + """ + + @classmethod + def special_first_param_name(cls) -> str: + return "params" + + async def invoke(self, args: Mapping[str, Any], params: "FunctionCallParams"): + return await self.function(params=params, **args) diff --git a/src/pipecat/adapters/schemas/tools_schema.py b/src/pipecat/adapters/schemas/tools_schema.py index 7ef582f7e..05710616d 100644 --- a/src/pipecat/adapters/schemas/tools_schema.py +++ b/src/pipecat/adapters/schemas/tools_schema.py @@ -13,6 +13,7 @@ and custom adapter-specific tools in the Pipecat framework. from enum import Enum from typing import Any, Dict, List, Optional +from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper from pipecat.adapters.schemas.function_schema import FunctionSchema @@ -36,7 +37,7 @@ class ToolsSchema: def __init__( self, - standard_tools: List[FunctionSchema], + standard_tools: List[FunctionSchema | DirectFunction], custom_tools: Optional[Dict[AdapterType, List[Dict[str, Any]]]] = None, ) -> None: """Initialize the tools schema. @@ -46,7 +47,20 @@ class ToolsSchema: custom_tools: Dictionary mapping adapter types to their custom tool definitions. These tools may not follow the FunctionSchema format (e.g., search_tool). """ - self._standard_tools = standard_tools + + def _map_standard_tools(tools): + schemas = [] + for tool in tools: + if isinstance(tool, FunctionSchema): + schemas.append(tool) + elif callable(tool): + wrapper = DirectFunctionWrapper(tool) + schemas.append(wrapper.to_function_schema()) + else: + raise TypeError(f"Unsupported tool type: {type(tool)}") + return schemas + + self._standard_tools = _map_standard_tools(standard_tools) self._custom_tools = custom_tools @property diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 68a74335b..98c8483fe 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -8,12 +8,33 @@ import asyncio import inspect +import types from dataclasses import dataclass -from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Mapping, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) +import docstring_parser from loguru import logger from pipecat.adapters.base_llm_adapter import BaseLLMAdapter +from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper +from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter from pipecat.frames.frames import ( CancelFrame, @@ -94,7 +115,7 @@ class FunctionCallRegistryItem: """ function_name: Optional[str] - handler: FunctionCallHandler + handler: FunctionCallHandler | "DirectFunctionWrapper" cancel_on_interruption: bool @@ -285,6 +306,19 @@ class LLMService(AIService): self._start_callbacks[function_name] = start_callback + def register_direct_function( + self, + handler: DirectFunction, + *, + cancel_on_interruption: bool = True, + ): + wrapper = DirectFunctionWrapper(handler) + self._functions[wrapper.name] = FunctionCallRegistryItem( + function_name=wrapper.name, + handler=wrapper, + cancel_on_interruption=cancel_on_interruption, + ) + def unregister_function(self, function_name: Optional[str]): """Remove a registered function handler. @@ -295,6 +329,11 @@ class LLMService(AIService): if self._start_callbacks[function_name]: del self._start_callbacks[function_name] + def unregister_direct_function(self, handler: Any): + wrapper = DirectFunctionWrapper(handler) + del self._functions[wrapper.name] + # Note: no need to remove start callback here, as direct functions don't support start callbacks. + def has_function(self, function_name: str): """Check if a function handler is registered. @@ -474,35 +513,50 @@ class LLMService(AIService): await self.push_frame(result_frame_downstream, FrameDirection.DOWNSTREAM) await self.push_frame(result_frame_upstream, FrameDirection.UPSTREAM) - signature = inspect.signature(item.handler) - if len(signature.parameters) > 1: - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.", - DeprecationWarning, - ) - - await item.handler( - runner_item.function_name, - runner_item.tool_call_id, - runner_item.arguments, - self, - runner_item.context, - function_call_result_callback, + if isinstance(item.handler, DirectFunctionWrapper): + # Handler is a DirectFunctionWrapper + await item.handler.invoke( + args=runner_item.arguments, + params=FunctionCallParams( + function_name=runner_item.function_name, + tool_call_id=runner_item.tool_call_id, + arguments=runner_item.arguments, + llm=self, + context=runner_item.context, + result_callback=function_call_result_callback, + ), ) else: - params = FunctionCallParams( - function_name=runner_item.function_name, - tool_call_id=runner_item.tool_call_id, - arguments=runner_item.arguments, - llm=self, - context=runner_item.context, - result_callback=function_call_result_callback, - ) - await item.handler(params) + # Handler is a FunctionCallHandler + signature = inspect.signature(item.handler) + if len(signature.parameters) > 1: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.", + DeprecationWarning, + ) + + await item.handler( + runner_item.function_name, + runner_item.tool_call_id, + runner_item.arguments, + self, + runner_item.context, + function_call_result_callback, + ) + else: + params = FunctionCallParams( + function_name=runner_item.function_name, + tool_call_id=runner_item.tool_call_id, + arguments=runner_item.arguments, + llm=self, + context=runner_item.context, + result_callback=function_call_result_callback, + ) + await item.handler(params) async def _cancel_function_call(self, function_name: Optional[str]): cancelled_tasks = set() diff --git a/tests/test_direct_functions.py b/tests/test_direct_functions.py new file mode 100644 index 000000000..4200d86ac --- /dev/null +++ b/tests/test_direct_functions.py @@ -0,0 +1,229 @@ +import asyncio +import unittest +from typing import Optional, TypedDict, Union + +from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper +from pipecat.services.llm_service import FunctionCallParams + +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +class TestDirectFunction(unittest.TestCase): + def test_name_is_set_from_function(self): + async def my_function(params: FunctionCallParams): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function) + self.assertEqual(func.name, "my_function") + + def test_description_is_set_from_function(self): + async def my_function_short_description(params: FunctionCallParams): + """This is a test function.""" + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_short_description) + self.assertEqual(func.description, "This is a test function.") + + async def my_function_long_description(params: FunctionCallParams): + """ + This is a test function. + + It does some really cool stuff. + + Trust me, you'll want to use it. + """ + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_long_description) + self.assertEqual( + func.description, + "This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.", + ) + + def test_properties_are_set_from_function(self): + async def my_function_no_params(params: FunctionCallParams): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_no_params) + self.assertEqual(func.properties, {}) + + async def my_function_simple_params( + params: FunctionCallParams, name: str, age: int, height: Union[float, None] + ): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_simple_params) + self.assertEqual( + func.properties, + { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "height": {"anyOf": [{"type": "number"}, {"type": "null"}]}, + }, + ) + + async def my_function_complex_params( + params: FunctionCallParams, + address_lines: list[str], + nickname: str | int | float, + extra: Optional[dict[str, str]], + ): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_complex_params) + self.assertEqual( + func.properties, + { + "address_lines": {"type": "array", "items": {"type": "string"}}, + "nickname": { + "anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}] + }, + "extra": { + "anyOf": [ + {"type": "object", "additionalProperties": {"type": "string"}}, + {"type": "null"}, + ] + }, + }, + ) + + class MyInfo1(TypedDict): + name: str + age: int + + class MyInfo2(TypedDict, total=False): + name: str + age: int + + async def my_function_complex_type_params( + params: FunctionCallParams, info1: MyInfo1, info2: MyInfo2 + ): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_complex_type_params) + self.assertEqual( + func.properties, + { + "info1": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + "info2": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + }, + }, + ) + + def test_required_is_set_from_function(self): + async def my_function_no_params(params: FunctionCallParams): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_no_params) + self.assertEqual(func.required, []) + + async def my_function_simple_params( + params: FunctionCallParams, name: str, age: int, height: Union[float, None] = None + ): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_simple_params) + self.assertEqual(func.required, ["name", "age"]) + + async def my_function_complex_params( + params: FunctionCallParams, + address_lines: Optional[list[str]], + nickname: str | int = "Bud", + extra: Optional[dict[str, str]] = None, + ): + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function_complex_params) + self.assertEqual(func.required, ["address_lines"]) + + def test_property_descriptions_are_set_from_function(self): + async def my_function( + params: FunctionCallParams, name: str, age: int, height: Union[float, None] + ): + """ + This is a test function. + + Args: + name (str): The name of the person. + age (int): The age of the person. + height (float | None): The height of the person in meters. Defaults to None. + """ + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function) + + # Validate that the function description is still set correctly even with the longer docstring + self.assertEqual(func.description, "This is a test function.") + + # Validate that the property descriptions are set correctly + self.assertEqual( + func.properties, + { + "name": {"type": "string", "description": "The name of the person."}, + "age": {"type": "integer", "description": "The age of the person."}, + "height": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "description": "The height of the person in meters. Defaults to None.", + }, + }, + ) + + def test_invalid_functions_fail_validation(self): + def my_function_non_async(params: FunctionCallParams): + return {"status": "success"}, None + + with self.assertRaises(Exception): + DirectFunctionWrapper(function=my_function_non_async) + + async def my_function_missing_params(): + return {"status": "success"}, None + + with self.assertRaises(Exception): + DirectFunctionWrapper(my_function_missing_params) + + async def my_function_misplaced_params(foo: str, params: FunctionCallParams): + return {"status": "success"}, None + + with self.assertRaises(Exception): + DirectFunctionWrapper(my_function_misplaced_params) + + def test_invoke_calls_function_with_args_and_params_object(self): + called = {} + + class DummyParams: + pass + + async def my_function(params: DummyParams, name: str, age: int): + called["params"] = params + called["name"] = name + called["age"] = age + return {"status": "success"}, None + + func = DirectFunctionWrapper(function=my_function) + params = DummyParams() + args = {"name": "Alice", "age": 30} + + result = asyncio.run(func.invoke(args=args, params=params)) + self.assertEqual(result, ({"status": "success"}, None)) + self.assertIs(called["params"], params) + self.assertEqual(called["name"], "Alice") + self.assertEqual(called["age"], 30) + + +if __name__ == "__main__": + unittest.main()