Add manual opener tool calls to Assistant model and API
- Introduced `manual_opener_tool_calls` field in the Assistant model to support custom tool calls. - Updated AssistantBase and AssistantUpdate schemas to include the new field. - Implemented normalization and migration logic for handling manual opener tool calls in the API. - Enhanced runtime metadata to include manual opener tool calls in responses. - Updated tests to validate the new functionality and ensure proper handling of tool calls. - Refactored tool ID normalization to support legacy tool names for backward compatibility.
This commit is contained in:
@@ -146,8 +146,8 @@ class DuplexPipeline:
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"voice_message_prompt": {
|
||||
"name": "voice_message_prompt",
|
||||
"voice_msg_prompt": {
|
||||
"name": "voice_msg_prompt",
|
||||
"description": "Speak a message prompt on client side",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
@@ -238,11 +238,21 @@ class DuplexPipeline:
|
||||
"turn_off_camera",
|
||||
"increase_volume",
|
||||
"decrease_volume",
|
||||
"voice_message_prompt",
|
||||
"voice_msg_prompt",
|
||||
"text_msg_prompt",
|
||||
"voice_choice_prompt",
|
||||
"text_choice_prompt",
|
||||
})
|
||||
_TOOL_NAME_ALIASES = {
|
||||
"voice_message_prompt": "voice_msg_prompt",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _normalize_tool_name(cls, raw_name: Any) -> str:
|
||||
name = str(raw_name or "").strip()
|
||||
if not name:
|
||||
return ""
|
||||
return cls._TOOL_NAME_ALIASES.get(name, name)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -369,6 +379,7 @@ class DuplexPipeline:
|
||||
self._runtime_first_turn_mode: str = "bot_first"
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
self._runtime_generated_opener_enabled: Optional[bool] = None
|
||||
self._runtime_manual_opener_tool_calls: List[Any] = []
|
||||
self._runtime_opener_audio: Dict[str, Any] = {}
|
||||
self._runtime_barge_in_enabled: Optional[bool] = None
|
||||
self._runtime_barge_in_min_duration_ms: Optional[int] = None
|
||||
@@ -463,6 +474,9 @@ class DuplexPipeline:
|
||||
generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled"))
|
||||
if generated_opener_flag is not None:
|
||||
self._runtime_generated_opener_enabled = generated_opener_flag
|
||||
if "manualOpenerToolCalls" in metadata:
|
||||
manual_calls = metadata.get("manualOpenerToolCalls")
|
||||
self._runtime_manual_opener_tool_calls = manual_calls if isinstance(manual_calls, list) else []
|
||||
|
||||
services = metadata.get("services") or {}
|
||||
if isinstance(services, dict):
|
||||
@@ -571,6 +585,10 @@ class DuplexPipeline:
|
||||
"tools": {
|
||||
"allowlist": self._resolved_tool_allowlist(),
|
||||
},
|
||||
"opener": {
|
||||
"generated": self._generated_opener_enabled(),
|
||||
"manualToolCallCount": len(self._resolved_manual_opener_tool_calls()),
|
||||
},
|
||||
"tracks": {
|
||||
"audio_in": self.track_audio_in,
|
||||
"audio_out": self.track_audio_out,
|
||||
@@ -965,6 +983,11 @@ class DuplexPipeline:
|
||||
logger.info("Initial generated opener started with tool-calling path")
|
||||
return
|
||||
|
||||
if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls():
|
||||
self._start_turn()
|
||||
self._start_response()
|
||||
await self._execute_manual_opener_tool_calls()
|
||||
|
||||
greeting_to_speak = self.conversation.greeting
|
||||
if self._generated_opener_enabled():
|
||||
generated_greeting = await self._generate_runtime_greeting()
|
||||
@@ -975,8 +998,10 @@ class DuplexPipeline:
|
||||
if not greeting_to_speak:
|
||||
return
|
||||
|
||||
self._start_turn()
|
||||
self._start_response()
|
||||
if not self._current_turn_id:
|
||||
self._start_turn()
|
||||
if not self._current_response_id:
|
||||
self._start_response()
|
||||
await self._send_event(
|
||||
ev(
|
||||
"assistant.response.final",
|
||||
@@ -1551,7 +1576,7 @@ class DuplexPipeline:
|
||||
seen: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
tool_name = item.strip()
|
||||
tool_name = self._normalize_tool_name(item)
|
||||
if not tool_name or tool_name in seen:
|
||||
continue
|
||||
seen.add(tool_name)
|
||||
@@ -1585,7 +1610,7 @@ class DuplexPipeline:
|
||||
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
fn_name = str(fn.get("name")).strip()
|
||||
fn_name = self._normalize_tool_name(fn.get("name"))
|
||||
if not fn_name or fn_name in seen:
|
||||
continue
|
||||
seen.add(fn_name)
|
||||
@@ -1602,7 +1627,7 @@ class DuplexPipeline:
|
||||
continue
|
||||
|
||||
if item.get("name"):
|
||||
item_name = str(item.get("name")).strip()
|
||||
item_name = self._normalize_tool_name(item.get("name"))
|
||||
if not item_name or item_name in seen:
|
||||
continue
|
||||
seen.add(item_name)
|
||||
@@ -1622,7 +1647,7 @@ class DuplexPipeline:
|
||||
result: Dict[str, str] = {}
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
name = self._normalize_tool_name(item)
|
||||
if name in self._DEFAULT_CLIENT_EXECUTORS:
|
||||
result[name] = "client"
|
||||
continue
|
||||
@@ -1630,9 +1655,9 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name"))
|
||||
name = self._normalize_tool_name(fn.get("name"))
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
name = self._normalize_tool_name(item.get("name"))
|
||||
if not name:
|
||||
continue
|
||||
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
|
||||
@@ -1647,9 +1672,9 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name")).strip()
|
||||
name = self._normalize_tool_name(fn.get("name"))
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
name = self._normalize_tool_name(item.get("name"))
|
||||
if not name:
|
||||
continue
|
||||
raw_defaults = item.get("defaultArgs")
|
||||
@@ -1666,9 +1691,9 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name")).strip()
|
||||
name = self._normalize_tool_name(fn.get("name"))
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
name = self._normalize_tool_name(item.get("name"))
|
||||
if not name:
|
||||
continue
|
||||
raw_wait = item.get("waitForResponse")
|
||||
@@ -1685,12 +1710,12 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
alias = str(fn.get("name")).strip()
|
||||
alias = self._normalize_tool_name(fn.get("name"))
|
||||
else:
|
||||
alias = str(item.get("name") or "").strip()
|
||||
alias = self._normalize_tool_name(item.get("name"))
|
||||
if not alias:
|
||||
continue
|
||||
tool_id = str(item.get("toolId") or item.get("tool_id") or alias).strip()
|
||||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or alias)
|
||||
if tool_id:
|
||||
result[alias] = tool_id
|
||||
return result
|
||||
@@ -1702,9 +1727,9 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name")).strip()
|
||||
name = self._normalize_tool_name(fn.get("name"))
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
name = self._normalize_tool_name(item.get("name"))
|
||||
if not name:
|
||||
continue
|
||||
display_name = str(
|
||||
@@ -1714,7 +1739,7 @@ class DuplexPipeline:
|
||||
).strip()
|
||||
if display_name:
|
||||
result[name] = display_name
|
||||
tool_id = str(item.get("toolId") or item.get("tool_id") or "").strip()
|
||||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or "")
|
||||
if tool_id:
|
||||
result[tool_id] = display_name
|
||||
return result
|
||||
@@ -1723,7 +1748,7 @@ class DuplexPipeline:
|
||||
names: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
name = self._normalize_tool_name(item)
|
||||
if name:
|
||||
names.add(name)
|
||||
continue
|
||||
@@ -1731,25 +1756,57 @@ class DuplexPipeline:
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.add(str(fn.get("name")).strip())
|
||||
names.add(self._normalize_tool_name(fn.get("name")))
|
||||
elif item.get("name"):
|
||||
names.add(str(item.get("name")).strip())
|
||||
names.add(self._normalize_tool_name(item.get("name")))
|
||||
return sorted([name for name in names if name])
|
||||
|
||||
def _resolved_manual_opener_tool_calls(self) -> List[Dict[str, Any]]:
|
||||
result: List[Dict[str, Any]] = []
|
||||
for item in self._runtime_manual_opener_tool_calls:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
tool_name = self._normalize_tool_name(str(
|
||||
item.get("toolName")
|
||||
or item.get("tool_name")
|
||||
or item.get("name")
|
||||
or ""
|
||||
).strip())
|
||||
if not tool_name:
|
||||
continue
|
||||
args_raw = item.get("arguments")
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, dict):
|
||||
args = dict(args_raw)
|
||||
elif isinstance(args_raw, str):
|
||||
text_value = args_raw.strip()
|
||||
if text_value:
|
||||
try:
|
||||
parsed = json.loads(text_value)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except Exception:
|
||||
logger.warning(f"[OpenerTool] ignore invalid JSON args for tool={tool_name}")
|
||||
result.append({"toolName": tool_name, "arguments": args})
|
||||
return result[:8]
|
||||
|
||||
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
|
||||
fn = tool_call.get("function")
|
||||
if isinstance(fn, dict):
|
||||
return str(fn.get("name") or "").strip()
|
||||
return self._normalize_tool_name(fn.get("name"))
|
||||
return ""
|
||||
|
||||
def _tool_id_for_name(self, tool_name: str) -> str:
|
||||
return str(self._runtime_tool_id_map.get(tool_name) or tool_name).strip()
|
||||
normalized = self._normalize_tool_name(tool_name)
|
||||
return self._normalize_tool_name(self._runtime_tool_id_map.get(normalized) or normalized)
|
||||
|
||||
def _tool_display_name(self, tool_name: str) -> str:
|
||||
return str(self._runtime_tool_display_names.get(tool_name) or tool_name).strip()
|
||||
normalized = self._normalize_tool_name(tool_name)
|
||||
return str(self._runtime_tool_display_names.get(normalized) or normalized).strip()
|
||||
|
||||
def _tool_wait_for_response(self, tool_name: str) -> bool:
|
||||
return bool(self._runtime_tool_wait_for_response.get(tool_name, False))
|
||||
normalized = self._normalize_tool_name(tool_name)
|
||||
return bool(self._runtime_tool_wait_for_response.get(normalized, False))
|
||||
|
||||
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
|
||||
name = self._tool_name(tool_call)
|
||||
@@ -1774,7 +1831,8 @@ class DuplexPipeline:
|
||||
return {}
|
||||
|
||||
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
defaults = self._runtime_tool_default_args.get(tool_name)
|
||||
normalized_tool_name = self._normalize_tool_name(tool_name)
|
||||
defaults = self._runtime_tool_default_args.get(normalized_tool_name)
|
||||
if not isinstance(defaults, dict) or not defaults:
|
||||
return args
|
||||
merged = dict(defaults)
|
||||
@@ -1782,6 +1840,84 @@ class DuplexPipeline:
|
||||
merged.update(args)
|
||||
return merged
|
||||
|
||||
async def _execute_manual_opener_tool_calls(self) -> None:
|
||||
calls = self._resolved_manual_opener_tool_calls()
|
||||
if not calls:
|
||||
return
|
||||
|
||||
for call in calls:
|
||||
tool_name = str(call.get("toolName") or "").strip()
|
||||
if not tool_name:
|
||||
continue
|
||||
tool_id = self._tool_id_for_name(tool_name)
|
||||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||||
tool_arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
|
||||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||||
call_id = f"call_opener_{uuid.uuid4().hex[:10]}"
|
||||
wait_for_response = self._tool_wait_for_response(tool_name)
|
||||
tool_call = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(merged_tool_arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
executor = self._tool_executor(tool_call)
|
||||
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_call",
|
||||
trackId=self.track_audio_out,
|
||||
tool_call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
tool_id=tool_id,
|
||||
tool_display_name=tool_display_name,
|
||||
wait_for_response=wait_for_response,
|
||||
arguments=merged_tool_arguments,
|
||||
executor=executor,
|
||||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||||
tool_call={**tool_call, "executor": executor, "wait_for_response": wait_for_response},
|
||||
)
|
||||
},
|
||||
priority=22,
|
||||
)
|
||||
logger.info(
|
||||
f"[OpenerTool] execute name={tool_name} call_id={call_id} executor={executor} "
|
||||
f"wait_for_response={wait_for_response}"
|
||||
)
|
||||
|
||||
if executor == "client":
|
||||
self._pending_client_tool_call_ids.add(call_id)
|
||||
if wait_for_response:
|
||||
result = await self._wait_for_single_tool_result(call_id)
|
||||
await self._emit_tool_result(result, source="client")
|
||||
continue
|
||||
|
||||
call_for_executor = dict(tool_call)
|
||||
fn_for_executor = (
|
||||
dict(call_for_executor.get("function"))
|
||||
if isinstance(call_for_executor.get("function"), dict)
|
||||
else None
|
||||
)
|
||||
if isinstance(fn_for_executor, dict):
|
||||
fn_for_executor["name"] = tool_id
|
||||
call_for_executor["function"] = fn_for_executor
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._server_tool_executor(call_for_executor),
|
||||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
result = {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"message": "server tool timeout"},
|
||||
"status": {"code": 504, "message": "server_tool_timeout"},
|
||||
}
|
||||
await self._emit_tool_result(result, source="server")
|
||||
|
||||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||||
status_code = int(status.get("code") or 0) if status else 0
|
||||
|
||||
Reference in New Issue
Block a user