Update code for tool call
This commit is contained in:
@@ -12,8 +12,9 @@ event-driven design.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
@@ -55,6 +56,60 @@ class DuplexPipeline:
|
||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"search": {
|
||||
"name": "search",
|
||||
"description": "Search the internet for recent information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"calculator": {
|
||||
"name": "calculator",
|
||||
"description": "Evaluate a math expression",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
"weather": {
|
||||
"name": "weather",
|
||||
"description": "Get weather by city name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
"translate": {
|
||||
"name": "translate",
|
||||
"description": "Translate text to target language",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"target_lang": {"type": "string"},
|
||||
},
|
||||
"required": ["text", "target_lang"],
|
||||
},
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "knowledge",
|
||||
"description": "Query knowledge base by question",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"kb_id": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -158,6 +213,10 @@ class DuplexPipeline:
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
self._runtime_knowledge_base_id: Optional[str] = None
|
||||
self._runtime_tools: List[Any] = []
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
@@ -208,8 +267,14 @@ class DuplexPipeline:
|
||||
if kb_id:
|
||||
self._runtime_knowledge_base_id = kb_id
|
||||
|
||||
tools_payload = metadata.get("tools")
|
||||
if isinstance(tools_payload, list):
|
||||
self._runtime_tools = tools_payload
|
||||
|
||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
@@ -234,6 +299,8 @@ class DuplexPipeline:
|
||||
|
||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
if hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
@@ -585,6 +652,117 @@ class DuplexPipeline:
|
||||
cfg.setdefault("enabled", True)
|
||||
return cfg
|
||||
|
||||
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
schemas: List[Dict[str, Any]] = []
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
|
||||
if base:
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": base["name"],
|
||||
"description": base.get("description") or "",
|
||||
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(fn.get("name")),
|
||||
"description": str(fn.get("description") or item.get("description") or ""),
|
||||
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if item.get("name"):
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(item.get("name")),
|
||||
"description": str(item.get("description") or ""),
|
||||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
return schemas
|
||||
|
||||
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
|
||||
"""Handle client tool execution results."""
|
||||
if not isinstance(results, list):
|
||||
return
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
if call_id in self._completed_tool_call_ids:
|
||||
continue
|
||||
|
||||
waiter = self._pending_tool_waiters.get(call_id)
|
||||
if waiter and not waiter.done():
|
||||
waiter.set_result(item)
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
continue
|
||||
self._early_tool_results[call_id] = item
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
|
||||
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
|
||||
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"status": {"code": 208, "message": "tool_call result already handled"},
|
||||
"output": "",
|
||||
}
|
||||
if call_id in self._early_tool_results:
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
return self._early_tool_results.pop(call_id)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_tool_waiters[call_id] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
|
||||
except asyncio.TimeoutError:
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"status": {"code": 504, "message": "tool_call timeout"},
|
||||
"output": "",
|
||||
}
|
||||
finally:
|
||||
self._pending_tool_waiters.pop(call_id, None)
|
||||
|
||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||
if isinstance(item, LLMStreamEvent):
|
||||
return item
|
||||
if isinstance(item, str):
|
||||
return LLMStreamEvent(type="text_delta", text=item)
|
||||
if isinstance(item, dict):
|
||||
event_type = str(item.get("type") or "")
|
||||
if event_type in {"text_delta", "tool_call", "done"}:
|
||||
return LLMStreamEvent(
|
||||
type=event_type, # type: ignore[arg-type]
|
||||
text=item.get("text"),
|
||||
tool_call=item.get("tool_call"),
|
||||
)
|
||||
return LLMStreamEvent(type="done")
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
@@ -599,109 +777,176 @@ class DuplexPipeline:
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
messages = self.conversation.get_messages()
|
||||
max_rounds = 3
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
self._drop_outbound_audio = False
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
first_audio_sent = False
|
||||
spoken_sentence_count = 0
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
for _ in range(max_rounds):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
round_response = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
allow_text_output = True
|
||||
|
||||
# Send LLM response streaming event to client
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
}, priority=40)
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
async for raw_event in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
|
||||
event = self._normalize_stream_event(raw_event)
|
||||
if event.type == "tool_call":
|
||||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||||
if not tool_call:
|
||||
continue
|
||||
allow_text_output = False
|
||||
tool_calls.append(tool_call)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_call",
|
||||
trackId=self.session_id,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
},
|
||||
priority=22,
|
||||
)
|
||||
continue
|
||||
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
if event.type != "text_delta":
|
||||
continue
|
||||
|
||||
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
text_chunk = event.text or ""
|
||||
if not text_chunk:
|
||||
continue
|
||||
|
||||
if not self._interrupt_event.is_set():
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self._send_event({
|
||||
if not allow_text_output:
|
||||
continue
|
||||
|
||||
full_response += text_chunk
|
||||
round_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
},
|
||||
priority=40,
|
||||
)
|
||||
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
continue
|
||||
|
||||
if not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
}, priority=10)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
spoken_sentence_count += 1
|
||||
|
||||
# Send final LLM response event
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
text=full_response,
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(
|
||||
remaining_text,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
}, priority=20)
|
||||
|
||||
# Speak any remaining text
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event({
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
for call in tool_calls:
|
||||
call_id = str(call.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
tool_results.append(await self._wait_for_single_tool_result(call_id))
|
||||
|
||||
messages = [
|
||||
*messages,
|
||||
LLMMessage(
|
||||
role="assistant",
|
||||
content=round_response.strip(),
|
||||
),
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=(
|
||||
"Tool execution results were returned by the client. "
|
||||
"Continue answering the user naturally using these results. "
|
||||
"Do not request the same tool again in this turn.\n"
|
||||
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||||
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
text=full_response,
|
||||
)
|
||||
}, priority=10)
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(
|
||||
remaining_text,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
},
|
||||
priority=20,
|
||||
)
|
||||
|
||||
# Send track end
|
||||
|
||||
@@ -28,6 +28,7 @@ from models.ws_v1 import (
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
ToolCallResultsMessage,
|
||||
)
|
||||
|
||||
|
||||
@@ -174,6 +175,8 @@ class Session:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user