Update code for tool call

This commit is contained in:
Xin Wang
2026-02-10 16:28:20 +08:00
parent 539cf2fda2
commit 4b8da32787
7 changed files with 676 additions and 91 deletions

View File

@@ -12,8 +12,9 @@ event-driven design.
""" """
import asyncio import asyncio
import json
import time import time
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
from loguru import logger from loguru import logger
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
from services.llm import MockLLMService, OpenAILLMService from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService from services.siliconflow_tts import SiliconFlowTTSService
@@ -55,6 +56,60 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"}) _SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""}) _SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6 _MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"search": {
"name": "search",
"description": "Search the internet for recent information",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
"calculator": {
"name": "calculator",
"description": "Evaluate a math expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
"weather": {
"name": "weather",
"description": "Get weather by city name",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
"translate": {
"name": "translate",
"description": "Translate text to target language",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string"},
"target_lang": {"type": "string"},
},
"required": ["text", "target_lang"],
},
},
"knowledge": {
"name": "knowledge",
"description": "Query knowledge base by question",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"kb_id": {"type": "string"},
},
"required": ["query"],
},
},
}
def __init__( def __init__(
self, self,
@@ -158,6 +213,10 @@ class DuplexPipeline:
self._runtime_greeting: Optional[str] = None self._runtime_greeting: Optional[str] = None
self._runtime_knowledge: Dict[str, Any] = {} self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
logger.info(f"DuplexPipeline initialized for session {session_id}") logger.info(f"DuplexPipeline initialized for session {session_id}")
@@ -208,8 +267,14 @@ class DuplexPipeline:
if kb_id: if kb_id:
self._runtime_knowledge_base_id = kb_id self._runtime_knowledge_base_id = kb_id
tools_payload = metadata.get("tools")
if isinstance(tools_payload, list):
self._runtime_tools = tools_payload
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
async def start(self) -> None: async def start(self) -> None:
"""Start the pipeline and connect services.""" """Start the pipeline and connect services."""
@@ -234,6 +299,8 @@ class DuplexPipeline:
if hasattr(self.llm_service, "set_knowledge_config"): if hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
await self.llm_service.connect() await self.llm_service.connect()
@@ -585,6 +652,117 @@ class DuplexPipeline:
cfg.setdefault("enabled", True) cfg.setdefault("enabled", True)
return cfg return cfg
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
schemas: List[Dict[str, Any]] = []
for item in self._runtime_tools:
if isinstance(item, str):
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
if base:
schemas.append(
{
"type": "function",
"function": {
"name": base["name"],
"description": base.get("description") or "",
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(fn.get("name")),
"description": str(fn.get("description") or item.get("description") or ""),
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if item.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
return schemas
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
"""Handle client tool execution results."""
if not isinstance(results, list):
return
for item in results:
if not isinstance(item, dict):
continue
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if call_id in self._completed_tool_call_ids:
continue
waiter = self._pending_tool_waiters.get(call_id)
if waiter and not waiter.done():
waiter.set_result(item)
self._completed_tool_call_ids.add(call_id)
continue
self._early_tool_results[call_id] = item
self._completed_tool_call_ids.add(call_id)
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
return {
"tool_call_id": call_id,
"status": {"code": 208, "message": "tool_call result already handled"},
"output": "",
}
if call_id in self._early_tool_results:
self._completed_tool_call_ids.add(call_id)
return self._early_tool_results.pop(call_id)
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_tool_waiters[call_id] = future
try:
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
except asyncio.TimeoutError:
self._completed_tool_call_ids.add(call_id)
return {
"tool_call_id": call_id,
"status": {"code": 504, "message": "tool_call timeout"},
"output": "",
}
finally:
self._pending_tool_waiters.pop(call_id, None)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
return item
if isinstance(item, str):
return LLMStreamEvent(type="text_delta", text=item)
if isinstance(item, dict):
event_type = str(item.get("type") or "")
if event_type in {"text_delta", "tool_call", "done"}:
return LLMStreamEvent(
type=event_type, # type: ignore[arg-type]
text=item.get("text"),
tool_call=item.get("tool_call"),
)
return LLMStreamEvent(type="done")
async def _handle_turn(self, user_text: str) -> None: async def _handle_turn(self, user_text: str) -> None:
""" """
Handle a complete conversation turn. Handle a complete conversation turn.
@@ -599,40 +777,75 @@ class DuplexPipeline:
self._turn_start_time = time.time() self._turn_start_time = time.time()
self._first_audio_sent = False self._first_audio_sent = False
# Get AI response (streaming)
messages = self.conversation.get_messages()
full_response = "" full_response = ""
messages = self.conversation.get_messages()
max_rounds = 3
await self.conversation.start_assistant_turn() await self.conversation.start_assistant_turn()
self._is_bot_speaking = True self._is_bot_speaking = True
self._interrupt_event.clear() self._interrupt_event.clear()
self._drop_outbound_audio = False self._drop_outbound_audio = False
# Sentence buffer for streaming TTS
sentence_buffer = ""
pending_punctuation = ""
first_audio_sent = False first_audio_sent = False
spoken_sentence_count = 0 for _ in range(max_rounds):
# Stream LLM response and TTS sentence by sentence
async for text_chunk in self.llm_service.generate_stream(messages):
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
break break
sentence_buffer = ""
pending_punctuation = ""
round_response = ""
tool_calls: List[Dict[str, Any]] = []
allow_text_output = True
async for raw_event in self.llm_service.generate_stream(messages):
if self._interrupt_event.is_set():
break
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
allow_text_output = False
tool_calls.append(tool_call)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
tool_call=tool_call,
)
},
priority=22,
)
continue
if event.type != "text_delta":
continue
text_chunk = event.text or ""
if not text_chunk:
continue
if not allow_text_output:
continue
full_response += text_chunk full_response += text_chunk
round_response += text_chunk
sentence_buffer += text_chunk sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk) await self.conversation.update_assistant_text(text_chunk)
# Send LLM response streaming event to client await self._send_event(
await self._send_event({ {
**ev( **ev(
"assistant.response.delta", "assistant.response.delta",
trackId=self.session_id, trackId=self.session_id,
text=text_chunk, text=text_chunk,
) )
}, priority=40) },
priority=40,
)
# Check for sentence completion - synthesize immediately for low latency
while True: while True:
split_result = extract_tts_sentence( split_result = extract_tts_sentence(
sentence_buffer, sentence_buffer,
@@ -654,20 +867,21 @@ class DuplexPipeline:
if not sentence: if not sentence:
continue continue
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
if not has_spoken_content(sentence): if not has_spoken_content(sentence):
pending_punctuation = sentence pending_punctuation = sentence
continue continue
if not self._interrupt_event.is_set(): if not self._interrupt_event.is_set():
# Send track start on first audio
if not first_audio_sent: if not first_audio_sent:
await self._send_event({ await self._send_event(
{
**ev( **ev(
"output.audio.start", "output.audio.start",
trackId=self.session_id, trackId=self.session_id,
) )
}, priority=10) },
priority=10,
)
first_audio_sent = True first_audio_sent = True
await self._speak_sentence( await self._speak_sentence(
@@ -675,28 +889,19 @@ class DuplexPipeline:
fade_in_ms=0, fade_in_ms=0,
fade_out_ms=8, fade_out_ms=8,
) )
spoken_sentence_count += 1
# Send final LLM response event
if full_response and not self._interrupt_event.is_set():
await self._send_event({
**ev(
"assistant.response.final",
trackId=self.session_id,
text=full_response,
)
}, priority=20)
# Speak any remaining text
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set(): if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
if not first_audio_sent: if not first_audio_sent:
await self._send_event({ await self._send_event(
{
**ev( **ev(
"output.audio.start", "output.audio.start",
trackId=self.session_id, trackId=self.session_id,
) )
}, priority=10) },
priority=10,
)
first_audio_sent = True first_audio_sent = True
await self._speak_sentence( await self._speak_sentence(
remaining_text, remaining_text,
@@ -704,6 +909,46 @@ class DuplexPipeline:
fade_out_ms=8, fade_out_ms=8,
) )
if not tool_calls:
break
tool_results: List[Dict[str, Any]] = []
for call in tool_calls:
call_id = str(call.get("id") or "").strip()
if not call_id:
continue
tool_results.append(await self._wait_for_single_tool_result(call_id))
messages = [
*messages,
LLMMessage(
role="assistant",
content=round_response.strip(),
),
LLMMessage(
role="system",
content=(
"Tool execution results were returned by the client. "
"Continue answering the user naturally using these results. "
"Do not request the same tool again in this turn.\n"
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
),
),
]
if full_response and not self._interrupt_event.is_set():
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
text=full_response,
)
},
priority=20,
)
# Send track end # Send track end
if first_audio_sent: if first_audio_sent:
await self._send_event({ await self._send_event({

View File

@@ -28,6 +28,7 @@ from models.ws_v1 import (
SessionStopMessage, SessionStopMessage,
InputTextMessage, InputTextMessage,
ResponseCancelMessage, ResponseCancelMessage,
ToolCallResultsMessage,
) )
@@ -174,6 +175,8 @@ class Session:
logger.info(f"Session {self.id} graceful response.cancel") logger.info(f"Session {self.id} graceful response.cancel")
else: else:
await self.pipeline.interrupt() await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
elif isinstance(message, SessionStopMessage): elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason) await self._handle_session_stop(message.reason)
else: else:

View File

@@ -110,6 +110,24 @@ Rules:
} }
``` ```
### `tool_call.results`
Client tool execution results returned to server.
```json
{
"type": "tool_call.results",
"results": [
{
"tool_call_id": "call_abc123",
"name": "weather",
"output": { "temp_c": 21, "condition": "sunny" },
"status": { "code": 200, "message": "ok" }
}
]
}
```
## Server -> Client Events ## Server -> Client Events
All server events include: All server events include:
@@ -142,6 +160,8 @@ Common events:
- Fields: `trackId`, `text` - Fields: `trackId`, `text`
- `assistant.response.final` - `assistant.response.final`
- Fields: `trackId`, `text` - Fields: `trackId`, `text`
- `assistant.tool_call`
- Fields: `trackId`, `tool_call`
- `output.audio.start` - `output.audio.start`
- Fields: `trackId` - Fields: `trackId`
- `output.audio.end` - `output.audio.end`

View File

@@ -39,12 +39,18 @@ class ResponseCancelMessage(BaseModel):
graceful: bool = False graceful: bool = False
class ToolCallResultsMessage(BaseModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = { CLIENT_MESSAGE_TYPES = {
"hello": HelloMessage, "hello": HelloMessage,
"session.start": SessionStartMessage, "session.start": SessionStartMessage,
"session.stop": SessionStopMessage, "session.stop": SessionStopMessage,
"input.text": InputTextMessage, "input.text": InputTextMessage,
"response.cancel": ResponseCancelMessage, "response.cancel": ResponseCancelMessage,
"tool_call.results": ToolCallResultsMessage,
} }

View File

@@ -7,7 +7,7 @@ StreamEngine pattern.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, List, Dict, Any from typing import AsyncIterator, Optional, List, Dict, Any, Literal
from enum import Enum from enum import Enum
@@ -52,6 +52,15 @@ class LLMMessage:
return d return d
@dataclass
class LLMStreamEvent:
"""Structured LLM stream event."""
type: Literal["text_delta", "tool_call", "done"]
text: Optional[str] = None
tool_call: Optional[Dict[str, Any]] = None
@dataclass @dataclass
class TTSChunk: class TTSChunk:
"""TTS audio chunk.""" """TTS audio chunk."""
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
messages: List[LLMMessage], messages: List[LLMMessage],
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
) -> AsyncIterator[str]: ) -> AsyncIterator[LLMStreamEvent]:
""" """
Generate response in streaming mode. Generate response in streaming mode.
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
Yields: Yields:
Text chunks as they are generated Stream events (text delta/tool call/done)
""" """
pass pass

