Add parameter schema and defaults to ToolResource model and schemas. Implement runtime tool resolution in assistants and tools routers, ensuring proper handling of tool parameters. Update tests to validate new functionality and ensure correct integration of parameter handling in the API.
This commit is contained in:
@@ -287,6 +287,7 @@ class DuplexPipeline:
|
||||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||||
self._runtime_tool_executor: Dict[str, str] = {}
|
||||
self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {}
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
@@ -307,6 +308,7 @@ class DuplexPipeline:
|
||||
self._last_llm_delta_emit_ms: float = 0.0
|
||||
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||||
self._initial_greeting_emitted = False
|
||||
|
||||
if self._server_tool_executor is None:
|
||||
@@ -408,9 +410,11 @@ class DuplexPipeline:
|
||||
if isinstance(tools_payload, list):
|
||||
self._runtime_tools = tools_payload
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||||
elif "tools" in metadata:
|
||||
self._runtime_tools = []
|
||||
self._runtime_tool_executor = {}
|
||||
self._runtime_tool_default_args = {}
|
||||
|
||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
@@ -1473,6 +1477,25 @@ class DuplexPipeline:
|
||||
result[name] = executor
|
||||
return result
|
||||
|
||||
def _resolved_tool_default_args_map(self) -> Dict[str, Dict[str, Any]]:
|
||||
result: Dict[str, Dict[str, Any]] = {}
|
||||
for item in self._runtime_tools:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name")).strip()
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
raw_defaults = item.get("defaultArgs")
|
||||
if raw_defaults is None:
|
||||
raw_defaults = item.get("default_args")
|
||||
if isinstance(raw_defaults, dict):
|
||||
result[name] = dict(raw_defaults)
|
||||
return result
|
||||
|
||||
def _resolved_tool_allowlist(self) -> List[str]:
|
||||
names: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
@@ -1518,6 +1541,15 @@ class DuplexPipeline:
|
||||
return {"raw": raw}
|
||||
return {}
|
||||
|
||||
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
defaults = self._runtime_tool_default_args.get(tool_name)
|
||||
if not isinstance(defaults, dict) or not defaults:
|
||||
return args
|
||||
merged = dict(defaults)
|
||||
if isinstance(args, dict):
|
||||
merged.update(args)
|
||||
return merged
|
||||
|
||||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||||
status_code = int(status.get("code") or 0) if status else 0
|
||||
@@ -1702,15 +1734,27 @@ class DuplexPipeline:
|
||||
enriched_tool_call["executor"] = executor
|
||||
tool_name = self._tool_name(enriched_tool_call) or "unknown_tool"
|
||||
call_id = str(enriched_tool_call.get("id") or "").strip()
|
||||
fn_payload = enriched_tool_call.get("function")
|
||||
fn_payload = (
|
||||
dict(enriched_tool_call.get("function"))
|
||||
if isinstance(enriched_tool_call.get("function"), dict)
|
||||
else None
|
||||
)
|
||||
raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else ""
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||||
try:
|
||||
merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False)
|
||||
except Exception:
|
||||
merged_args_text = raw_args if raw_args else "{}"
|
||||
if isinstance(fn_payload, dict):
|
||||
fn_payload["arguments"] = merged_args_text
|
||||
enriched_tool_call["function"] = fn_payload
|
||||
args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..."
|
||||
logger.info(
|
||||
f"[Tool] call requested name={tool_name} call_id={call_id} "
|
||||
f"executor={executor} args={args_preview}"
|
||||
f"executor={executor} args={args_preview} merged_args={merged_args_text}"
|
||||
)
|
||||
tool_calls.append(enriched_tool_call)
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
if executor == "client" and call_id:
|
||||
self._pending_client_tool_call_ids.add(call_id)
|
||||
await self._send_event(
|
||||
|
||||
@@ -187,6 +187,17 @@ async def execute_server_tool(
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
|
||||
resource: Optional[Dict[str, Any]] = None
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
try:
|
||||
resource = await resource_fetcher(tool_name)
|
||||
except Exception:
|
||||
resource = None
|
||||
defaults = resource.get("parameter_defaults") if isinstance(resource, dict) else None
|
||||
if isinstance(defaults, dict) and defaults:
|
||||
merged_args = dict(defaults)
|
||||
merged_args.update(args)
|
||||
args = merged_args
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
@@ -269,7 +280,6 @@ async def execute_server_tool(
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await resource_fetcher(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
@@ -143,6 +144,64 @@ def test_pipeline_assigns_default_client_executor_for_system_string_tools(monkey
|
||||
assert pipeline._tool_executor(tool_call) == "client"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_applies_default_args_to_tool_call(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_defaults",
|
||||
"type": "function",
|
||||
"function": {"name": "weather", "arguments": "{}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[LLMStreamEvent(type="done")],
|
||||
],
|
||||
)
|
||||
pipeline.apply_runtime_overrides(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"executor": "server",
|
||||
"defaultArgs": {"city": "Hangzhou", "unit": "c"},
|
||||
"function": {
|
||||
"name": "weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def _server_exec(call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
captured["call"] = call
|
||||
return {
|
||||
"tool_call_id": str(call.get("id") or ""),
|
||||
"name": "weather",
|
||||
"output": {"ok": True},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(pipeline, "_server_tool_executor", _server_exec)
|
||||
await pipeline._handle_turn("weather?")
|
||||
|
||||
sent_call = captured.get("call")
|
||||
assert isinstance(sent_call, dict)
|
||||
args_raw = sent_call.get("function", {}).get("arguments")
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else {}
|
||||
assert args.get("city") == "Hangzhou"
|
||||
assert args.get("unit") == "c"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_tool_call_results():
|
||||
msg = parse_client_message(
|
||||
|
||||
Reference in New Issue
Block a user