Update code for tool call

This commit is contained in:
Xin Wang
2026-02-10 16:28:20 +08:00
parent 539cf2fda2
commit 4b8da32787
7 changed files with 676 additions and 91 deletions

View File

@@ -12,8 +12,9 @@ event-driven design.
"""
import asyncio
import json
import time
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService
@@ -55,6 +56,60 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"search": {
"name": "search",
"description": "Search the internet for recent information",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
"calculator": {
"name": "calculator",
"description": "Evaluate a math expression",
"parameters": {
"type": "object",
"properties": {"expression": {"type": "string"}},
"required": ["expression"],
},
},
"weather": {
"name": "weather",
"description": "Get weather by city name",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
"translate": {
"name": "translate",
"description": "Translate text to target language",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string"},
"target_lang": {"type": "string"},
},
"required": ["text", "target_lang"],
},
},
"knowledge": {
"name": "knowledge",
"description": "Query knowledge base by question",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"kb_id": {"type": "string"},
},
"required": ["query"],
},
},
}
def __init__(
self,
@@ -158,6 +213,10 @@ class DuplexPipeline:
self._runtime_greeting: Optional[str] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
logger.info(f"DuplexPipeline initialized for session {session_id}")
@@ -208,8 +267,14 @@ class DuplexPipeline:
if kb_id:
self._runtime_knowledge_base_id = kb_id
tools_payload = metadata.get("tools")
if isinstance(tools_payload, list):
self._runtime_tools = tools_payload
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
async def start(self) -> None:
"""Start the pipeline and connect services."""
@@ -234,6 +299,8 @@ class DuplexPipeline:
if hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
if hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
await self.llm_service.connect()
@@ -585,6 +652,117 @@ class DuplexPipeline:
cfg.setdefault("enabled", True)
return cfg
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
schemas: List[Dict[str, Any]] = []
for item in self._runtime_tools:
if isinstance(item, str):
base = self._DEFAULT_TOOL_SCHEMAS.get(item)
if base:
schemas.append(
{
"type": "function",
"function": {
"name": base["name"],
"description": base.get("description") or "",
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(fn.get("name")),
"description": str(fn.get("description") or item.get("description") or ""),
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
},
}
)
continue
if item.get("name"):
schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
return schemas
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
"""Handle client tool execution results."""
if not isinstance(results, list):
return
for item in results:
if not isinstance(item, dict):
continue
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if call_id in self._completed_tool_call_ids:
continue
waiter = self._pending_tool_waiters.get(call_id)
if waiter and not waiter.done():
waiter.set_result(item)
self._completed_tool_call_ids.add(call_id)
continue
self._early_tool_results[call_id] = item
self._completed_tool_call_ids.add(call_id)
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
return {
"tool_call_id": call_id,
"status": {"code": 208, "message": "tool_call result already handled"},
"output": "",
}
if call_id in self._early_tool_results:
self._completed_tool_call_ids.add(call_id)
return self._early_tool_results.pop(call_id)
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_tool_waiters[call_id] = future
try:
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
except asyncio.TimeoutError:
self._completed_tool_call_ids.add(call_id)
return {
"tool_call_id": call_id,
"status": {"code": 504, "message": "tool_call timeout"},
"output": "",
}
finally:
self._pending_tool_waiters.pop(call_id, None)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
return item
if isinstance(item, str):
return LLMStreamEvent(type="text_delta", text=item)
if isinstance(item, dict):
event_type = str(item.get("type") or "")
if event_type in {"text_delta", "tool_call", "done"}:
return LLMStreamEvent(
type=event_type, # type: ignore[arg-type]
text=item.get("text"),
tool_call=item.get("tool_call"),
)
return LLMStreamEvent(type="done")
async def _handle_turn(self, user_text: str) -> None:
"""
Handle a complete conversation turn.
@@ -599,109 +777,176 @@ class DuplexPipeline:
self._turn_start_time = time.time()
self._first_audio_sent = False
# Get AI response (streaming)
messages = self.conversation.get_messages()
full_response = ""
messages = self.conversation.get_messages()
max_rounds = 3
await self.conversation.start_assistant_turn()
self._is_bot_speaking = True
self._interrupt_event.clear()
self._drop_outbound_audio = False
# Sentence buffer for streaming TTS
sentence_buffer = ""
pending_punctuation = ""
first_audio_sent = False
spoken_sentence_count = 0
# Stream LLM response and TTS sentence by sentence
async for text_chunk in self.llm_service.generate_stream(messages):
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
full_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
sentence_buffer = ""
pending_punctuation = ""
round_response = ""
tool_calls: List[Dict[str, Any]] = []
allow_text_output = True
# Send LLM response streaming event to client
await self._send_event({
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
}, priority=40)
# Check for sentence completion - synthesize immediately for low latency
while True:
split_result = extract_tts_sentence(
sentence_buffer,
end_chars=self._SENTENCE_END_CHARS,
trailing_chars=self._SENTENCE_TRAILING_CHARS,
closers=self._SENTENCE_CLOSERS,
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
hold_trailing_at_buffer_end=True,
force=False,
)
if not split_result:
async for raw_event in self.llm_service.generate_stream(messages):
if self._interrupt_event.is_set():
break
sentence, sentence_buffer = split_result
if not sentence:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
allow_text_output = False
tool_calls.append(tool_call)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
tool_call=tool_call,
)
},
priority=22,
)
continue
sentence = f"{pending_punctuation}{sentence}".strip()
pending_punctuation = ""
if not sentence:
if event.type != "text_delta":
continue
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
if not has_spoken_content(sentence):
pending_punctuation = sentence
text_chunk = event.text or ""
if not text_chunk:
continue
if not self._interrupt_event.is_set():
# Send track start on first audio
if not first_audio_sent:
await self._send_event({
if not allow_text_output:
continue
full_response += text_chunk
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
priority=40,
)
while True:
split_result = extract_tts_sentence(
sentence_buffer,
end_chars=self._SENTENCE_END_CHARS,
trailing_chars=self._SENTENCE_TRAILING_CHARS,
closers=self._SENTENCE_CLOSERS,
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
hold_trailing_at_buffer_end=True,
force=False,
)
if not split_result:
break
sentence, sentence_buffer = split_result
if not sentence:
continue
sentence = f"{pending_punctuation}{sentence}".strip()
pending_punctuation = ""
if not sentence:
continue
if not has_spoken_content(sentence):
pending_punctuation = sentence
continue
if not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
)
},
priority=10,
)
first_audio_sent = True
await self._speak_sentence(
sentence,
fade_in_ms=0,
fade_out_ms=8,
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
)
}, priority=10)
first_audio_sent = True
await self._speak_sentence(
sentence,
fade_in_ms=0,
fade_out_ms=8,
},
priority=10,
)
spoken_sentence_count += 1
# Send final LLM response event
if full_response and not self._interrupt_event.is_set():
await self._send_event({
**ev(
"assistant.response.final",
trackId=self.session_id,
text=full_response,
first_audio_sent = True
await self._speak_sentence(
remaining_text,
fade_in_ms=0,
fade_out_ms=8,
)
}, priority=20)
# Speak any remaining text
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
if not first_audio_sent:
await self._send_event({
if not tool_calls:
break
tool_results: List[Dict[str, Any]] = []
for call in tool_calls:
call_id = str(call.get("id") or "").strip()
if not call_id:
continue
tool_results.append(await self._wait_for_single_tool_result(call_id))
messages = [
*messages,
LLMMessage(
role="assistant",
content=round_response.strip(),
),
LLMMessage(
role="system",
content=(
"Tool execution results were returned by the client. "
"Continue answering the user naturally using these results. "
"Do not request the same tool again in this turn.\n"
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
),
),
]
if full_response and not self._interrupt_event.is_set():
await self._send_event(
{
**ev(
"output.audio.start",
"assistant.response.final",
trackId=self.session_id,
text=full_response,
)
}, priority=10)
first_audio_sent = True
await self._speak_sentence(
remaining_text,
fade_in_ms=0,
fade_out_ms=8,
},
priority=20,
)
# Send track end

View File

@@ -28,6 +28,7 @@ from models.ws_v1 import (
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
ToolCallResultsMessage,
)
@@ -174,6 +175,8 @@ class Session:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:

View File

@@ -110,6 +110,24 @@ Rules:
}
```
### `tool_call.results`
Client tool execution results returned to server.
```json
{
"type": "tool_call.results",
"results": [
{
"tool_call_id": "call_abc123",
"name": "weather",
"output": { "temp_c": 21, "condition": "sunny" },
"status": { "code": 200, "message": "ok" }
}
]
}
```
## Server -> Client Events
All server events include:
@@ -142,6 +160,8 @@ Common events:
- Fields: `trackId`, `text`
- `assistant.response.final`
- Fields: `trackId`, `text`
- `assistant.tool_call`
- Fields: `trackId`, `tool_call`
- `output.audio.start`
- Fields: `trackId`
- `output.audio.end`

View File

@@ -39,12 +39,18 @@ class ResponseCancelMessage(BaseModel):
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
"hello": HelloMessage,
"session.start": SessionStartMessage,
"session.stop": SessionStopMessage,
"input.text": InputTextMessage,
"response.cancel": ResponseCancelMessage,
"tool_call.results": ToolCallResultsMessage,
}

View File

@@ -7,7 +7,7 @@ StreamEngine pattern.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, List, Dict, Any
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
from enum import Enum
@@ -52,6 +52,15 @@ class LLMMessage:
return d
@dataclass
class LLMStreamEvent:
"""Structured LLM stream event."""
type: Literal["text_delta", "tool_call", "done"]
text: Optional[str] = None
tool_call: Optional[Dict[str, Any]] = None
@dataclass
class TTSChunk:
"""TTS audio chunk."""
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
max_tokens: Maximum tokens to generate
Yields:
Text chunks as they are generated
Stream events (text delta/tool call/done)
"""
pass

View File

@@ -6,11 +6,12 @@ for real-time voice conversation.
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger
from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, ServiceState
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
try:
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
max_tokens: Maximum tokens to generate
Yields:
Text chunks as they are generated
Structured stream events
"""
if not self.client:
raise RuntimeError("LLM service not connected")
@@ -276,15 +301,21 @@ class OpenAILLMService(BaseLLMService):
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try:
stream = await self.client.chat.completions.create(
create_args: Dict[str, Any] = dict(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens,
stream=True
stream=True,
)
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream:
# Check for cancellation
@@ -292,9 +323,60 @@ class OpenAILLMService(BaseLLMService):
logger.info("LLM stream cancelled")
break
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio")
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens)
# Stream word by word
words = response.split()
for i, word in enumerate(words):
if i > 0:
yield " "
yield word
yield LLMStreamEvent(type="text_delta", text=" ")
yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")

View File

@@ -0,0 +1,219 @@
import asyncio
from typing import Any, Dict, List
import pytest
from core.duplex_pipeline import DuplexPipeline
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
from services.base import LLMStreamEvent
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
pass
def process(self, _vad_status: str) -> bool:
return False
def reset(self) -> None:
return None
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeTTS:
async def synthesize_stream(self, _text: str):
if False:
yield None
class _FakeASR:
async def connect(self) -> None:
return None
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):
self._rounds = rounds
self._call_index = 0
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
idx = self._call_index
self._call_index += 1
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
for event in events:
yield event
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
pipeline = DuplexPipeline(
transport=_FakeTransport(),
session_id="s_test",
llm_service=_FakeLLM(llm_rounds),
tts_service=_FakeTTS(),
asr_service=_FakeASR(),
)
events: List[Dict[str, Any]] = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
events.append(event)
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
return None
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
return pipeline, events
@pytest.mark.asyncio
async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1"
@pytest.mark.asyncio
async def test_turn_without_tool_keeps_streaming(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "assistant.tool_call" not in event_types
@pytest.mark.asyncio
async def test_turn_with_tool_call_then_results(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="let me check."),
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_ok",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="it's sunny."),
LLMStreamEvent(type="done"),
],
],
)
task = asyncio.create_task(pipeline._handle_turn("weather?"))
for _ in range(200):
if any(e.get("type") == "assistant.tool_call" for e in events):
break
await asyncio.sleep(0.005)
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": "call_ok",
"name": "weather",
"output": {"temp": 21},
"status": {"code": 200, "message": "ok"},
}
]
)
await task
assert any(e.get("type") == "assistant.tool_call" for e in events)
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "it's sunny" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_turn_with_tool_call_timeout(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_timeout",
"type": "function",
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="fallback answer."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("query")
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "fallback answer" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_duplicate_tool_results_are_ignored(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
)
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
)
result = await pipeline._wait_for_single_tool_result("call_dup")
assert result.get("output", {}).get("value") == 1