View File

@@ -6,11 +6,12 @@ for real-time voice conversation.
import os import os
import asyncio import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger from loguru import logger
from app.backend_client import search_knowledge_context from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, ServiceState from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai # Try to import openai
try: try:
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event() self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {} self._knowledge_config: Dict[str, Any] = knowledge_config or {}
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5 _RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8 _RAG_MAX_RESULTS = 8
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
"""Update runtime knowledge retrieval config.""" """Update runtime knowledge retrieval config."""
self._knowledge_config = config or {} self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod @staticmethod
def _coerce_int(value: Any, default: int) -> int: def _coerce_int(value: Any, default: int) -> int:
try: try:
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
messages: List[LLMMessage], messages: List[LLMMessage],
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
) -> AsyncIterator[str]: ) -> AsyncIterator[LLMStreamEvent]:
""" """
Generate response in streaming mode. Generate response in streaming mode.
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
Yields: Yields:
Text chunks as they are generated Structured stream events
""" """
if not self.client: if not self.client:
raise RuntimeError("LLM service not connected") raise RuntimeError("LLM service not connected")
@@ -276,15 +301,21 @@ class OpenAILLMService(BaseLLMService):
rag_messages = await self._with_knowledge_context(messages) rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages) prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear() self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try: try:
stream = await self.client.chat.completions.create( create_args: Dict[str, Any] = dict(
model=self.model, model=self.model,
messages=prepared, messages=prepared,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True stream=True,
) )
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream: async for chunk in stream:
# Check for cancellation # Check for cancellation
@@ -292,9 +323,60 @@ class OpenAILLMService(BaseLLMService):
logger.info("LLM stream cancelled") logger.info("LLM stream cancelled")
break break
if chunk.choices and chunk.choices[0].delta.content: if not chunk.choices:
content = chunk.choices[0].delta.content continue
yield content
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio") logger.info("LLM stream cancelled via asyncio")
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
messages: List[LLMMessage], messages: List[LLMMessage],
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
) -> AsyncIterator[str]: ) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens) response = await self.generate(messages, temperature, max_tokens)
# Stream word by word # Stream word by word
words = response.split() words = response.split()
for i, word in enumerate(words): for i, word in enumerate(words):
if i > 0: if i > 0:
yield " " yield LLMStreamEvent(type="text_delta", text=" ")
yield word yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")

