Add manual opener tool calls to Assistant model and API

- Introduced `manual_opener_tool_calls` field in the Assistant model to support custom tool calls.
- Updated AssistantBase and AssistantUpdate schemas to include the new field.
- Implemented normalization and migration logic for handling manual opener tool calls in the API.
- Enhanced runtime metadata to include manual opener tool calls in responses.
- Updated tests to validate the new functionality and ensure proper handling of tool calls.
- Refactored tool ID normalization to support legacy tool names for backward compatibility.
This commit is contained in:
Xin Wang
2026-03-02 12:34:42 +08:00
parent b5cdb76e52
commit 00b88c5afa
14 changed files with 806 additions and 74 deletions

View File

@@ -146,8 +146,8 @@ class DuplexPipeline:
"required": [],
},
},
"voice_message_prompt": {
"name": "voice_message_prompt",
"voice_msg_prompt": {
"name": "voice_msg_prompt",
"description": "Speak a message prompt on client side",
"parameters": {
"type": "object",
@@ -238,11 +238,21 @@ class DuplexPipeline:
"turn_off_camera",
"increase_volume",
"decrease_volume",
"voice_message_prompt",
"voice_msg_prompt",
"text_msg_prompt",
"voice_choice_prompt",
"text_choice_prompt",
})
_TOOL_NAME_ALIASES = {
"voice_message_prompt": "voice_msg_prompt",
}
@classmethod
def _normalize_tool_name(cls, raw_name: Any) -> str:
name = str(raw_name or "").strip()
if not name:
return ""
return cls._TOOL_NAME_ALIASES.get(name, name)
def __init__(
self,
@@ -369,6 +379,7 @@ class DuplexPipeline:
self._runtime_first_turn_mode: str = "bot_first"
self._runtime_greeting: Optional[str] = None
self._runtime_generated_opener_enabled: Optional[bool] = None
self._runtime_manual_opener_tool_calls: List[Any] = []
self._runtime_opener_audio: Dict[str, Any] = {}
self._runtime_barge_in_enabled: Optional[bool] = None
self._runtime_barge_in_min_duration_ms: Optional[int] = None
@@ -463,6 +474,9 @@ class DuplexPipeline:
generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled"))
if generated_opener_flag is not None:
self._runtime_generated_opener_enabled = generated_opener_flag
if "manualOpenerToolCalls" in metadata:
manual_calls = metadata.get("manualOpenerToolCalls")
self._runtime_manual_opener_tool_calls = manual_calls if isinstance(manual_calls, list) else []
services = metadata.get("services") or {}
if isinstance(services, dict):
@@ -571,6 +585,10 @@ class DuplexPipeline:
"tools": {
"allowlist": self._resolved_tool_allowlist(),
},
"opener": {
"generated": self._generated_opener_enabled(),
"manualToolCallCount": len(self._resolved_manual_opener_tool_calls()),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
@@ -965,6 +983,11 @@ class DuplexPipeline:
logger.info("Initial generated opener started with tool-calling path")
return
if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls():
self._start_turn()
self._start_response()
await self._execute_manual_opener_tool_calls()
greeting_to_speak = self.conversation.greeting
if self._generated_opener_enabled():
generated_greeting = await self._generate_runtime_greeting()
@@ -975,8 +998,10 @@ class DuplexPipeline:
if not greeting_to_speak:
return
self._start_turn()
self._start_response()
if not self._current_turn_id:
self._start_turn()
if not self._current_response_id:
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
@@ -1551,7 +1576,7 @@ class DuplexPipeline:
seen: set[str] = set()
for item in self._runtime_tools:
if isinstance(item, str):
tool_name = item.strip()
tool_name = self._normalize_tool_name(item)
if not tool_name or tool_name in seen:
continue
seen.add(tool_name)
@@ -1585,7 +1610,7 @@ class DuplexPipeline:
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
fn_name = str(fn.get("name")).strip()
fn_name = self._normalize_tool_name(fn.get("name"))
if not fn_name or fn_name in seen:
continue
seen.add(fn_name)
@@ -1602,7 +1627,7 @@ class DuplexPipeline:
continue
if item.get("name"):
item_name = str(item.get("name")).strip()
item_name = self._normalize_tool_name(item.get("name"))
if not item_name or item_name in seen:
continue
seen.add(item_name)
@@ -1622,7 +1647,7 @@ class DuplexPipeline:
result: Dict[str, str] = {}
for item in self._runtime_tools:
if isinstance(item, str):
name = item.strip()
name = self._normalize_tool_name(item)
if name in self._DEFAULT_CLIENT_EXECUTORS:
result[name] = "client"
continue
@@ -1630,9 +1655,9 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name"))
name = self._normalize_tool_name(fn.get("name"))
else:
name = str(item.get("name") or "").strip()
name = self._normalize_tool_name(item.get("name"))
if not name:
continue
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
@@ -1647,9 +1672,9 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name")).strip()
name = self._normalize_tool_name(fn.get("name"))
else:
name = str(item.get("name") or "").strip()
name = self._normalize_tool_name(item.get("name"))
if not name:
continue
raw_defaults = item.get("defaultArgs")
@@ -1666,9 +1691,9 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name")).strip()
name = self._normalize_tool_name(fn.get("name"))
else:
name = str(item.get("name") or "").strip()
name = self._normalize_tool_name(item.get("name"))
if not name:
continue
raw_wait = item.get("waitForResponse")
@@ -1685,12 +1710,12 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
alias = str(fn.get("name")).strip()
alias = self._normalize_tool_name(fn.get("name"))
else:
alias = str(item.get("name") or "").strip()
alias = self._normalize_tool_name(item.get("name"))
if not alias:
continue
tool_id = str(item.get("toolId") or item.get("tool_id") or alias).strip()
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or alias)
if tool_id:
result[alias] = tool_id
return result
@@ -1702,9 +1727,9 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name")).strip()
name = self._normalize_tool_name(fn.get("name"))
else:
name = str(item.get("name") or "").strip()
name = self._normalize_tool_name(item.get("name"))
if not name:
continue
display_name = str(
@@ -1714,7 +1739,7 @@ class DuplexPipeline:
).strip()
if display_name:
result[name] = display_name
tool_id = str(item.get("toolId") or item.get("tool_id") or "").strip()
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or "")
if tool_id:
result[tool_id] = display_name
return result
@@ -1723,7 +1748,7 @@ class DuplexPipeline:
names: set[str] = set()
for item in self._runtime_tools:
if isinstance(item, str):
name = item.strip()
name = self._normalize_tool_name(item)
if name:
names.add(name)
continue
@@ -1731,25 +1756,57 @@ class DuplexPipeline:
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
names.add(str(fn.get("name")).strip())
names.add(self._normalize_tool_name(fn.get("name")))
elif item.get("name"):
names.add(str(item.get("name")).strip())
names.add(self._normalize_tool_name(item.get("name")))
return sorted([name for name in names if name])
def _resolved_manual_opener_tool_calls(self) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
for item in self._runtime_manual_opener_tool_calls:
if not isinstance(item, dict):
continue
tool_name = self._normalize_tool_name(str(
item.get("toolName")
or item.get("tool_name")
or item.get("name")
or ""
).strip())
if not tool_name:
continue
args_raw = item.get("arguments")
args: Dict[str, Any] = {}
if isinstance(args_raw, dict):
args = dict(args_raw)
elif isinstance(args_raw, str):
text_value = args_raw.strip()
if text_value:
try:
parsed = json.loads(text_value)
if isinstance(parsed, dict):
args = parsed
except Exception:
logger.warning(f"[OpenerTool] ignore invalid JSON args for tool={tool_name}")
result.append({"toolName": tool_name, "arguments": args})
return result[:8]
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
fn = tool_call.get("function")
if isinstance(fn, dict):
return str(fn.get("name") or "").strip()
return self._normalize_tool_name(fn.get("name"))
return ""
def _tool_id_for_name(self, tool_name: str) -> str:
return str(self._runtime_tool_id_map.get(tool_name) or tool_name).strip()
normalized = self._normalize_tool_name(tool_name)
return self._normalize_tool_name(self._runtime_tool_id_map.get(normalized) or normalized)
def _tool_display_name(self, tool_name: str) -> str:
return str(self._runtime_tool_display_names.get(tool_name) or tool_name).strip()
normalized = self._normalize_tool_name(tool_name)
return str(self._runtime_tool_display_names.get(normalized) or normalized).strip()
def _tool_wait_for_response(self, tool_name: str) -> bool:
return bool(self._runtime_tool_wait_for_response.get(tool_name, False))
normalized = self._normalize_tool_name(tool_name)
return bool(self._runtime_tool_wait_for_response.get(normalized, False))
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
name = self._tool_name(tool_call)
@@ -1774,7 +1831,8 @@ class DuplexPipeline:
return {}
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
defaults = self._runtime_tool_default_args.get(tool_name)
normalized_tool_name = self._normalize_tool_name(tool_name)
defaults = self._runtime_tool_default_args.get(normalized_tool_name)
if not isinstance(defaults, dict) or not defaults:
return args
merged = dict(defaults)
@@ -1782,6 +1840,84 @@ class DuplexPipeline:
merged.update(args)
return merged
async def _execute_manual_opener_tool_calls(self) -> None:
calls = self._resolved_manual_opener_tool_calls()
if not calls:
return
for call in calls:
tool_name = str(call.get("toolName") or "").strip()
if not tool_name:
continue
tool_id = self._tool_id_for_name(tool_name)
tool_display_name = self._tool_display_name(tool_name) or tool_name
tool_arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
call_id = f"call_opener_{uuid.uuid4().hex[:10]}"
wait_for_response = self._tool_wait_for_response(tool_name)
tool_call = {
"id": call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(merged_tool_arguments, ensure_ascii=False),
},
}
executor = self._tool_executor(tool_call)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
tool_id=tool_id,
tool_display_name=tool_display_name,
wait_for_response=wait_for_response,
arguments=merged_tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call={**tool_call, "executor": executor, "wait_for_response": wait_for_response},
)
},
priority=22,
)
logger.info(
f"[OpenerTool] execute name={tool_name} call_id={call_id} executor={executor} "
f"wait_for_response={wait_for_response}"
)
if executor == "client":
self._pending_client_tool_call_ids.add(call_id)
if wait_for_response:
result = await self._wait_for_single_tool_result(call_id)
await self._emit_tool_result(result, source="client")
continue
call_for_executor = dict(tool_call)
fn_for_executor = (
dict(call_for_executor.get("function"))
if isinstance(call_for_executor.get("function"), dict)
else None
)
if isinstance(fn_for_executor, dict):
fn_for_executor["name"] = tool_id
call_for_executor["function"] = fn_for_executor
try:
result = await asyncio.wait_for(
self._server_tool_executor(call_for_executor),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
result = {
"tool_call_id": call_id,
"name": tool_name,
"output": {"message": "server tool timeout"},
"status": {"code": 504, "message": "server_tool_timeout"},
}
await self._emit_tool_result(result, source="server")
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0