From 4b8da32787d49a51b3019a7fdcfb9d79df741ec6 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Tue, 10 Feb 2026 16:28:20 +0800 Subject: [PATCH] Update code for tool call --- engine/core/duplex_pipeline.py | 397 ++++++++++++++++++++++------ engine/core/session.py | 3 + engine/docs/ws_v1_schema.md | 20 ++ engine/models/ws_v1.py | 6 + engine/services/base.py | 15 +- engine/services/llm.py | 107 +++++++- engine/tests/test_tool_call_flow.py | 219 +++++++++++++++ 7 files changed, 676 insertions(+), 91 deletions(-) create mode 100644 engine/tests/test_tool_call_flow.py diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 3d2a1dd..3b683ab 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -12,8 +12,9 @@ event-driven design. """ import asyncio +import json import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from loguru import logger @@ -26,7 +27,7 @@ from models.ws_v1 import ev from processors.eou import EouDetector from processors.vad import SileroVAD, VADProcessor from services.asr import BufferedASRService -from services.base import BaseASRService, BaseLLMService, BaseTTSService +from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent from services.llm import MockLLMService, OpenAILLMService from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_tts import SiliconFlowTTSService @@ -55,6 +56,60 @@ class DuplexPipeline: _SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"}) _SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"}) _MIN_SPLIT_SPOKEN_CHARS = 6 + _TOOL_WAIT_TIMEOUT_SECONDS = 15.0 + _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { + "search": { + "name": "search", + "description": "Search the internet for recent information", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + "calculator": { + "name": "calculator", + "description": "Evaluate a math expression", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + }, + "weather": { + "name": "weather", + "description": "Get weather by city name", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + "translate": { + "name": "translate", + "description": "Translate text to target language", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "target_lang": {"type": "string"}, + }, + "required": ["text", "target_lang"], + }, + }, + "knowledge": { + "name": "knowledge", + "description": "Query knowledge base by question", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "kb_id": {"type": "string"}, + }, + "required": ["query"], + }, + }, + } def __init__( self, @@ -158,6 +213,10 @@ class DuplexPipeline: self._runtime_greeting: Optional[str] = None self._runtime_knowledge: Dict[str, Any] = {} self._runtime_knowledge_base_id: Optional[str] = None + self._runtime_tools: List[Any] = [] + self._pending_tool_waiters: Dict[str, asyncio.Future] = {} + self._early_tool_results: Dict[str, Dict[str, Any]] = {} + self._completed_tool_call_ids: set[str] = set() logger.info(f"DuplexPipeline initialized for session {session_id}") @@ -208,8 +267,14 @@ class DuplexPipeline: if kb_id: self._runtime_knowledge_base_id = kb_id + tools_payload = metadata.get("tools") + if isinstance(tools_payload, list): + self._runtime_tools = tools_payload + if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): + self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) async def start(self) -> None: """Start the pipeline and connect services.""" @@ -234,6 +299,8 @@ class DuplexPipeline: if hasattr(self.llm_service, "set_knowledge_config"): self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + if hasattr(self.llm_service, "set_tool_schemas"): + self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) await self.llm_service.connect() @@ -585,6 +652,117 @@ class DuplexPipeline: cfg.setdefault("enabled", True) return cfg + def _resolved_tool_schemas(self) -> List[Dict[str, Any]]: + schemas: List[Dict[str, Any]] = [] + for item in self._runtime_tools: + if isinstance(item, str): + base = self._DEFAULT_TOOL_SCHEMAS.get(item) + if base: + schemas.append( + { + "type": "function", + "function": { + "name": base["name"], + "description": base.get("description") or "", + "parameters": base.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + continue + + if not isinstance(item, dict): + continue + + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + schemas.append( + { + "type": "function", + "function": { + "name": str(fn.get("name")), + "description": str(fn.get("description") or item.get("description") or ""), + "parameters": fn.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + continue + + if item.get("name"): + schemas.append( + { + "type": "function", + "function": { + "name": str(item.get("name")), + "description": str(item.get("description") or ""), + "parameters": item.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + return schemas + + async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None: + """Handle client tool execution results.""" + if not isinstance(results, list): + return + + for item in results: + if not isinstance(item, dict): + continue + call_id = str(item.get("tool_call_id") or item.get("id") or "").strip() + if not call_id: + continue + if call_id in self._completed_tool_call_ids: + continue + + waiter = self._pending_tool_waiters.get(call_id) + if waiter and not waiter.done(): + waiter.set_result(item) + self._completed_tool_call_ids.add(call_id) + continue + self._early_tool_results[call_id] = item + self._completed_tool_call_ids.add(call_id) + + async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]: + if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results: + return { + "tool_call_id": call_id, + "status": {"code": 208, "message": "tool_call result already handled"}, + "output": "", + } + if call_id in self._early_tool_results: + self._completed_tool_call_ids.add(call_id) + return self._early_tool_results.pop(call_id) + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending_tool_waiters[call_id] = future + try: + return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS) + except asyncio.TimeoutError: + self._completed_tool_call_ids.add(call_id) + return { + "tool_call_id": call_id, + "status": {"code": 504, "message": "tool_call timeout"}, + "output": "", + } + finally: + self._pending_tool_waiters.pop(call_id, None) + + def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: + if isinstance(item, LLMStreamEvent): + return item + if isinstance(item, str): + return LLMStreamEvent(type="text_delta", text=item) + if isinstance(item, dict): + event_type = str(item.get("type") or "") + if event_type in {"text_delta", "tool_call", "done"}: + return LLMStreamEvent( + type=event_type, # type: ignore[arg-type] + text=item.get("text"), + tool_call=item.get("tool_call"), + ) + return LLMStreamEvent(type="done") + async def _handle_turn(self, user_text: str) -> None: """ Handle a complete conversation turn. @@ -599,109 +777,176 @@ class DuplexPipeline: self._turn_start_time = time.time() self._first_audio_sent = False - # Get AI response (streaming) - messages = self.conversation.get_messages() full_response = "" + messages = self.conversation.get_messages() + max_rounds = 3 await self.conversation.start_assistant_turn() self._is_bot_speaking = True self._interrupt_event.clear() self._drop_outbound_audio = False - # Sentence buffer for streaming TTS - sentence_buffer = "" - pending_punctuation = "" first_audio_sent = False - spoken_sentence_count = 0 - - # Stream LLM response and TTS sentence by sentence - async for text_chunk in self.llm_service.generate_stream(messages): + for _ in range(max_rounds): if self._interrupt_event.is_set(): break - full_response += text_chunk - sentence_buffer += text_chunk - await self.conversation.update_assistant_text(text_chunk) + sentence_buffer = "" + pending_punctuation = "" + round_response = "" + tool_calls: List[Dict[str, Any]] = [] + allow_text_output = True - # Send LLM response streaming event to client - await self._send_event({ - **ev( - "assistant.response.delta", - trackId=self.session_id, - text=text_chunk, - ) - }, priority=40) - - # Check for sentence completion - synthesize immediately for low latency - while True: - split_result = extract_tts_sentence( - sentence_buffer, - end_chars=self._SENTENCE_END_CHARS, - trailing_chars=self._SENTENCE_TRAILING_CHARS, - closers=self._SENTENCE_CLOSERS, - min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS, - hold_trailing_at_buffer_end=True, - force=False, - ) - if not split_result: + async for raw_event in self.llm_service.generate_stream(messages): + if self._interrupt_event.is_set(): break - sentence, sentence_buffer = split_result - if not sentence: + + event = self._normalize_stream_event(raw_event) + if event.type == "tool_call": + tool_call = event.tool_call if isinstance(event.tool_call, dict) else None + if not tool_call: + continue + allow_text_output = False + tool_calls.append(tool_call) + await self._send_event( + { + **ev( + "assistant.tool_call", + trackId=self.session_id, + tool_call=tool_call, + ) + }, + priority=22, + ) continue - sentence = f"{pending_punctuation}{sentence}".strip() - pending_punctuation = "" - if not sentence: + if event.type != "text_delta": continue - # Avoid synthesizing punctuation-only fragments (e.g. standalone "!") - if not has_spoken_content(sentence): - pending_punctuation = sentence + text_chunk = event.text or "" + if not text_chunk: continue - if not self._interrupt_event.is_set(): - # Send track start on first audio - if not first_audio_sent: - await self._send_event({ + if not allow_text_output: + continue + + full_response += text_chunk + round_response += text_chunk + sentence_buffer += text_chunk + await self.conversation.update_assistant_text(text_chunk) + + await self._send_event( + { + **ev( + "assistant.response.delta", + trackId=self.session_id, + text=text_chunk, + ) + }, + priority=40, + ) + + while True: + split_result = extract_tts_sentence( + sentence_buffer, + end_chars=self._SENTENCE_END_CHARS, + trailing_chars=self._SENTENCE_TRAILING_CHARS, + closers=self._SENTENCE_CLOSERS, + min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS, + hold_trailing_at_buffer_end=True, + force=False, + ) + if not split_result: + break + sentence, sentence_buffer = split_result + if not sentence: + continue + + sentence = f"{pending_punctuation}{sentence}".strip() + pending_punctuation = "" + if not sentence: + continue + + if not has_spoken_content(sentence): + pending_punctuation = sentence + continue + + if not self._interrupt_event.is_set(): + if not first_audio_sent: + await self._send_event( + { + **ev( + "output.audio.start", + trackId=self.session_id, + ) + }, + priority=10, + ) + first_audio_sent = True + + await self._speak_sentence( + sentence, + fade_in_ms=0, + fade_out_ms=8, + ) + + remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set(): + if not first_audio_sent: + await self._send_event( + { **ev( "output.audio.start", trackId=self.session_id, ) - }, priority=10) - first_audio_sent = True - - await self._speak_sentence( - sentence, - fade_in_ms=0, - fade_out_ms=8, + }, + priority=10, ) - spoken_sentence_count += 1 - - # Send final LLM response event - if full_response and not self._interrupt_event.is_set(): - await self._send_event({ - **ev( - "assistant.response.final", - trackId=self.session_id, - text=full_response, + first_audio_sent = True + await self._speak_sentence( + remaining_text, + fade_in_ms=0, + fade_out_ms=8, ) - }, priority=20) - # Speak any remaining text - remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() - if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set(): - if not first_audio_sent: - await self._send_event({ + if not tool_calls: + break + + tool_results: List[Dict[str, Any]] = [] + for call in tool_calls: + call_id = str(call.get("id") or "").strip() + if not call_id: + continue + tool_results.append(await self._wait_for_single_tool_result(call_id)) + + messages = [ + *messages, + LLMMessage( + role="assistant", + content=round_response.strip(), + ), + LLMMessage( + role="system", + content=( + "Tool execution results were returned by the client. " + "Continue answering the user naturally using these results. " + "Do not request the same tool again in this turn.\n" + f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n" + f"tool_results={json.dumps(tool_results, ensure_ascii=False)}" + ), + ), + ] + + if full_response and not self._interrupt_event.is_set(): + await self._send_event( + { **ev( - "output.audio.start", + "assistant.response.final", trackId=self.session_id, + text=full_response, ) - }, priority=10) - first_audio_sent = True - await self._speak_sentence( - remaining_text, - fade_in_ms=0, - fade_out_ms=8, + }, + priority=20, ) # Send track end diff --git a/engine/core/session.py b/engine/core/session.py index c4fb193..3f8f18d 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -28,6 +28,7 @@ from models.ws_v1 import ( SessionStopMessage, InputTextMessage, ResponseCancelMessage, + ToolCallResultsMessage, ) @@ -174,6 +175,8 @@ class Session: logger.info(f"Session {self.id} graceful response.cancel") else: await self.pipeline.interrupt() + elif isinstance(message, ToolCallResultsMessage): + await self.pipeline.handle_tool_call_results(message.results) elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index de2cc2c..c84ce5a 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -110,6 +110,24 @@ Rules: } ``` +### `tool_call.results` + +Client tool execution results returned to server. + +```json +{ + "type": "tool_call.results", + "results": [ + { + "tool_call_id": "call_abc123", + "name": "weather", + "output": { "temp_c": 21, "condition": "sunny" }, + "status": { "code": 200, "message": "ok" } + } + ] +} +``` + ## Server -> Client Events All server events include: @@ -142,6 +160,8 @@ Common events: - Fields: `trackId`, `text` - `assistant.response.final` - Fields: `trackId`, `text` +- `assistant.tool_call` + - Fields: `trackId`, `tool_call` - `output.audio.start` - Fields: `trackId` - `output.audio.end` diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py index 7c51a90..b8f5524 100644 --- a/engine/models/ws_v1.py +++ b/engine/models/ws_v1.py @@ -39,12 +39,18 @@ class ResponseCancelMessage(BaseModel): graceful: bool = False +class ToolCallResultsMessage(BaseModel): + type: Literal["tool_call.results"] + results: list[Dict[str, Any]] = Field(default_factory=list) + + CLIENT_MESSAGE_TYPES = { "hello": HelloMessage, "session.start": SessionStartMessage, "session.stop": SessionStopMessage, "input.text": InputTextMessage, "response.cancel": ResponseCancelMessage, + "tool_call.results": ToolCallResultsMessage, } diff --git a/engine/services/base.py b/engine/services/base.py index 420428b..7238416 100644 --- a/engine/services/base.py +++ b/engine/services/base.py @@ -7,7 +7,7 @@ StreamEngine pattern. from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import AsyncIterator, Optional, List, Dict, Any +from typing import AsyncIterator, Optional, List, Dict, Any, Literal from enum import Enum @@ -52,6 +52,15 @@ class LLMMessage: return d +@dataclass +class LLMStreamEvent: + """Structured LLM stream event.""" + + type: Literal["text_delta", "tool_call", "done"] + text: Optional[str] = None + tool_call: Optional[Dict[str, Any]] = None + + @dataclass class TTSChunk: """TTS audio chunk.""" @@ -170,7 +179,7 @@ class BaseLLMService(ABC): messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None - ) -> AsyncIterator[str]: + ) -> AsyncIterator[LLMStreamEvent]: """ Generate response in streaming mode. @@ -180,7 +189,7 @@ class BaseLLMService(ABC): max_tokens: Maximum tokens to generate Yields: - Text chunks as they are generated + Stream events (text delta/tool call/done) """ pass diff --git a/engine/services/llm.py b/engine/services/llm.py index 6496a69..a25ff26 100644 --- a/engine/services/llm.py +++ b/engine/services/llm.py @@ -6,11 +6,12 @@ for real-time voice conversation. import os import asyncio +import uuid from typing import AsyncIterator, Optional, List, Dict, Any from loguru import logger from app.backend_client import search_knowledge_context -from services.base import BaseLLMService, LLMMessage, ServiceState +from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState # Try to import openai try: @@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService): self.client: Optional[AsyncOpenAI] = None self._cancel_event = asyncio.Event() self._knowledge_config: Dict[str, Any] = knowledge_config or {} + self._tool_schemas: List[Dict[str, Any]] = [] _RAG_DEFAULT_RESULTS = 5 _RAG_MAX_RESULTS = 8 @@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService): """Update runtime knowledge retrieval config.""" self._knowledge_config = config or {} + def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: + """Update runtime tool schemas.""" + self._tool_schemas = [] + if not isinstance(schemas, list): + return + for item in schemas: + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + self._tool_schemas.append(item) + elif item.get("name"): + self._tool_schemas.append( + { + "type": "function", + "function": { + "name": str(item.get("name")), + "description": str(item.get("description") or ""), + "parameters": item.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + @staticmethod def _coerce_int(value: Any, default: int) -> int: try: @@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService): messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None - ) -> AsyncIterator[str]: + ) -> AsyncIterator[LLMStreamEvent]: """ Generate response in streaming mode. @@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService): max_tokens: Maximum tokens to generate Yields: - Text chunks as they are generated + Structured stream events """ if not self.client: raise RuntimeError("LLM service not connected") @@ -276,25 +301,82 @@ class OpenAILLMService(BaseLLMService): rag_messages = await self._with_knowledge_context(messages) prepared = self._prepare_messages(rag_messages) self._cancel_event.clear() + tool_accumulator: Dict[int, Dict[str, str]] = {} + openai_tools = self._tool_schemas or None try: - stream = await self.client.chat.completions.create( + create_args: Dict[str, Any] = dict( model=self.model, messages=prepared, temperature=temperature, max_tokens=max_tokens, - stream=True + stream=True, ) + if openai_tools: + create_args["tools"] = openai_tools + create_args["tool_choice"] = "auto" + stream = await self.client.chat.completions.create(**create_args) async for chunk in stream: # Check for cancellation if self._cancel_event.is_set(): logger.info("LLM stream cancelled") break - - if chunk.choices and chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - yield content + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = getattr(choice, "delta", None) + if delta and getattr(delta, "content", None): + content = delta.content + yield LLMStreamEvent(type="text_delta", text=content) + + # OpenAI streams function calls via incremental tool_calls deltas. + tool_calls = getattr(delta, "tool_calls", None) if delta else None + if tool_calls: + for tc in tool_calls: + index = getattr(tc, "index", 0) or 0 + item = tool_accumulator.setdefault( + int(index), + {"id": "", "name": "", "arguments": ""}, + ) + tc_id = getattr(tc, "id", None) + if tc_id: + item["id"] = str(tc_id) + fn = getattr(tc, "function", None) + if fn: + fn_name = getattr(fn, "name", None) + if fn_name: + item["name"] = str(fn_name) + fn_args = getattr(fn, "arguments", None) + if fn_args: + item["arguments"] += str(fn_args) + + finish_reason = getattr(choice, "finish_reason", None) + if finish_reason == "tool_calls" and tool_accumulator: + for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]): + call_name = payload.get("name", "").strip() + if not call_name: + continue + call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}" + yield LLMStreamEvent( + type="tool_call", + tool_call={ + "id": call_id, + "type": "function", + "function": { + "name": call_name, + "arguments": payload.get("arguments", "") or "{}", + }, + }, + ) + yield LLMStreamEvent(type="done") + return + + if finish_reason in {"stop", "length", "content_filter"}: + yield LLMStreamEvent(type="done") + return except asyncio.CancelledError: logger.info("LLM stream cancelled via asyncio") @@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService): messages: List[LLMMessage], temperature: float = 0.7, max_tokens: Optional[int] = None - ) -> AsyncIterator[str]: + ) -> AsyncIterator[LLMStreamEvent]: response = await self.generate(messages, temperature, max_tokens) # Stream word by word words = response.split() for i, word in enumerate(words): if i > 0: - yield " " - yield word + yield LLMStreamEvent(type="text_delta", text=" ") + yield LLMStreamEvent(type="text_delta", text=word) await asyncio.sleep(0.05) # Simulate streaming delay + yield LLMStreamEvent(type="done") diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py new file mode 100644 index 0000000..bdf2889 --- /dev/null +++ b/engine/tests/test_tool_call_flow.py @@ -0,0 +1,219 @@ +import asyncio +from typing import Any, Dict, List + +import pytest + +from core.duplex_pipeline import DuplexPipeline +from models.ws_v1 import ToolCallResultsMessage, parse_client_message +from services.base import LLMStreamEvent + + +class _DummySileroVAD: + def __init__(self, *args, **kwargs): + pass + + def process_audio(self, _pcm: bytes) -> float: + return 0.0 + + +class _DummyVADProcessor: + def __init__(self, *args, **kwargs): + pass + + def process(self, _speech_prob: float): + return "Silence", 0.0 + + +class _DummyEouDetector: + def __init__(self, *args, **kwargs): + pass + + def process(self, _vad_status: str) -> bool: + return False + + def reset(self) -> None: + return None + + +class _FakeTransport: + async def send_event(self, _event: Dict[str, Any]) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + +class _FakeTTS: + async def synthesize_stream(self, _text: str): + if False: + yield None + + +class _FakeASR: + async def connect(self) -> None: + return None + + +class _FakeLLM: + def __init__(self, rounds: List[List[LLMStreamEvent]]): + self._rounds = rounds + self._call_index = 0 + + async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): + idx = self._call_index + self._call_index += 1 + events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")] + for event in events: + yield event + + +def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: + monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD) + monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor) + monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector) + + pipeline = DuplexPipeline( + transport=_FakeTransport(), + session_id="s_test", + llm_service=_FakeLLM(llm_rounds), + tts_service=_FakeTTS(), + asr_service=_FakeASR(), + ) + events: List[Dict[str, Any]] = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + events.append(event) + + async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8): + return None + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak) + return pipeline, events + + +@pytest.mark.asyncio +async def test_ws_message_parses_tool_call_results(): + msg = parse_client_message( + { + "type": "tool_call.results", + "results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}], + } + ) + assert isinstance(msg, ToolCallResultsMessage) + assert msg.results[0]["tool_call_id"] == "call_1" + + +@pytest.mark.asyncio +async def test_turn_without_tool_keeps_streaming(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="hello "), + LLMStreamEvent(type="text_delta", text="world."), + LLMStreamEvent(type="done"), + ] + ], + ) + + await pipeline._handle_turn("hi") + + event_types = [e.get("type") for e in events] + assert "assistant.response.delta" in event_types + assert "assistant.response.final" in event_types + assert "assistant.tool_call" not in event_types + + +@pytest.mark.asyncio +async def test_turn_with_tool_call_then_results(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="let me check."), + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_ok", + "type": "function", + "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="it's sunny."), + LLMStreamEvent(type="done"), + ], + ], + ) + + task = asyncio.create_task(pipeline._handle_turn("weather?")) + for _ in range(200): + if any(e.get("type") == "assistant.tool_call" for e in events): + break + await asyncio.sleep(0.005) + + await pipeline.handle_tool_call_results( + [ + { + "tool_call_id": "call_ok", + "name": "weather", + "output": {"temp": 21}, + "status": {"code": 200, "message": "ok"}, + } + ] + ) + await task + + assert any(e.get("type") == "assistant.tool_call" for e in events) + finals = [e for e in events if e.get("type") == "assistant.response.final"] + assert finals + assert "it's sunny" in finals[-1].get("text", "") + + +@pytest.mark.asyncio +async def test_turn_with_tool_call_timeout(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_timeout", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\":\"x\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="fallback answer."), + LLMStreamEvent(type="done"), + ], + ], + ) + pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01 + + await pipeline._handle_turn("query") + + finals = [e for e in events if e.get("type") == "assistant.response.final"] + assert finals + assert "fallback answer" in finals[-1].get("text", "") + + +@pytest.mark.asyncio +async def test_duplicate_tool_results_are_ignored(monkeypatch): + pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + + await pipeline.handle_tool_call_results( + [{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}] + ) + await pipeline.handle_tool_call_results( + [{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}] + ) + result = await pipeline._wait_for_single_tool_result("call_dup") + + assert result.get("output", {}).get("value") == 1