View File

@@ -0,0 +1,219 @@
import asyncio
from typing import Any, Dict, List
import pytest
from core.duplex_pipeline import DuplexPipeline
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
from services.base import LLMStreamEvent
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
pass
def process(self, _vad_status: str) -> bool:
return False
def reset(self) -> None:
return None
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeTTS:
async def synthesize_stream(self, _text: str):
if False:
yield None
class _FakeASR:
async def connect(self) -> None:
return None
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):
self._rounds = rounds
self._call_index = 0
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
idx = self._call_index
self._call_index += 1
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
for event in events:
yield event
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
pipeline = DuplexPipeline(
transport=_FakeTransport(),
session_id="s_test",
llm_service=_FakeLLM(llm_rounds),
tts_service=_FakeTTS(),
asr_service=_FakeASR(),
)
events: List[Dict[str, Any]] = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
events.append(event)
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
return None
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
return pipeline, events
@pytest.mark.asyncio
async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1"
@pytest.mark.asyncio
async def test_turn_without_tool_keeps_streaming(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "assistant.tool_call" not in event_types
@pytest.mark.asyncio
async def test_turn_with_tool_call_then_results(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="let me check."),
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_ok",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="it's sunny."),
LLMStreamEvent(type="done"),
],
],
)
task = asyncio.create_task(pipeline._handle_turn("weather?"))
for _ in range(200):
if any(e.get("type") == "assistant.tool_call" for e in events):
break
await asyncio.sleep(0.005)
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": "call_ok",
"name": "weather",
"output": {"temp": 21},
"status": {"code": 200, "message": "ok"},
}
]
)
await task
assert any(e.get("type") == "assistant.tool_call" for e in events)
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "it's sunny" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_turn_with_tool_call_timeout(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_timeout",
"type": "function",
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="fallback answer."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("query")
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "fallback answer" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_duplicate_tool_results_are_ignored(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
)
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
)
result = await pipeline._wait_for_single_tool_result("call_dup")
assert result.get("output", {}).get("value") == 1