220 lines
6.4 KiB
Python
220 lines
6.4 KiB
Python
import asyncio
|
|
from typing import Any, Dict, List
|
|
|
|
import pytest
|
|
|
|
from core.duplex_pipeline import DuplexPipeline
|
|
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
|
|
from services.base import LLMStreamEvent
|
|
|
|
|
|
class _DummySileroVAD:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def process_audio(self, _pcm: bytes) -> float:
|
|
return 0.0
|
|
|
|
|
|
class _DummyVADProcessor:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def process(self, _speech_prob: float):
|
|
return "Silence", 0.0
|
|
|
|
|
|
class _DummyEouDetector:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def process(self, _vad_status: str) -> bool:
|
|
return False
|
|
|
|
def reset(self) -> None:
|
|
return None
|
|
|
|
|
|
class _FakeTransport:
|
|
async def send_event(self, _event: Dict[str, Any]) -> None:
|
|
return None
|
|
|
|
async def send_audio(self, _audio: bytes) -> None:
|
|
return None
|
|
|
|
|
|
class _FakeTTS:
|
|
async def synthesize_stream(self, _text: str):
|
|
if False:
|
|
yield None
|
|
|
|
|
|
class _FakeASR:
|
|
async def connect(self) -> None:
|
|
return None
|
|
|
|
|
|
class _FakeLLM:
|
|
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
|
self._rounds = rounds
|
|
self._call_index = 0
|
|
|
|
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
|
idx = self._call_index
|
|
self._call_index += 1
|
|
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
|
|
for event in events:
|
|
yield event
|
|
|
|
|
|
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
|
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
|
|
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
|
|
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
|
|
|
|
pipeline = DuplexPipeline(
|
|
transport=_FakeTransport(),
|
|
session_id="s_test",
|
|
llm_service=_FakeLLM(llm_rounds),
|
|
tts_service=_FakeTTS(),
|
|
asr_service=_FakeASR(),
|
|
)
|
|
events: List[Dict[str, Any]] = []
|
|
|
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
|
events.append(event)
|
|
|
|
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
|
|
return None
|
|
|
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
|
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
|
|
return pipeline, events
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ws_message_parses_tool_call_results():
|
|
msg = parse_client_message(
|
|
{
|
|
"type": "tool_call.results",
|
|
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
|
}
|
|
)
|
|
assert isinstance(msg, ToolCallResultsMessage)
|
|
assert msg.results[0]["tool_call_id"] == "call_1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
|
pipeline, events = _build_pipeline(
|
|
monkeypatch,
|
|
[
|
|
[
|
|
LLMStreamEvent(type="text_delta", text="hello "),
|
|
LLMStreamEvent(type="text_delta", text="world."),
|
|
LLMStreamEvent(type="done"),
|
|
]
|
|
],
|
|
)
|
|
|
|
await pipeline._handle_turn("hi")
|
|
|
|
event_types = [e.get("type") for e in events]
|
|
assert "assistant.response.delta" in event_types
|
|
assert "assistant.response.final" in event_types
|
|
assert "assistant.tool_call" not in event_types
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_turn_with_tool_call_then_results(monkeypatch):
|
|
pipeline, events = _build_pipeline(
|
|
monkeypatch,
|
|
[
|
|
[
|
|
LLMStreamEvent(type="text_delta", text="let me check."),
|
|
LLMStreamEvent(
|
|
type="tool_call",
|
|
tool_call={
|
|
"id": "call_ok",
|
|
"type": "function",
|
|
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
|
},
|
|
),
|
|
LLMStreamEvent(type="done"),
|
|
],
|
|
[
|
|
LLMStreamEvent(type="text_delta", text="it's sunny."),
|
|
LLMStreamEvent(type="done"),
|
|
],
|
|
],
|
|
)
|
|
|
|
task = asyncio.create_task(pipeline._handle_turn("weather?"))
|
|
for _ in range(200):
|
|
if any(e.get("type") == "assistant.tool_call" for e in events):
|
|
break
|
|
await asyncio.sleep(0.005)
|
|
|
|
await pipeline.handle_tool_call_results(
|
|
[
|
|
{
|
|
"tool_call_id": "call_ok",
|
|
"name": "weather",
|
|
"output": {"temp": 21},
|
|
"status": {"code": 200, "message": "ok"},
|
|
}
|
|
]
|
|
)
|
|
await task
|
|
|
|
assert any(e.get("type") == "assistant.tool_call" for e in events)
|
|
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
|
assert finals
|
|
assert "it's sunny" in finals[-1].get("text", "")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_turn_with_tool_call_timeout(monkeypatch):
|
|
pipeline, events = _build_pipeline(
|
|
monkeypatch,
|
|
[
|
|
[
|
|
LLMStreamEvent(
|
|
type="tool_call",
|
|
tool_call={
|
|
"id": "call_timeout",
|
|
"type": "function",
|
|
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
|
|
},
|
|
),
|
|
LLMStreamEvent(type="done"),
|
|
],
|
|
[
|
|
LLMStreamEvent(type="text_delta", text="fallback answer."),
|
|
LLMStreamEvent(type="done"),
|
|
],
|
|
],
|
|
)
|
|
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
|
|
|
|
await pipeline._handle_turn("query")
|
|
|
|
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
|
assert finals
|
|
assert "fallback answer" in finals[-1].get("text", "")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_tool_results_are_ignored(monkeypatch):
|
|
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
|
|
|
await pipeline.handle_tool_call_results(
|
|
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
|
|
)
|
|
await pipeline.handle_tool_call_results(
|
|
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
|
|
)
|
|
result = await pipeline._wait_for_single_tool_result("call_dup")
|
|
|
|
assert result.get("output", {}).get("value") == 1
|