Implement runtime tool ID and display name mapping in DuplexPipeline. Enhance Assistants and ToolLibrary components to utilize new mappings for improved tool identification and display. Update DebugDrawer to reflect changes in tool display names during interactions.

This commit is contained in:
Xin Wang
2026-02-27 15:50:43 +08:00
parent 0f1165af64
commit b035e023c4
4 changed files with 126 additions and 8 deletions

View File

@@ -288,6 +288,8 @@ class DuplexPipeline:
self._runtime_tools: List[Any] = list(raw_default_tools)
self._runtime_tool_executor: Dict[str, str] = {}
self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {}
self._runtime_tool_id_map: Dict[str, str] = {}
self._runtime_tool_display_names: Dict[str, str] = {}
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
@@ -309,6 +311,8 @@ class DuplexPipeline:
self._runtime_tool_executor = self._resolved_tool_executor_map()
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
self._runtime_tool_id_map = self._resolved_tool_id_map()
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
self._initial_greeting_emitted = False
if self._server_tool_executor is None:
@@ -411,10 +415,14 @@ class DuplexPipeline:
self._runtime_tools = tools_payload
self._runtime_tool_executor = self._resolved_tool_executor_map()
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
self._runtime_tool_id_map = self._resolved_tool_id_map()
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
elif "tools" in metadata:
self._runtime_tools = []
self._runtime_tool_executor = {}
self._runtime_tool_default_args = {}
self._runtime_tool_id_map = {}
self._runtime_tool_display_names = {}
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
@@ -1496,6 +1504,47 @@ class DuplexPipeline:
result[name] = dict(raw_defaults)
return result
def _resolved_tool_id_map(self) -> Dict[str, str]:
result: Dict[str, str] = {}
for item in self._runtime_tools:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
alias = str(fn.get("name")).strip()
else:
alias = str(item.get("name") or "").strip()
if not alias:
continue
tool_id = str(item.get("toolId") or item.get("tool_id") or alias).strip()
if tool_id:
result[alias] = tool_id
return result
def _resolved_tool_display_name_map(self) -> Dict[str, str]:
result: Dict[str, str] = {}
for item in self._runtime_tools:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name")).strip()
else:
name = str(item.get("name") or "").strip()
if not name:
continue
display_name = str(
item.get("displayName")
or item.get("display_name")
or name
).strip()
if display_name:
result[name] = display_name
tool_id = str(item.get("toolId") or item.get("tool_id") or "").strip()
if tool_id:
result[tool_id] = display_name
return result
def _resolved_tool_allowlist(self) -> List[str]:
names: set[str] = set()
for item in self._runtime_tools:
@@ -1519,6 +1568,12 @@ class DuplexPipeline:
return str(fn.get("name") or "").strip()
return ""
def _tool_id_for_name(self, tool_name: str) -> str:
return str(self._runtime_tool_id_map.get(tool_name) or tool_name).strip()
def _tool_display_name(self, tool_name: str) -> str:
return str(self._runtime_tool_display_names.get(tool_name) or tool_name).strip()
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
name = self._tool_name(tool_call)
if name and name in self._runtime_tool_executor:
@@ -1556,6 +1611,7 @@ class DuplexPipeline:
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
tool_display_name = self._tool_display_name(tool_name) or tool_name
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
@@ -1568,6 +1624,7 @@ class DuplexPipeline:
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_display_name": tool_display_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
@@ -1575,6 +1632,7 @@ class DuplexPipeline:
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
tool_display_name = self._tool_display_name(tool_name) or tool_name
call_id = str(result.get("tool_call_id") or result.get("id") or "")
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
@@ -1592,6 +1650,7 @@ class DuplexPipeline:
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
tool_display_name=normalized["tool_display_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
@@ -1733,6 +1792,8 @@ class DuplexPipeline:
enriched_tool_call = dict(tool_call)
enriched_tool_call["executor"] = executor
tool_name = self._tool_name(enriched_tool_call) or "unknown_tool"
tool_id = self._tool_id_for_name(tool_name)
tool_display_name = self._tool_display_name(tool_name) or tool_name
call_id = str(enriched_tool_call.get("id") or "").strip()
fn_payload = (
dict(enriched_tool_call.get("function"))
@@ -1764,6 +1825,8 @@ class DuplexPipeline:
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
tool_id=tool_id,
tool_display_name=tool_display_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
@@ -1881,6 +1944,7 @@ class DuplexPipeline:
continue
executor = str(call.get("executor") or "server").strip().lower()
tool_name = self._tool_name(call) or "unknown_tool"
tool_id = self._tool_id_for_name(tool_name)
logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}")
if executor == "client":
result = await self._wait_for_single_tool_result(call_id)
@@ -1888,9 +1952,18 @@ class DuplexPipeline:
tool_results.append(result)
continue
call_for_executor = dict(call)
fn_for_executor = (
dict(call_for_executor.get("function"))
if isinstance(call_for_executor.get("function"), dict)
else None
)
if isinstance(fn_for_executor, dict):
fn_for_executor["name"] = tool_id
call_for_executor["function"] = fn_for_executor
try:
result = await asyncio.wait_for(
self._server_tool_executor(call),
self._server_tool_executor(call_for_executor),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError: