added custom_tools support for OpenAI adapters

This commit is contained in:
Om Chauhan
2026-04-07 10:08:04 +05:30
parent 0acfb4dd49
commit 4bef85e363
5 changed files with 153 additions and 10 deletions

View File

@@ -21,10 +21,12 @@ class AdapterType(Enum):
"""Supported adapter types for custom tools.
Parameters:
GEMINI: Google Gemini adapter - currently the only service supporting custom tools.
GEMINI: Google Gemini adapter.
OPENAI: OpenAI adapter (Chat Completions and Responses API).
"""
GEMINI = "gemini" # that is the only service where we are able to add custom tools for now
GEMINI = "gemini"
OPENAI = "openai"
class ToolsSchema:

View File

@@ -17,7 +17,7 @@ from openai.types.chat import (
)
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
@@ -96,7 +96,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
"tool_choice": context.tool_choice,
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]:
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
"""Convert function schemas to OpenAI's function-calling format.
Args:
@@ -107,10 +107,14 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
with ChatCompletion API.
"""
functions_schema = tools_schema.standard_tools
return [
formatted_standard_tools = [
ChatCompletionToolParam(type="function", function=func.to_default_dict())
for func in functions_schema
]
custom_openai_tools = []
if tools_schema.custom_tools:
custom_openai_tools = tools_schema.custom_tools.get(AdapterType.OPENAI, [])
return formatted_standard_tools + custom_openai_tools
def get_messages_for_logging(self, context: LLMContext) -> List[Dict[str, Any]]:
"""Get messages from a universal LLM context in a format ready for logging about OpenAI.

View File

@@ -15,7 +15,7 @@ from loguru import logger
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
from pipecat.services.openai.realtime import events
@@ -236,4 +236,10 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
List of function definitions in OpenAI Realtime format.
"""
functions_schema = tools_schema.standard_tools
return [self._to_openai_realtime_function_format(func) for func in functions_schema]
formatted_standard_tools = [
self._to_openai_realtime_function_format(func) for func in functions_schema
]
custom_openai_tools = []
if tools_schema.custom_tools:
custom_openai_tools = tools_schema.custom_tools.get(AdapterType.OPENAI, [])
return formatted_standard_tools + custom_openai_tools

View File

@@ -13,7 +13,7 @@ from openai._types import NotGiven as OpenAINotGiven
from openai.types.responses import FunctionToolParam, ResponseInputItemParam
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
@@ -106,7 +106,7 @@ class OpenAIResponsesLLMAdapter(BaseLLMAdapter[OpenAIResponsesLLMInvocationParam
return params
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[FunctionToolParam]:
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
"""Convert function schemas to Responses API function tool format.
Args:
@@ -128,7 +128,10 @@ class OpenAIResponsesLLMAdapter(BaseLLMAdapter[OpenAIResponsesLLMInvocationParam
if "description" in d:
tool["description"] = d["description"]
result.append(tool)
return result
custom_openai_tools = []
if tools_schema.custom_tools:
custom_openai_tools = tools_schema.custom_tools.get(AdapterType.OPENAI, [])
return result + custom_openai_tools
def get_messages_for_logging(self, context: LLMContext) -> List[Dict[str, Any]]:
"""Get messages from context in a format ready for logging.

View File

@@ -15,6 +15,7 @@ from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.adapters.services.open_ai_responses_adapter import OpenAIResponsesLLMAdapter
class TestFunctionAdapters(unittest.TestCase):
@@ -176,6 +177,133 @@ class TestFunctionAdapters(unittest.TestCase):
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
def test_openai_adapter_with_custom_tools(self):
"""Test OpenAI adapter appends custom tools."""
tool_search = {"type": "tool_search"}
expected = [
ChatCompletionToolParam(
type="function",
function={
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
},
"required": ["location", "format"],
},
},
),
tool_search,
]
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.OPENAI: [tool_search]}
assert OpenAILLMAdapter().to_provider_tools_format(tools_def) == expected
def test_openai_responses_adapter_with_custom_tools(self):
"""Test OpenAI Responses adapter appends custom tools."""
tool_search = {"type": "tool_search"}
expected = [
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
},
"required": ["location", "format"],
},
"strict": None,
},
tool_search,
]
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.OPENAI: [tool_search]}
assert OpenAIResponsesLLMAdapter().to_provider_tools_format(tools_def) == expected
def test_openai_responses_adapter(self):
"""Test OpenAI Responses adapter format transformation."""
expected = [
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
},
"required": ["location", "format"],
},
"strict": None,
}
]
assert OpenAIResponsesLLMAdapter().to_provider_tools_format(self.tools_def) == expected
def test_openai_realtime_adapter_with_custom_tools(self):
"""Test OpenAI Realtime adapter appends custom tools."""
tool_search = {"type": "tool_search"}
expected = [
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
},
"required": ["location", "format"],
},
},
tool_search,
]
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.OPENAI: [tool_search]}
assert OpenAIRealtimeLLMAdapter().to_provider_tools_format(tools_def) == expected
def test_openai_adapter_ignores_other_adapter_custom_tools(self):
"""Test that OpenAI adapter ignores custom tools for other adapters."""
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.GEMINI: [{"google_search": {}}]}
result = OpenAILLMAdapter().to_provider_tools_format(tools_def)
assert len(result) == 1
def test_bedrock_adapter(self):
"""Test AWS Bedrock adapter format transformation."""
expected = [