Update code for tool call
This commit is contained in:
@@ -12,8 +12,9 @@ event-driven design.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
@@ -55,6 +56,60 @@ class DuplexPipeline:
|
||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"search": {
|
||||
"name": "search",
|
||||
"description": "Search the internet for recent information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"calculator": {
|
||||
"name": "calculator",
|
||||
"description": "Evaluate a math expression",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
"weather": {
|
||||
"name": "weather",
|
||||
"description": "Get weather by city name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
"translate": {
|
||||
"name": "translate",
|
||||
"description": "Translate text to target language",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"target_lang": {"type": "string"},
|
||||
},
|
||||
"required": ["text", "target_lang"],
|
||||
},
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "knowledge",
|
||||
"description": "Query knowledge base by question",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"kb_id": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -158,6 +213,10 @@ class DuplexPipeline:
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
self._runtime_knowledge_base_id: Optional[str] = None
|
||||
self._runtime_tools: List[Any] = []
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
@@ -208,8 +267,14 @@ class DuplexPipeline:
|
||||
if kb_id:
|
||||
self._runtime_knowledge_base_id = kb_id
|
||||
|
||||
tools_payload = metadata.get("tools")
|
||||
if isinstance(tools_payload, list):
|
||||
self._runtime_tools = tools_payload
|
||||
|
||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
@@ -234,6 +299,8 @@ class DuplexPipeline:
|
||||
|
||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
if hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
@@ -585,6 +652,117 @@ class DuplexPipeline:
|
||||
cfg.setdefault("enabled", True)
|
||||
return cfg
|
||||
|
||||
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
schemas: List[Dict[str, Any]] = []
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
|
||||
if base:
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": base["name"],
|
||||
"description": base.get("description") or "",
|
||||
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(fn.get("name")),
|
||||
"description": str(fn.get("description") or item.get("description") or ""),
|
||||
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if item.get("name"):
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(item.get("name")),
|
||||
"description": str(item.get("description") or ""),
|
||||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
return schemas
|
||||
|
||||
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
|
||||
"""Handle client tool execution results."""
|
||||
if not isinstance(results, list):
|
||||
return
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
if call_id in self._completed_tool_call_ids:
|
||||
continue
|
||||
|
||||
waiter = self._pending_tool_waiters.get(call_id)
|
||||
if waiter and not waiter.done():
|
||||
waiter.set_result(item)
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
continue
|
||||
self._early_tool_results[call_id] = item
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
|
||||
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
|
||||
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"status": {"code": 208, "message": "tool_call result already handled"},
|
||||
"output": "",
|
||||
}
|
||||
if call_id in self._early_tool_results:
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
return self._early_tool_results.pop(call_id)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_tool_waiters[call_id] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
|
||||
except asyncio.TimeoutError:
|
||||
self._completed_tool_call_ids.add(call_id)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"status": {"code": 504, "message": "tool_call timeout"},
|
||||
"output": "",
|
||||
}
|
||||
finally:
|
||||
self._pending_tool_waiters.pop(call_id, None)
|
||||
|
||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||
if isinstance(item, LLMStreamEvent):
|
||||
return item
|
||||
if isinstance(item, str):
|
||||
return LLMStreamEvent(type="text_delta", text=item)
|
||||
if isinstance(item, dict):
|
||||
event_type = str(item.get("type") or "")
|
||||
if event_type in {"text_delta", "tool_call", "done"}:
|
||||
return LLMStreamEvent(
|
||||
type=event_type, # type: ignore[arg-type]
|
||||
text=item.get("text"),
|
||||
tool_call=item.get("tool_call"),
|
||||
)
|
||||
return LLMStreamEvent(type="done")
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
@@ -599,109 +777,176 @@ class DuplexPipeline:
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
messages = self.conversation.get_messages()
|
||||
max_rounds = 3
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
self._drop_outbound_audio = False
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
first_audio_sent = False
|
||||
spoken_sentence_count = 0
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
for _ in range(max_rounds):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
round_response = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
allow_text_output = True
|
||||
|
||||
# Send LLM response streaming event to client
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
}, priority=40)
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
async for raw_event in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
|
||||
event = self._normalize_stream_event(raw_event)
|
||||
if event.type == "tool_call":
|
||||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||||
if not tool_call:
|
||||
continue
|
||||
allow_text_output = False
|
||||
tool_calls.append(tool_call)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_call",
|
||||
trackId=self.session_id,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
},
|
||||
priority=22,
|
||||
)
|
||||
continue
|
||||
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
if event.type != "text_delta":
|
||||
continue
|
||||
|
||||
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
text_chunk = event.text or ""
|
||||
if not text_chunk:
|
||||
continue
|
||||
|
||||
if not self._interrupt_event.is_set():
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self._send_event({
|
||||
if not allow_text_output:
|
||||
continue
|
||||
|
||||
full_response += text_chunk
|
||||
round_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
},
|
||||
priority=40,
|
||||
)
|
||||
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
continue
|
||||
|
||||
if not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
}, priority=10)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
spoken_sentence_count += 1
|
||||
|
||||
# Send final LLM response event
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
text=full_response,
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(
|
||||
remaining_text,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
}, priority=20)
|
||||
|
||||
# Speak any remaining text
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event({
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
for call in tool_calls:
|
||||
call_id = str(call.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
tool_results.append(await self._wait_for_single_tool_result(call_id))
|
||||
|
||||
messages = [
|
||||
*messages,
|
||||
LLMMessage(
|
||||
role="assistant",
|
||||
content=round_response.strip(),
|
||||
),
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=(
|
||||
"Tool execution results were returned by the client. "
|
||||
"Continue answering the user naturally using these results. "
|
||||
"Do not request the same tool again in this turn.\n"
|
||||
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||||
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
text=full_response,
|
||||
)
|
||||
}, priority=10)
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(
|
||||
remaining_text,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
},
|
||||
priority=20,
|
||||
)
|
||||
|
||||
# Send track end
|
||||
|
||||
@@ -28,6 +28,7 @@ from models.ws_v1 import (
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
ToolCallResultsMessage,
|
||||
)
|
||||
|
||||
|
||||
@@ -174,6 +175,8 @@ class Session:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
|
||||
@@ -110,6 +110,24 @@ Rules:
|
||||
}
|
||||
```
|
||||
|
||||
### `tool_call.results`
|
||||
|
||||
Client tool execution results returned to server.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [
|
||||
{
|
||||
"tool_call_id": "call_abc123",
|
||||
"name": "weather",
|
||||
"output": { "temp_c": 21, "condition": "sunny" },
|
||||
"status": { "code": 200, "message": "ok" }
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Server -> Client Events
|
||||
|
||||
All server events include:
|
||||
@@ -142,6 +160,8 @@ Common events:
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.response.final`
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.tool_call`
|
||||
- Fields: `trackId`, `tool_call`
|
||||
- `output.audio.start`
|
||||
- Fields: `trackId`
|
||||
- `output.audio.end`
|
||||
|
||||
@@ -39,12 +39,18 @@ class ResponseCancelMessage(BaseModel):
|
||||
graceful: bool = False
|
||||
|
||||
|
||||
class ToolCallResultsMessage(BaseModel):
|
||||
type: Literal["tool_call.results"]
|
||||
results: list[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
CLIENT_MESSAGE_TYPES = {
|
||||
"hello": HelloMessage,
|
||||
"session.start": SessionStartMessage,
|
||||
"session.stop": SessionStopMessage,
|
||||
"input.text": InputTextMessage,
|
||||
"response.cancel": ResponseCancelMessage,
|
||||
"tool_call.results": ToolCallResultsMessage,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ StreamEngine pattern.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -52,6 +52,15 @@ class LLMMessage:
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamEvent:
|
||||
"""Structured LLM stream event."""
|
||||
|
||||
type: Literal["text_delta", "tool_call", "done"]
|
||||
text: Optional[str] = None
|
||||
tool_call: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSChunk:
|
||||
"""TTS audio chunk."""
|
||||
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
Stream events (text delta/tool call/done)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ for real-time voice conversation.
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
try:
|
||||
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
|
||||
"""Update runtime knowledge retrieval config."""
|
||||
self._knowledge_config = config or {}
|
||||
|
||||
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||
"""Update runtime tool schemas."""
|
||||
self._tool_schemas = []
|
||||
if not isinstance(schemas, list):
|
||||
return
|
||||
for item in schemas:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
self._tool_schemas.append(item)
|
||||
elif item.get("name"):
|
||||
self._tool_schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(item.get("name")),
|
||||
"description": str(item.get("description") or ""),
|
||||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
Structured stream events
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
@@ -276,25 +301,82 @@ class OpenAILLMService(BaseLLMService):
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
self._cancel_event.clear()
|
||||
tool_accumulator: Dict[int, Dict[str, str]] = {}
|
||||
openai_tools = self._tool_schemas or None
|
||||
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
create_args: Dict[str, Any] = dict(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
if openai_tools:
|
||||
create_args["tools"] = openai_tools
|
||||
create_args["tool_choice"] = "auto"
|
||||
stream = await self.client.chat.completions.create(**create_args)
|
||||
|
||||
async for chunk in stream:
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("LLM stream cancelled")
|
||||
break
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
if delta and getattr(delta, "content", None):
|
||||
content = delta.content
|
||||
yield LLMStreamEvent(type="text_delta", text=content)
|
||||
|
||||
# OpenAI streams function calls via incremental tool_calls deltas.
|
||||
tool_calls = getattr(delta, "tool_calls", None) if delta else None
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
index = getattr(tc, "index", 0) or 0
|
||||
item = tool_accumulator.setdefault(
|
||||
int(index),
|
||||
{"id": "", "name": "", "arguments": ""},
|
||||
)
|
||||
tc_id = getattr(tc, "id", None)
|
||||
if tc_id:
|
||||
item["id"] = str(tc_id)
|
||||
fn = getattr(tc, "function", None)
|
||||
if fn:
|
||||
fn_name = getattr(fn, "name", None)
|
||||
if fn_name:
|
||||
item["name"] = str(fn_name)
|
||||
fn_args = getattr(fn, "arguments", None)
|
||||
if fn_args:
|
||||
item["arguments"] += str(fn_args)
|
||||
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
if finish_reason == "tool_calls" and tool_accumulator:
|
||||
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
|
||||
call_name = payload.get("name", "").strip()
|
||||
if not call_name:
|
||||
continue
|
||||
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
|
||||
yield LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call_name,
|
||||
"arguments": payload.get("arguments", "") or "{}",
|
||||
},
|
||||
},
|
||||
)
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
if finish_reason in {"stop", "length", "content_filter"}:
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM stream cancelled via asyncio")
|
||||
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
response = await self.generate(messages, temperature, max_tokens)
|
||||
|
||||
# Stream word by word
|
||||
words = response.split()
|
||||
for i, word in enumerate(words):
|
||||
if i > 0:
|
||||
yield " "
|
||||
yield word
|
||||
yield LLMStreamEvent(type="text_delta", text=" ")
|
||||
yield LLMStreamEvent(type="text_delta", text=word)
|
||||
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
219
engine/tests/test_tool_call_flow.py
Normal file
219
engine/tests/test_tool_call_flow.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
|
||||
from services.base import LLMStreamEvent
|
||||
|
||||
|
||||
class _DummySileroVAD:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process_audio(self, _pcm: bytes) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class _DummyVADProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _speech_prob: float):
|
||||
return "Silence", 0.0
|
||||
|
||||
|
||||
class _DummyEouDetector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _vad_status: str) -> bool:
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
async def send_event(self, _event: Dict[str, Any]) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTTS:
|
||||
async def synthesize_stream(self, _text: str):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
|
||||
class _FakeASR:
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
||||
self._rounds = rounds
|
||||
self._call_index = 0
|
||||
|
||||
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
|
||||
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
||||
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
|
||||
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
|
||||
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
|
||||
|
||||
pipeline = DuplexPipeline(
|
||||
transport=_FakeTransport(),
|
||||
session_id="s_test",
|
||||
llm_service=_FakeLLM(llm_rounds),
|
||||
tts_service=_FakeTTS(),
|
||||
asr_service=_FakeASR(),
|
||||
)
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
events.append(event)
|
||||
|
||||
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
|
||||
return pipeline, events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_tool_call_results():
|
||||
msg = parse_client_message(
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
||||
}
|
||||
)
|
||||
assert isinstance(msg, ToolCallResultsMessage)
|
||||
assert msg.results[0]["tool_call_id"] == "call_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="hello "),
|
||||
LLMStreamEvent(type="text_delta", text="world."),
|
||||
LLMStreamEvent(type="done"),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
await pipeline._handle_turn("hi")
|
||||
|
||||
event_types = [e.get("type") for e in events]
|
||||
assert "assistant.response.delta" in event_types
|
||||
assert "assistant.response.final" in event_types
|
||||
assert "assistant.tool_call" not in event_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_with_tool_call_then_results(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="let me check."),
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_ok",
|
||||
"type": "function",
|
||||
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="it's sunny."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
task = asyncio.create_task(pipeline._handle_turn("weather?"))
|
||||
for _ in range(200):
|
||||
if any(e.get("type") == "assistant.tool_call" for e in events):
|
||||
break
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call_ok",
|
||||
"name": "weather",
|
||||
"output": {"temp": 21},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
]
|
||||
)
|
||||
await task
|
||||
|
||||
assert any(e.get("type") == "assistant.tool_call" for e in events)
|
||||
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "it's sunny" in finals[-1].get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_with_tool_call_timeout(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_timeout",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="fallback answer."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
|
||||
|
||||
await pipeline._handle_turn("query")
|
||||
|
||||
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "fallback answer" in finals[-1].get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_tool_results_are_ignored(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
|
||||
)
|
||||
await pipeline.handle_tool_call_results(
|
||||
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
|
||||
)
|
||||
result = await pipeline._wait_for_single_tool_result("call_dup")
|
||||
|
||||
assert result.get("output", {}).get("value") == 1
|
||||
Reference in New Issue
Block a user