Now we have server tool and client tool

This commit is contained in:
Xin Wang
2026-02-10 19:13:54 +08:00
parent 54eb48fb74
commit 6cac24918d
5 changed files with 257 additions and 82 deletions

View File

@@ -22,6 +22,7 @@ from loguru import logger
from app.config import settings
from core.conversation import ConversationManager, ConversationState
from core.events import get_event_bus
from core.tool_executor import execute_server_tool
from core.transports import BaseTransport
from models.ws_v1 import ev
from processors.eou import EouDetector
@@ -214,6 +215,7 @@ class DuplexPipeline:
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
self._runtime_tool_executor: Dict[str, str] = {}
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
@@ -270,6 +272,10 @@ class DuplexPipeline:
tools_payload = metadata.get("tools")
if isinstance(tools_payload, list):
self._runtime_tools = tools_payload
self._runtime_tool_executor = self._resolved_tool_executor_map()
elif "tools" in metadata:
self._runtime_tools = []
self._runtime_tool_executor = {}
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
@@ -675,6 +681,10 @@ class DuplexPipeline:
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
fn_name = str(fn.get("name"))
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
if executor in {"client", "server"}:
self._runtime_tool_executor[fn_name] = executor
schemas.append(
{
"type": "function",
@@ -688,6 +698,10 @@ class DuplexPipeline:
continue
if item.get("name"):
fn_name = str(item.get("name"))
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
if executor in {"client", "server"}:
self._runtime_tool_executor[fn_name] = executor
schemas.append(
{
"type": "function",
@@ -700,6 +714,49 @@ class DuplexPipeline:
)
return schemas
def _resolved_tool_executor_map(self) -> Dict[str, str]:
result: Dict[str, str] = {}
for item in self._runtime_tools:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
name = str(fn.get("name"))
else:
name = str(item.get("name") or "").strip()
if not name:
continue
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
if executor in {"client", "server"}:
result[name] = executor
return result
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
fn = tool_call.get("function")
if isinstance(fn, dict):
return str(fn.get("name") or "").strip()
return ""
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
name = self._tool_name(tool_call)
if name and name in self._runtime_tool_executor:
return self._runtime_tool_executor[name]
# Default to server execution unless explicitly marked as client.
return "server"
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
source=source,
result=result,
)
},
priority=22,
)
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
"""Handle client tool execution results."""
if not isinstance(results, list):
@@ -807,13 +864,16 @@ class DuplexPipeline:
if not tool_call:
continue
allow_text_output = False
tool_calls.append(tool_call)
executor = self._tool_executor(tool_call)
enriched_tool_call = dict(tool_call)
enriched_tool_call["executor"] = executor
tool_calls.append(enriched_tool_call)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
tool_call=tool_call,
tool_call=enriched_tool_call,
)
},
priority=22,
@@ -917,7 +977,16 @@ class DuplexPipeline:
call_id = str(call.get("id") or "").strip()
if not call_id:
continue
tool_results.append(await self._wait_for_single_tool_result(call_id))
executor = str(call.get("executor") or "server").strip().lower()
if executor == "client":
result = await self._wait_for_single_tool_result(call_id)
await self._emit_tool_result(result, source="client")
tool_results.append(result)
continue
result = await execute_server_tool(call)
await self._emit_tool_result(result, source="server")
tool_results.append(result)
messages = [
*messages,
@@ -928,7 +997,7 @@ class DuplexPipeline:
LLMMessage(
role="system",
content=(
"Tool execution results were returned by the client. "
"Tool execution results are available. "
"Continue answering the user naturally using these results. "
"Do not request the same tool again in this turn.\n"
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"