Update code for tool call

This commit is contained in:
Xin Wang
2026-02-10 16:28:20 +08:00
parent 539cf2fda2
commit 4b8da32787
7 changed files with 676 additions and 91 deletions

View File

@@ -12,8 +12,9 @@ event-driven design.
"""
import asyncio
import json
import time
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService
@@ -55,6 +56,60 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"search": {
"name": "search",
"description": "Search the internet for recent information",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
"calculator": {
"name": "calculator",
"description": "Evaluate a math expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
"weather": {
"name": "weather",
"description": "Get weather by city name",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
"translate": {
"name": "translate",
"description": "Translate text to target language",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string"},
"target_lang": {"type": "string"},
},
"required": ["text", "target_lang"],
},
},
"knowledge": {
"name": "knowledge",
"description": "Query knowledge base by question",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"kb_id": {"type": "string"},
},
"required": ["query"],
},
},
}
def __init__(
self,
@@ -158,6 +213,10 @@ class DuplexPipeline:
self._runtime_greeting: Optional[str] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
logger.info(f"DuplexPipeline initialized for session {session_id}")
@@ -208,8 +267,14 @@ class DuplexPipeline:
if kb_id:
self._runtime_knowledge_base_id = kb_id
tools_payload = metadata.get("tools")
if isinstance(tools_payload, list):
self._runtime_tools = tools_payload
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
async def start(self) -> None:
"""Start the pipeline and connect services."""
@@ -234,6 +299,8 @@ class DuplexPipeline:
if hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
await self.llm_service.connect()
@@ -585,6 +652,117 @@ class DuplexPipeline:
cfg.setdefault("enabled", True)
return cfg
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
schemas: List[Dict[str, Any]] = []
for item in self._runtime_tools:
if isinstance(item, str):
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
if base:
schemas.append(
{
"type": "function",
"function": {
"name": base["name"],
"description": base.get("description") or "",
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(fn.get("name")),
"description": str(fn.get("description") or item.get("description") or ""),
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if item.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
return schemas
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
"""Handle client tool execution results."""
if not isinstance(results, list):
return
for item in results:
if not isinstance(item, dict):
continue
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if call_id in self._completed_tool_call_ids:
continue
waiter = self._pending_tool_waiters.get(call_id)
if waiter and not waiter.done():
waiter.set_result(item)
self._completed_tool_call_ids.add(call_id)
continue
self._early_tool_results[call_id] = item
self._completed_tool_call_ids.add(call_id)
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
return {
"tool_call_id": call_id,
"status": {"code": 208, "message": "tool_call result already handled"},
"output": "",
}
if call_id in self._early_tool_results:
self._completed_tool_call_ids.add(call_id)
return self._early_tool_results.pop(call_id)
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_tool_waiters[call_id] = future
try:
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
except asyncio.TimeoutError:
self._completed_tool_call_ids.add(call_id)
return {
"tool_call_id": call_id,
"status": {"code": 504, "message": "tool_call timeout"},
"output": "",
}
finally:
self._pending_tool_waiters.pop(call_id, None)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
return item
if isinstance(item, str):
return LLMStreamEvent(type="text_delta", text=item)
if isinstance(item, dict):
event_type = str(item.get("type") or "")
if event_type in {"text_delta", "tool_call", "done"}:
return LLMStreamEvent(
type=event_type, # type: ignore[arg-type]
text=item.get("text"),
tool_call=item.get("tool_call"),
)
return LLMStreamEvent(type="done")
async def _handle_turn(self, user_text: str) -> None:
"""
Handle a complete conversation turn.
@@ -599,109 +777,176 @@ class DuplexPipeline:
self._turn_start_time = time.time()
self._first_audio_sent = False
# Get AI response (streaming)
messages = self.conversation.get_messages()
full_response = ""
messages = self.conversation.get_messages()
max_rounds = 3
await self.conversation.start_assistant_turn()
self._is_bot_speaking = True
self._interrupt_event.clear()
self._drop_outbound_audio = False
# Sentence buffer for streaming TTS
sentence_buffer = ""
pending_punctuation = ""
first_audio_sent = False
spoken_sentence_count = 0
# Stream LLM response and TTS sentence by sentence
async for text_chunk in self.llm_service.generate_stream(messages):
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
full_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
sentence_buffer = ""
pending_punctuation = ""
round_response = ""
tool_calls: List[Dict[str, Any]] = []
allow_text_output = True
# Send LLM response streaming event to client
await self._send_event({
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
}, priority=40)
# Check for sentence completion - synthesize immediately for low latency
while True:
split_result = extract_tts_sentence(
sentence_buffer,
end_chars=self._SENTENCE_END_CHARS,
trailing_chars=self._SENTENCE_TRAILING_CHARS,
closers=self._SENTENCE_CLOSERS,
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
hold_trailing_at_buffer_end=True,
force=False,
)
if not split_result:
async for raw_event in self.llm_service.generate_stream(messages):
if self._interrupt_event.is_set():
break
sentence, sentence_buffer = split_result
if not sentence:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
allow_text_output = False
tool_calls.append(tool_call)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
tool_call=tool_call,
)
},
priority=22,
)
continue
sentence = f"{pending_punctuation}{sentence}".strip()
pending_punctuation = ""
if not sentence:
if event.type != "text_delta":
continue
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
if not has_spoken_content(sentence):
pending_punctuation = sentence
text_chunk = event.text or ""
if not text_chunk:
continue
if not self._interrupt_event.is_set():
# Send track start on first audio
if not first_audio_sent:
await self._send_event({
if not allow_text_output:
continue
full_response += text_chunk
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
priority=40,
)
while True:
split_result = extract_tts_sentence(
sentence_buffer,
end_chars=self._SENTENCE_END_CHARS,
trailing_chars=self._SENTENCE_TRAILING_CHARS,
closers=self._SENTENCE_CLOSERS,
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
hold_trailing_at_buffer_end=True,
force=False,
)
if not split_result:
break
sentence, sentence_buffer = split_result
if not sentence:
continue
sentence = f"{pending_punctuation}{sentence}".strip()
pending_punctuation = ""
if not sentence:
continue
if not has_spoken_content(sentence):
pending_punctuation = sentence
continue
if not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
)
},
priority=10,
)
first_audio_sent = True
await self._speak_sentence(
sentence,
fade_in_ms=0,
fade_out_ms=8,
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
)
}, priority=10)
first_audio_sent = True
await self._speak_sentence(
sentence,
fade_in_ms=0,
fade_out_ms=8,
},
priority=10,
)
spoken_sentence_count += 1
# Send final LLM response event
if full_response and not self._interrupt_event.is_set():
await self._send_event({
**ev(
"assistant.response.final",
trackId=self.session_id,
text=full_response,
first_audio_sent = True
await self._speak_sentence(
remaining_text,
fade_in_ms=0,
fade_out_ms=8,
)
}, priority=20)
# Speak any remaining text
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event({
if not tool_calls:
break
tool_results: List[Dict[str, Any]] = []
for call in tool_calls:
call_id = str(call.get("id") or "").strip()
if not call_id:
continue
tool_results.append(await self._wait_for_single_tool_result(call_id))
messages = [
*messages,
LLMMessage(
role="assistant",
content=round_response.strip(),
),
LLMMessage(
role="system",
content=(
"Tool execution results were returned by the client. "
"Continue answering the user naturally using these results. "
"Do not request the same tool again in this turn.\n"
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
),
),
]
if full_response and not self._interrupt_event.is_set():
await self._send_event(
{
**ev(
"output.audio.start",
"assistant.response.final",
trackId=self.session_id,
text=full_response,
)
}, priority=10)
first_audio_sent = True
await self._speak_sentence(
remaining_text,
fade_in_ms=0,
fade_out_ms=8,
},
priority=20,
)
# Send track end

View File

@@ -28,6 +28,7 @@ from models.ws_v1 import (
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
ToolCallResultsMessage,
)
@@ -174,6 +175,8 @@ class Session:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else: