Files
AI-VideoAssistant/engine/tests/test_tool_call_flow.py

459 lines
14 KiB
Python

import asyncio
import json
from typing import Any, Dict, List
import pytest
from core.conversation import ConversationState
from core.duplex_pipeline import DuplexPipeline
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
from services.base import LLMStreamEvent
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
pass
def process(self, _vad_status: str) -> bool:
return False
def reset(self) -> None:
return None
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeTTS:
async def synthesize_stream(self, _text: str):
if False:
yield None
class _FakeASR:
async def connect(self) -> None:
return None
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):
self._rounds = rounds
self._call_index = 0
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
idx = self._call_index
self._call_index += 1
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
for event in events:
yield event
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
pipeline = DuplexPipeline(
transport=_FakeTransport(),
session_id="s_test",
llm_service=_FakeLLM(llm_rounds),
tts_service=_FakeTTS(),
asr_service=_FakeASR(),
)
events: List[Dict[str, Any]] = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
events.append(event)
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
return None
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
return pipeline, events
def test_pipeline_uses_default_tools_from_settings(monkeypatch):
monkeypatch.setattr(
"core.duplex_pipeline.settings.tools",
[
"current_time",
"calculator",
{
"name": "weather",
"description": "Get weather by city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"executor": "server",
},
],
)
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
cfg = pipeline.resolved_runtime_config()
assert cfg["tools"]["allowlist"] == ["calculator", "current_time", "weather"]
schemas = pipeline._resolved_tool_schemas()
names = [s.get("function", {}).get("name") for s in schemas if isinstance(s, dict)]
assert "current_time" in names
assert "calculator" in names
assert "weather" in names
def test_pipeline_exposes_unknown_string_tools_with_fallback_schema(monkeypatch):
monkeypatch.setattr("core.duplex_pipeline.settings.tools", ["custom_system_cmd"])
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
schemas = pipeline._resolved_tool_schemas()
tool_schema = next((s for s in schemas if s.get("function", {}).get("name") == "custom_system_cmd"), None)
assert tool_schema is not None
assert tool_schema.get("function", {}).get("parameters", {}).get("type") == "object"
def test_pipeline_assigns_default_client_executor_for_system_string_tools(monkeypatch):
monkeypatch.setattr("core.duplex_pipeline.settings.tools", ["increase_volume"])
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
tool_call = {
"type": "function",
"function": {"name": "increase_volume", "arguments": "{}"},
}
assert pipeline._tool_executor(tool_call) == "client"
@pytest.mark.asyncio
async def test_pipeline_applies_default_args_to_tool_call(monkeypatch):
pipeline, _events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_defaults",
"type": "function",
"function": {"name": "weather", "arguments": "{}"},
},
),
LLMStreamEvent(type="done"),
],
[LLMStreamEvent(type="done")],
],
)
pipeline.apply_runtime_overrides(
{
"tools": [
{
"type": "function",
"executor": "server",
"defaultArgs": {"city": "Hangzhou", "unit": "c"},
"function": {
"name": "weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
},
}
]
}
)
captured: Dict[str, Any] = {}
async def _server_exec(call: Dict[str, Any]) -> Dict[str, Any]:
captured["call"] = call
return {
"tool_call_id": str(call.get("id") or ""),
"name": "weather",
"output": {"ok": True},
"status": {"code": 200, "message": "ok"},
}
monkeypatch.setattr(pipeline, "_server_tool_executor", _server_exec)
await pipeline._handle_turn("weather?")
sent_call = captured.get("call")
assert isinstance(sent_call, dict)
args_raw = sent_call.get("function", {}).get("arguments")
args = json.loads(args_raw) if isinstance(args_raw, str) else {}
assert args.get("city") == "Hangzhou"
assert args.get("unit") == "c"
@pytest.mark.asyncio
async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0].tool_call_id == "call_1"
@pytest.mark.asyncio
async def test_turn_without_tool_keeps_streaming(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "assistant.tool_call" not in event_types
@pytest.mark.asyncio
@pytest.mark.parametrize(
"metadata",
[
{"output": {"mode": "text"}},
{"services": {"tts": {"enabled": False}}},
],
)
async def test_text_output_mode_skips_audio_events(monkeypatch, metadata):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
pipeline.apply_runtime_overrides(metadata)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "output.audio.start" not in event_types
assert "output.audio.end" not in event_types
@pytest.mark.asyncio
async def test_turn_with_tool_call_then_results(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="let me check."),
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_ok",
"executor": "client",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="it's sunny."),
LLMStreamEvent(type="done"),
],
],
)
task = asyncio.create_task(pipeline._handle_turn("weather?"))
for _ in range(200):
if any(e.get("type") == "assistant.tool_call" for e in events):
break
await asyncio.sleep(0.005)
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": "call_ok",
"name": "weather",
"output": {"temp": 21},
"status": {"code": 200, "message": "ok"},
}
]
)
await task
assert any(e.get("type") == "assistant.tool_call" for e in events)
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "it's sunny" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_turn_with_tool_call_timeout(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_timeout",
"executor": "client",
"type": "function",
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="fallback answer."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("query")
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "fallback answer" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_duplicate_tool_results_are_ignored(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
)
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
)
result = await pipeline._wait_for_single_tool_result("call_dup")
assert result.get("output", {}).get("value") == 1
@pytest.mark.asyncio
async def test_server_calculator_emits_tool_result(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_calc",
"executor": "server",
"type": "function",
"function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="done."),
LLMStreamEvent(type="done"),
],
],
)
await pipeline._handle_turn("calc")
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
assert tool_results
payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 200
assert payload.get("output", {}).get("result") == 3
@pytest.mark.asyncio
async def test_server_tool_timeout_emits_504_and_continues(monkeypatch):
async def _slow_execute(_call):
await asyncio.sleep(0.05)
return {
"tool_call_id": "call_slow",
"name": "weather",
"output": {"ok": True},
"status": {"code": 200, "message": "ok"},
}
monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute)
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_slow",
"executor": "server",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="timeout fallback."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("weather?")
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
assert tool_results
payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 504
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "timeout fallback" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_eou_early_return_clears_stale_asr_capture(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
await pipeline.conversation.set_state(ConversationState.PROCESSING)
pipeline._asr_capture_active = True
pipeline._asr_capture_started_ms = 1234.0
pipeline._pending_speech_audio = b"stale"
await pipeline._on_end_of_utterance()
assert pipeline._asr_capture_active is False
assert pipeline._asr_capture_started_ms == 0.0
assert pipeline._pending_speech_audio == b""