Update code for tool call
This commit is contained in:
@@ -12,8 +12,9 @@ event-driven design.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
|
|||||||
from processors.eou import EouDetector
|
from processors.eou import EouDetector
|
||||||
from processors.vad import SileroVAD, VADProcessor
|
from processors.vad import SileroVAD, VADProcessor
|
||||||
from services.asr import BufferedASRService
|
from services.asr import BufferedASRService
|
||||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
|
||||||
from services.llm import MockLLMService, OpenAILLMService
|
from services.llm import MockLLMService, OpenAILLMService
|
||||||
from services.siliconflow_asr import SiliconFlowASRService
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
from services.siliconflow_tts import SiliconFlowTTSService
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
@@ -55,6 +56,60 @@ class DuplexPipeline:
|
|||||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||||
|
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||||
|
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||||
|
"search": {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the internet for recent information",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"calculator": {
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Evaluate a math expression",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"expression": {"type": "string"}},
|
||||||
|
"required": ["expression"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"weather": {
|
||||||
|
"name": "weather",
|
||||||
|
"description": "Get weather by city name",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"translate": {
|
||||||
|
"name": "translate",
|
||||||
|
"description": "Translate text to target language",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string"},
|
||||||
|
"target_lang": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["text", "target_lang"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"name": "knowledge",
|
||||||
|
"description": "Query knowledge base by question",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"kb_id": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -158,6 +213,10 @@ class DuplexPipeline:
|
|||||||
self._runtime_greeting: Optional[str] = None
|
self._runtime_greeting: Optional[str] = None
|
||||||
self._runtime_knowledge: Dict[str, Any] = {}
|
self._runtime_knowledge: Dict[str, Any] = {}
|
||||||
self._runtime_knowledge_base_id: Optional[str] = None
|
self._runtime_knowledge_base_id: Optional[str] = None
|
||||||
|
self._runtime_tools: List[Any] = []
|
||||||
|
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||||
|
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self._completed_tool_call_ids: set[str] = set()
|
||||||
|
|
||||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||||
|
|
||||||
@@ -208,8 +267,14 @@ class DuplexPipeline:
|
|||||||
if kb_id:
|
if kb_id:
|
||||||
self._runtime_knowledge_base_id = kb_id
|
self._runtime_knowledge_base_id = kb_id
|
||||||
|
|
||||||
|
tools_payload = metadata.get("tools")
|
||||||
|
if isinstance(tools_payload, list):
|
||||||
|
self._runtime_tools = tools_payload
|
||||||
|
|
||||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||||
|
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||||
|
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the pipeline and connect services."""
|
"""Start the pipeline and connect services."""
|
||||||
@@ -234,6 +299,8 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||||
|
if hasattr(self.llm_service, "set_tool_schemas"):
|
||||||
|
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||||
|
|
||||||
await self.llm_service.connect()
|
await self.llm_service.connect()
|
||||||
|
|
||||||
@@ -585,6 +652,117 @@ class DuplexPipeline:
|
|||||||
cfg.setdefault("enabled", True)
|
cfg.setdefault("enabled", True)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||||
|
schemas: List[Dict[str, Any]] = []
|
||||||
|
for item in self._runtime_tools:
|
||||||
|
if isinstance(item, str):
|
||||||
|
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
|
||||||
|
if base:
|
||||||
|
schemas.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": base["name"],
|
||||||
|
"description": base.get("description") or "",
|
||||||
|
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
fn = item.get("function")
|
||||||
|
if isinstance(fn, dict) and fn.get("name"):
|
||||||
|
schemas.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": str(fn.get("name")),
|
||||||
|
"description": str(fn.get("description") or item.get("description") or ""),
|
||||||
|
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item.get("name"):
|
||||||
|
schemas.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": str(item.get("name")),
|
||||||
|
"description": str(item.get("description") or ""),
|
||||||
|
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return schemas
|
||||||
|
|
||||||
|
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Handle client tool execution results."""
|
||||||
|
if not isinstance(results, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in results:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
if call_id in self._completed_tool_call_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
waiter = self._pending_tool_waiters.get(call_id)
|
||||||
|
if waiter and not waiter.done():
|
||||||
|
waiter.set_result(item)
|
||||||
|
self._completed_tool_call_ids.add(call_id)
|
||||||
|
continue
|
||||||
|
self._early_tool_results[call_id] = item
|
||||||
|
self._completed_tool_call_ids.add(call_id)
|
||||||
|
|
||||||
|
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
|
||||||
|
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
|
||||||
|
return {
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"status": {"code": 208, "message": "tool_call result already handled"},
|
||||||
|
"output": "",
|
||||||
|
}
|
||||||
|
if call_id in self._early_tool_results:
|
||||||
|
self._completed_tool_call_ids.add(call_id)
|
||||||
|
return self._early_tool_results.pop(call_id)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.create_future()
|
||||||
|
self._pending_tool_waiters[call_id] = future
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._completed_tool_call_ids.add(call_id)
|
||||||
|
return {
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"status": {"code": 504, "message": "tool_call timeout"},
|
||||||
|
"output": "",
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
self._pending_tool_waiters.pop(call_id, None)
|
||||||
|
|
||||||
|
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||||
|
if isinstance(item, LLMStreamEvent):
|
||||||
|
return item
|
||||||
|
if isinstance(item, str):
|
||||||
|
return LLMStreamEvent(type="text_delta", text=item)
|
||||||
|
if isinstance(item, dict):
|
||||||
|
event_type = str(item.get("type") or "")
|
||||||
|
if event_type in {"text_delta", "tool_call", "done"}:
|
||||||
|
return LLMStreamEvent(
|
||||||
|
type=event_type, # type: ignore[arg-type]
|
||||||
|
text=item.get("text"),
|
||||||
|
tool_call=item.get("tool_call"),
|
||||||
|
)
|
||||||
|
return LLMStreamEvent(type="done")
|
||||||
|
|
||||||
async def _handle_turn(self, user_text: str) -> None:
|
async def _handle_turn(self, user_text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Handle a complete conversation turn.
|
Handle a complete conversation turn.
|
||||||
@@ -599,40 +777,75 @@ class DuplexPipeline:
|
|||||||
self._turn_start_time = time.time()
|
self._turn_start_time = time.time()
|
||||||
self._first_audio_sent = False
|
self._first_audio_sent = False
|
||||||
|
|
||||||
# Get AI response (streaming)
|
|
||||||
messages = self.conversation.get_messages()
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
messages = self.conversation.get_messages()
|
||||||
|
max_rounds = 3
|
||||||
|
|
||||||
await self.conversation.start_assistant_turn()
|
await self.conversation.start_assistant_turn()
|
||||||
self._is_bot_speaking = True
|
self._is_bot_speaking = True
|
||||||
self._interrupt_event.clear()
|
self._interrupt_event.clear()
|
||||||
self._drop_outbound_audio = False
|
self._drop_outbound_audio = False
|
||||||
|
|
||||||
# Sentence buffer for streaming TTS
|
|
||||||
sentence_buffer = ""
|
|
||||||
pending_punctuation = ""
|
|
||||||
first_audio_sent = False
|
first_audio_sent = False
|
||||||
spoken_sentence_count = 0
|
for _ in range(max_rounds):
|
||||||
|
|
||||||
# Stream LLM response and TTS sentence by sentence
|
|
||||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
|
||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
sentence_buffer = ""
|
||||||
|
pending_punctuation = ""
|
||||||
|
round_response = ""
|
||||||
|
tool_calls: List[Dict[str, Any]] = []
|
||||||
|
allow_text_output = True
|
||||||
|
|
||||||
|
async for raw_event in self.llm_service.generate_stream(messages):
|
||||||
|
if self._interrupt_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
event = self._normalize_stream_event(raw_event)
|
||||||
|
if event.type == "tool_call":
|
||||||
|
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||||||
|
if not tool_call:
|
||||||
|
continue
|
||||||
|
allow_text_output = False
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
await self._send_event(
|
||||||
|
{
|
||||||
|
**ev(
|
||||||
|
"assistant.tool_call",
|
||||||
|
trackId=self.session_id,
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
priority=22,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.type != "text_delta":
|
||||||
|
continue
|
||||||
|
|
||||||
|
text_chunk = event.text or ""
|
||||||
|
if not text_chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not allow_text_output:
|
||||||
|
continue
|
||||||
|
|
||||||
full_response += text_chunk
|
full_response += text_chunk
|
||||||
|
round_response += text_chunk
|
||||||
sentence_buffer += text_chunk
|
sentence_buffer += text_chunk
|
||||||
await self.conversation.update_assistant_text(text_chunk)
|
await self.conversation.update_assistant_text(text_chunk)
|
||||||
|
|
||||||
# Send LLM response streaming event to client
|
await self._send_event(
|
||||||
await self._send_event({
|
{
|
||||||
**ev(
|
**ev(
|
||||||
"assistant.response.delta",
|
"assistant.response.delta",
|
||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
text=text_chunk,
|
text=text_chunk,
|
||||||
)
|
)
|
||||||
}, priority=40)
|
},
|
||||||
|
priority=40,
|
||||||
|
)
|
||||||
|
|
||||||
# Check for sentence completion - synthesize immediately for low latency
|
|
||||||
while True:
|
while True:
|
||||||
split_result = extract_tts_sentence(
|
split_result = extract_tts_sentence(
|
||||||
sentence_buffer,
|
sentence_buffer,
|
||||||
@@ -654,20 +867,21 @@ class DuplexPipeline:
|
|||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
|
|
||||||
if not has_spoken_content(sentence):
|
if not has_spoken_content(sentence):
|
||||||
pending_punctuation = sentence
|
pending_punctuation = sentence
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self._interrupt_event.is_set():
|
if not self._interrupt_event.is_set():
|
||||||
# Send track start on first audio
|
|
||||||
if not first_audio_sent:
|
if not first_audio_sent:
|
||||||
await self._send_event({
|
await self._send_event(
|
||||||
|
{
|
||||||
**ev(
|
**ev(
|
||||||
"output.audio.start",
|
"output.audio.start",
|
||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
}, priority=10)
|
},
|
||||||
|
priority=10,
|
||||||
|
)
|
||||||
first_audio_sent = True
|
first_audio_sent = True
|
||||||
|
|
||||||
await self._speak_sentence(
|
await self._speak_sentence(
|
||||||
@@ -675,28 +889,19 @@ class DuplexPipeline:
|
|||||||
fade_in_ms=0,
|
fade_in_ms=0,
|
||||||
fade_out_ms=8,
|
fade_out_ms=8,
|
||||||
)
|
)
|
||||||
spoken_sentence_count += 1
|
|
||||||
|
|
||||||
# Send final LLM response event
|
|
||||||
if full_response and not self._interrupt_event.is_set():
|
|
||||||
await self._send_event({
|
|
||||||
**ev(
|
|
||||||
"assistant.response.final",
|
|
||||||
trackId=self.session_id,
|
|
||||||
text=full_response,
|
|
||||||
)
|
|
||||||
}, priority=20)
|
|
||||||
|
|
||||||
# Speak any remaining text
|
|
||||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||||
if not first_audio_sent:
|
if not first_audio_sent:
|
||||||
await self._send_event({
|
await self._send_event(
|
||||||
|
{
|
||||||
**ev(
|
**ev(
|
||||||
"output.audio.start",
|
"output.audio.start",
|
||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
}, priority=10)
|
},
|
||||||
|
priority=10,
|
||||||
|
)
|
||||||
first_audio_sent = True
|
first_audio_sent = True
|
||||||
await self._speak_sentence(
|
await self._speak_sentence(
|
||||||
remaining_text,
|
remaining_text,
|
||||||
@@ -704,6 +909,46 @@ class DuplexPipeline:
|
|||||||
fade_out_ms=8,
|
fade_out_ms=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_results: List[Dict[str, Any]] = []
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id = str(call.get("id") or "").strip()
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
tool_results.append(await self._wait_for_single_tool_result(call_id))
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
*messages,
|
||||||
|
LLMMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=round_response.strip(),
|
||||||
|
),
|
||||||
|
LLMMessage(
|
||||||
|
role="system",
|
||||||
|
content=(
|
||||||
|
"Tool execution results were returned by the client. "
|
||||||
|
"Continue answering the user naturally using these results. "
|
||||||
|
"Do not request the same tool again in this turn.\n"
|
||||||
|
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||||||
|
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if full_response and not self._interrupt_event.is_set():
|
||||||
|
await self._send_event(
|
||||||
|
{
|
||||||
|
**ev(
|
||||||
|
"assistant.response.final",
|
||||||
|
trackId=self.session_id,
|
||||||
|
text=full_response,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
priority=20,
|
||||||
|
)
|
||||||
|
|
||||||
# Send track end
|
# Send track end
|
||||||
if first_audio_sent:
|
if first_audio_sent:
|
||||||
await self._send_event({
|
await self._send_event({
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from models.ws_v1 import (
|
|||||||
SessionStopMessage,
|
SessionStopMessage,
|
||||||
InputTextMessage,
|
InputTextMessage,
|
||||||
ResponseCancelMessage,
|
ResponseCancelMessage,
|
||||||
|
ToolCallResultsMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -174,6 +175,8 @@ class Session:
|
|||||||
logger.info(f"Session {self.id} graceful response.cancel")
|
logger.info(f"Session {self.id} graceful response.cancel")
|
||||||
else:
|
else:
|
||||||
await self.pipeline.interrupt()
|
await self.pipeline.interrupt()
|
||||||
|
elif isinstance(message, ToolCallResultsMessage):
|
||||||
|
await self.pipeline.handle_tool_call_results(message.results)
|
||||||
elif isinstance(message, SessionStopMessage):
|
elif isinstance(message, SessionStopMessage):
|
||||||
await self._handle_session_stop(message.reason)
|
await self._handle_session_stop(message.reason)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -110,6 +110,24 @@ Rules:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `tool_call.results`
|
||||||
|
|
||||||
|
Client tool execution results returned to server.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_call.results",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"tool_call_id": "call_abc123",
|
||||||
|
"name": "weather",
|
||||||
|
"output": { "temp_c": 21, "condition": "sunny" },
|
||||||
|
"status": { "code": 200, "message": "ok" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Server -> Client Events
|
## Server -> Client Events
|
||||||
|
|
||||||
All server events include:
|
All server events include:
|
||||||
@@ -142,6 +160,8 @@ Common events:
|
|||||||
- Fields: `trackId`, `text`
|
- Fields: `trackId`, `text`
|
||||||
- `assistant.response.final`
|
- `assistant.response.final`
|
||||||
- Fields: `trackId`, `text`
|
- Fields: `trackId`, `text`
|
||||||
|
- `assistant.tool_call`
|
||||||
|
- Fields: `trackId`, `tool_call`
|
||||||
- `output.audio.start`
|
- `output.audio.start`
|
||||||
- Fields: `trackId`
|
- Fields: `trackId`
|
||||||
- `output.audio.end`
|
- `output.audio.end`
|
||||||
|
|||||||
@@ -39,12 +39,18 @@ class ResponseCancelMessage(BaseModel):
|
|||||||
graceful: bool = False
|
graceful: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResultsMessage(BaseModel):
|
||||||
|
type: Literal["tool_call.results"]
|
||||||
|
results: list[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
CLIENT_MESSAGE_TYPES = {
|
CLIENT_MESSAGE_TYPES = {
|
||||||
"hello": HelloMessage,
|
"hello": HelloMessage,
|
||||||
"session.start": SessionStartMessage,
|
"session.start": SessionStartMessage,
|
||||||
"session.stop": SessionStopMessage,
|
"session.stop": SessionStopMessage,
|
||||||
"input.text": InputTextMessage,
|
"input.text": InputTextMessage,
|
||||||
"response.cancel": ResponseCancelMessage,
|
"response.cancel": ResponseCancelMessage,
|
||||||
|
"tool_call.results": ToolCallResultsMessage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ StreamEngine pattern.
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
@@ -52,6 +52,15 @@ class LLMMessage:
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMStreamEvent:
|
||||||
|
"""Structured LLM stream event."""
|
||||||
|
|
||||||
|
type: Literal["text_delta", "tool_call", "done"]
|
||||||
|
text: Optional[str] = None
|
||||||
|
tool_call: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TTSChunk:
|
class TTSChunk:
|
||||||
"""TTS audio chunk."""
|
"""TTS audio chunk."""
|
||||||
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
|
|||||||
messages: List[LLMMessage],
|
messages: List[LLMMessage],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[LLMStreamEvent]:
|
||||||
"""
|
"""
|
||||||
Generate response in streaming mode.
|
Generate response in streaming mode.
|
||||||
|
|
||||||
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
|
|||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Text chunks as they are generated
|
Stream events (text delta/tool call/done)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ for real-time voice conversation.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import uuid
|
||||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from app.backend_client import search_knowledge_context
|
from app.backend_client import search_knowledge_context
|
||||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||||
|
|
||||||
# Try to import openai
|
# Try to import openai
|
||||||
try:
|
try:
|
||||||
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
self.client: Optional[AsyncOpenAI] = None
|
self.client: Optional[AsyncOpenAI] = None
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||||
|
self._tool_schemas: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
_RAG_DEFAULT_RESULTS = 5
|
_RAG_DEFAULT_RESULTS = 5
|
||||||
_RAG_MAX_RESULTS = 8
|
_RAG_MAX_RESULTS = 8
|
||||||
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
"""Update runtime knowledge retrieval config."""
|
"""Update runtime knowledge retrieval config."""
|
||||||
self._knowledge_config = config or {}
|
self._knowledge_config = config or {}
|
||||||
|
|
||||||
|
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||||
|
"""Update runtime tool schemas."""
|
||||||
|
self._tool_schemas = []
|
||||||
|
if not isinstance(schemas, list):
|
||||||
|
return
|
||||||
|
for item in schemas:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
fn = item.get("function")
|
||||||
|
if isinstance(fn, dict) and fn.get("name"):
|
||||||
|
self._tool_schemas.append(item)
|
||||||
|
elif item.get("name"):
|
||||||
|
self._tool_schemas.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": str(item.get("name")),
|
||||||
|
"description": str(item.get("description") or ""),
|
||||||
|
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _coerce_int(value: Any, default: int) -> int:
|
def _coerce_int(value: Any, default: int) -> int:
|
||||||
try:
|
try:
|
||||||
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
messages: List[LLMMessage],
|
messages: List[LLMMessage],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[LLMStreamEvent]:
|
||||||
"""
|
"""
|
||||||
Generate response in streaming mode.
|
Generate response in streaming mode.
|
||||||
|
|
||||||
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Text chunks as they are generated
|
Structured stream events
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
raise RuntimeError("LLM service not connected")
|
raise RuntimeError("LLM service not connected")
|
||||||
@@ -276,15 +301,21 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
rag_messages = await self._with_knowledge_context(messages)
|
rag_messages = await self._with_knowledge_context(messages)
|
||||||
prepared = self._prepare_messages(rag_messages)
|
prepared = self._prepare_messages(rag_messages)
|
||||||
self._cancel_event.clear()
|
self._cancel_event.clear()
|
||||||
|
tool_accumulator: Dict[int, Dict[str, str]] = {}
|
||||||
|
openai_tools = self._tool_schemas or None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = await self.client.chat.completions.create(
|
create_args: Dict[str, Any] = dict(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=prepared,
|
messages=prepared,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
if openai_tools:
|
||||||
|
create_args["tools"] = openai_tools
|
||||||
|
create_args["tool_choice"] = "auto"
|
||||||
|
stream = await self.client.chat.completions.create(**create_args)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
# Check for cancellation
|
# Check for cancellation
|
||||||
@@ -292,9 +323,60 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
logger.info("LLM stream cancelled")
|
logger.info("LLM stream cancelled")
|
||||||
break
|
break
|
||||||
|
|
||||||
if chunk.choices and chunk.choices[0].delta.content:
|
if not chunk.choices:
|
||||||
content = chunk.choices[0].delta.content
|
continue
|
||||||
yield content
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = getattr(choice, "delta", None)
|
||||||
|
if delta and getattr(delta, "content", None):
|
||||||
|
content = delta.content
|
||||||
|
yield LLMStreamEvent(type="text_delta", text=content)
|
||||||
|
|
||||||
|
# OpenAI streams function calls via incremental tool_calls deltas.
|
||||||
|
tool_calls = getattr(delta, "tool_calls", None) if delta else None
|
||||||
|
if tool_calls:
|
||||||
|
for tc in tool_calls:
|
||||||
|
index = getattr(tc, "index", 0) or 0
|
||||||
|
item = tool_accumulator.setdefault(
|
||||||
|
int(index),
|
||||||
|
{"id": "", "name": "", "arguments": ""},
|
||||||
|
)
|
||||||
|
tc_id = getattr(tc, "id", None)
|
||||||
|
if tc_id:
|
||||||
|
item["id"] = str(tc_id)
|
||||||
|
fn = getattr(tc, "function", None)
|
||||||
|
if fn:
|
||||||
|
fn_name = getattr(fn, "name", None)
|
||||||
|
if fn_name:
|
||||||
|
item["name"] = str(fn_name)
|
||||||
|
fn_args = getattr(fn, "arguments", None)
|
||||||
|
if fn_args:
|
||||||
|
item["arguments"] += str(fn_args)
|
||||||
|
|
||||||
|
finish_reason = getattr(choice, "finish_reason", None)
|
||||||
|
if finish_reason == "tool_calls" and tool_accumulator:
|
||||||
|
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
|
||||||
|
call_name = payload.get("name", "").strip()
|
||||||
|
if not call_name:
|
||||||
|
continue
|
||||||
|
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
|
||||||
|
yield LLMStreamEvent(
|
||||||
|
type="tool_call",
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call_name,
|
||||||
|
"arguments": payload.get("arguments", "") or "{}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield LLMStreamEvent(type="done")
|
||||||
|
return
|
||||||
|
|
||||||
|
if finish_reason in {"stop", "length", "content_filter"}:
|
||||||
|
yield LLMStreamEvent(type="done")
|
||||||
|
return
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("LLM stream cancelled via asyncio")
|
logger.info("LLM stream cancelled via asyncio")
|
||||||
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
|
|||||||
messages: List[LLMMessage],
|
messages: List[LLMMessage],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[LLMStreamEvent]:
|
||||||
response = await self.generate(messages, temperature, max_tokens)
|
response = await self.generate(messages, temperature, max_tokens)
|
||||||
|
|
||||||
# Stream word by word
|
# Stream word by word
|
||||||
words = response.split()
|
words = response.split()
|
||||||
for i, word in enumerate(words):
|
for i, word in enumerate(words):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
yield " "
|
yield LLMStreamEvent(type="text_delta", text=" ")
|
||||||
yield word
|
yield LLMStreamEvent(type="text_delta", text=word)
|
||||||
await asyncio.sleep(0.05) # Simulate streaming delay
|
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||||
|
yield LLMStreamEvent(type="done")
|
||||||
|
|||||||
219
engine/tests/test_tool_call_flow.py
Normal file
219
engine/tests/test_tool_call_flow.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.duplex_pipeline import DuplexPipeline
|
||||||
|
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
|
||||||
|
from services.base import LLMStreamEvent
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySileroVAD:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_audio(self, _pcm: bytes) -> float:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyVADProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self, _speech_prob: float):
|
||||||
|
return "Silence", 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyEouDetector:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self, _vad_status: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTransport:
|
||||||
|
async def send_event(self, _event: Dict[str, Any]) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send_audio(self, _audio: bytes) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTTS:
|
||||||
|
async def synthesize_stream(self, _text: str):
|
||||||
|
if False:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeASR:
|
||||||
|
async def connect(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
||||||
|
self._rounds = rounds
|
||||||
|
self._call_index = 0
|
||||||
|
|
||||||
|
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
||||||
|
idx = self._call_index
|
||||||
|
self._call_index += 1
|
||||||
|
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
|
||||||
|
for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
||||||
|
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
|
||||||
|
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
|
||||||
|
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
|
||||||
|
|
||||||
|
pipeline = DuplexPipeline(
|
||||||
|
transport=_FakeTransport(),
|
||||||
|
session_id="s_test",
|
||||||
|
llm_service=_FakeLLM(llm_rounds),
|
||||||
|
tts_service=_FakeTTS(),
|
||||||
|
asr_service=_FakeASR(),
|
||||||
|
)
|
||||||
|
events: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||||
|
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
|
||||||
|
return pipeline, events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ws_message_parses_tool_call_results():
|
||||||
|
msg = parse_client_message(
|
||||||
|
{
|
||||||
|
"type": "tool_call.results",
|
||||||
|
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert isinstance(msg, ToolCallResultsMessage)
|
||||||
|
assert msg.results[0]["tool_call_id"] == "call_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
||||||
|
pipeline, events = _build_pipeline(
|
||||||
|
monkeypatch,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
LLMStreamEvent(type="text_delta", text="hello "),
|
||||||
|
LLMStreamEvent(type="text_delta", text="world."),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline._handle_turn("hi")
|
||||||
|
|
||||||
|
event_types = [e.get("type") for e in events]
|
||||||
|
assert "assistant.response.delta" in event_types
|
||||||
|
assert "assistant.response.final" in event_types
|
||||||
|
assert "assistant.tool_call" not in event_types
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_turn_with_tool_call_then_results(monkeypatch):
|
||||||
|
pipeline, events = _build_pipeline(
|
||||||
|
monkeypatch,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
LLMStreamEvent(type="text_delta", text="let me check."),
|
||||||
|
LLMStreamEvent(
|
||||||
|
type="tool_call",
|
||||||
|
tool_call={
|
||||||
|
"id": "call_ok",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
LLMStreamEvent(type="text_delta", text="it's sunny."),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(pipeline._handle_turn("weather?"))
|
||||||
|
for _ in range(200):
|
||||||
|
if any(e.get("type") == "assistant.tool_call" for e in events):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
await pipeline.handle_tool_call_results(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_call_id": "call_ok",
|
||||||
|
"name": "weather",
|
||||||
|
"output": {"temp": 21},
|
||||||
|
"status": {"code": 200, "message": "ok"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await task
|
||||||
|
|
||||||
|
assert any(e.get("type") == "assistant.tool_call" for e in events)
|
||||||
|
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||||
|
assert finals
|
||||||
|
assert "it's sunny" in finals[-1].get("text", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_turn_with_tool_call_timeout(monkeypatch):
|
||||||
|
pipeline, events = _build_pipeline(
|
||||||
|
monkeypatch,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
LLMStreamEvent(
|
||||||
|
type="tool_call",
|
||||||
|
tool_call={
|
||||||
|
"id": "call_timeout",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
LLMStreamEvent(type="text_delta", text="fallback answer."),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
|
||||||
|
|
||||||
|
await pipeline._handle_turn("query")
|
||||||
|
|
||||||
|
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||||
|
assert finals
|
||||||
|
assert "fallback answer" in finals[-1].get("text", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_duplicate_tool_results_are_ignored(monkeypatch):
|
||||||
|
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||||
|
|
||||||
|
await pipeline.handle_tool_call_results(
|
||||||
|
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
|
||||||
|
)
|
||||||
|
await pipeline.handle_tool_call_results(
|
||||||
|
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
|
||||||
|
)
|
||||||
|
result = await pipeline._wait_for_single_tool_result("call_dup")
|
||||||
|
|
||||||
|
assert result.get("output", {}).get("value") == 1
|
||||||
Reference in New Issue
Block a user