import asyncio import json from typing import Any, Dict, List import pytest from core.conversation import ConversationState from core.duplex_pipeline import DuplexPipeline from models.ws_v1 import ToolCallResultsMessage, parse_client_message from services.base import LLMStreamEvent class _DummySileroVAD: def __init__(self, *args, **kwargs): pass def process_audio(self, _pcm: bytes) -> float: return 0.0 class _DummyVADProcessor: def __init__(self, *args, **kwargs): pass def process(self, _speech_prob: float): return "Silence", 0.0 class _DummyEouDetector: def __init__(self, *args, **kwargs): pass def process(self, _vad_status: str) -> bool: return False def reset(self) -> None: return None class _FakeTransport: async def send_event(self, _event: Dict[str, Any]) -> None: return None async def send_audio(self, _audio: bytes) -> None: return None class _FakeTTS: async def synthesize_stream(self, _text: str): if False: yield None class _FakeASR: async def connect(self) -> None: return None class _FakeLLM: def __init__(self, rounds: List[List[LLMStreamEvent]]): self._rounds = rounds self._call_index = 0 async def generate(self, _messages, temperature=0.7, max_tokens=None): return "" async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): idx = self._call_index self._call_index += 1 events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")] for event in events: yield event class _CaptureGenerateLLM: def __init__(self, response: str): self.response = response self.messages: List[Any] = [] async def generate(self, messages, temperature=0.7, max_tokens=None): self.messages = list(messages) return self.response async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): yield LLMStreamEvent(type="done") def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD) monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor) monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector) pipeline = DuplexPipeline( transport=_FakeTransport(), session_id="s_test", llm_service=_FakeLLM(llm_rounds), tts_service=_FakeTTS(), asr_service=_FakeASR(), ) events: List[Dict[str, Any]] = [] async def _capture_event(event: Dict[str, Any], priority: int = 20): events.append(event) async def _noop_speak(_text: str, *args, **kwargs): return None monkeypatch.setattr(pipeline, "_send_event", _capture_event) monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak) return pipeline, events def test_pipeline_uses_default_tools_from_settings(monkeypatch): monkeypatch.setattr( "core.duplex_pipeline.settings.tools", [ "current_time", "calculator", { "name": "weather", "description": "Get weather by city", "parameters": { "type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"], }, "executor": "server", }, ], ) pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) cfg = pipeline.resolved_runtime_config() assert cfg["tools"]["allowlist"] == ["calculator", "current_time", "weather"] schemas = pipeline._resolved_tool_schemas() names = [s.get("function", {}).get("name") for s in schemas if isinstance(s, dict)] assert "current_time" in names assert "calculator" in names assert "weather" in names def test_pipeline_exposes_unknown_string_tools_with_fallback_schema(monkeypatch): monkeypatch.setattr("core.duplex_pipeline.settings.tools", ["custom_system_cmd"]) pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) schemas = pipeline._resolved_tool_schemas() tool_schema = next((s for s in schemas if s.get("function", {}).get("name") == "custom_system_cmd"), None) assert tool_schema is not None assert tool_schema.get("function", {}).get("parameters", {}).get("type") == "object" def test_pipeline_assigns_default_client_executor_for_system_string_tools(monkeypatch): monkeypatch.setattr("core.duplex_pipeline.settings.tools", ["increase_volume"]) pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) tool_call = { "type": "function", "function": {"name": "increase_volume", "arguments": "{}"}, } assert pipeline._tool_executor(tool_call) == "client" @pytest.mark.asyncio async def test_pipeline_applies_default_args_to_tool_call(monkeypatch): pipeline, _events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent( type="tool_call", tool_call={ "id": "call_defaults", "type": "function", "function": {"name": "weather", "arguments": "{}"}, }, ), LLMStreamEvent(type="done"), ], [LLMStreamEvent(type="done")], ], ) pipeline.apply_runtime_overrides( { "tools": [ { "type": "function", "executor": "server", "defaultArgs": {"city": "Hangzhou", "unit": "c"}, "function": { "name": "weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, }, } ] } ) captured: Dict[str, Any] = {} async def _server_exec(call: Dict[str, Any]) -> Dict[str, Any]: captured["call"] = call return { "tool_call_id": str(call.get("id") or ""), "name": "weather", "output": {"ok": True}, "status": {"code": 200, "message": "ok"}, } monkeypatch.setattr(pipeline, "_server_tool_executor", _server_exec) await pipeline._handle_turn("weather?") sent_call = captured.get("call") assert isinstance(sent_call, dict) args_raw = sent_call.get("function", {}).get("arguments") args = json.loads(args_raw) if isinstance(args_raw, str) else {} assert args.get("city") == "Hangzhou" assert args.get("unit") == "c" @pytest.mark.asyncio async def test_generated_opener_prompt_uses_system_prompt_only(monkeypatch): monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD) monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor) monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector) llm = _CaptureGenerateLLM("你好") pipeline = DuplexPipeline( transport=_FakeTransport(), session_id="s_generated_opener", llm_service=llm, tts_service=_FakeTTS(), asr_service=_FakeASR(), ) pipeline.conversation.system_prompt = "SYSTEM_PROMPT_ONLY" pipeline._runtime_greeting = "DEV_HINT_SHOULD_NOT_BE_USED" generated = await pipeline._generate_runtime_greeting() assert generated == "你好" assert len(llm.messages) == 2 user_prompt = llm.messages[1].content assert "SYSTEM_PROMPT_ONLY" in user_prompt assert "DEV_HINT_SHOULD_NOT_BE_USED" not in user_prompt assert "额外风格提示" not in user_prompt @pytest.mark.asyncio async def test_generated_opener_uses_tool_capable_turn_when_tools_available(monkeypatch): pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) pipeline.apply_runtime_overrides( { "generatedOpenerEnabled": True, "tools": [ { "type": "function", "executor": "client", "function": { "name": "text_msg_prompt", "description": "Show a prompt", "parameters": {"type": "object", "properties": {}}, }, } ], } ) called: Dict[str, Any] = {} waiter = asyncio.Event() async def _fake_handle_turn(user_text: str) -> None: called["user_text"] = user_text waiter.set() monkeypatch.setattr(pipeline, "_handle_turn", _fake_handle_turn) pipeline.conversation.greeting = "fallback greeting" await pipeline.emit_initial_greeting() await asyncio.wait_for(waiter.wait(), timeout=1.0) assert called.get("user_text") == "" @pytest.mark.asyncio async def test_ws_message_parses_tool_call_results(): msg = parse_client_message( { "type": "tool_call.results", "results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}], } ) assert isinstance(msg, ToolCallResultsMessage) assert msg.results[0].tool_call_id == "call_1" @pytest.mark.asyncio async def test_turn_without_tool_keeps_streaming(monkeypatch): pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent(type="text_delta", text="hello "), LLMStreamEvent(type="text_delta", text="world."), LLMStreamEvent(type="done"), ] ], ) await pipeline._handle_turn("hi") event_types = [e.get("type") for e in events] assert "assistant.response.delta" in event_types assert "assistant.response.final" in event_types assert "assistant.tool_call" not in event_types @pytest.mark.asyncio @pytest.mark.parametrize( "metadata", [ {"output": {"mode": "text"}}, {"services": {"tts": {"enabled": False}}}, ], ) async def test_text_output_mode_skips_audio_events(monkeypatch, metadata): pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent(type="text_delta", text="hello "), LLMStreamEvent(type="text_delta", text="world."), LLMStreamEvent(type="done"), ] ], ) pipeline.apply_runtime_overrides(metadata) await pipeline._handle_turn("hi") event_types = [e.get("type") for e in events] assert "assistant.response.delta" in event_types assert "assistant.response.final" in event_types assert "output.audio.start" not in event_types assert "output.audio.end" not in event_types @pytest.mark.asyncio async def test_turn_with_tool_call_then_results(monkeypatch): pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent(type="text_delta", text="let me check."), LLMStreamEvent( type="tool_call", tool_call={ "id": "call_ok", "executor": "client", "type": "function", "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, }, ), LLMStreamEvent(type="done"), ], [ LLMStreamEvent(type="text_delta", text="it's sunny."), LLMStreamEvent(type="done"), ], ], ) task = asyncio.create_task(pipeline._handle_turn("weather?")) for _ in range(200): if any(e.get("type") == "assistant.tool_call" for e in events): break await asyncio.sleep(0.005) await pipeline.handle_tool_call_results( [ { "tool_call_id": "call_ok", "name": "weather", "output": {"temp": 21}, "status": {"code": 200, "message": "ok"}, } ] ) await task assert any(e.get("type") == "assistant.tool_call" for e in events) finals = [e for e in events if e.get("type") == "assistant.response.final"] assert finals assert "it's sunny" in finals[-1].get("text", "") @pytest.mark.asyncio async def test_turn_with_tool_call_timeout(monkeypatch): pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent( type="tool_call", tool_call={ "id": "call_timeout", "executor": "client", "type": "function", "function": {"name": "search", "arguments": "{\"query\":\"x\"}"}, }, ), LLMStreamEvent(type="done"), ], [ LLMStreamEvent(type="text_delta", text="fallback answer."), LLMStreamEvent(type="done"), ], ], ) pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01 await pipeline._handle_turn("query") finals = [e for e in events if e.get("type") == "assistant.response.final"] assert finals assert "fallback answer" in finals[-1].get("text", "") @pytest.mark.asyncio async def test_duplicate_tool_results_are_ignored(monkeypatch): pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) await pipeline.handle_tool_call_results( [{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}] ) await pipeline.handle_tool_call_results( [{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}] ) result = await pipeline._wait_for_single_tool_result("call_dup") assert result.get("output", {}).get("value") == 1 @pytest.mark.asyncio async def test_server_calculator_emits_tool_result(monkeypatch): pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent( type="tool_call", tool_call={ "id": "call_calc", "executor": "server", "type": "function", "function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"}, }, ), LLMStreamEvent(type="done"), ], [ LLMStreamEvent(type="text_delta", text="done."), LLMStreamEvent(type="done"), ], ], ) await pipeline._handle_turn("calc") tool_results = [e for e in events if e.get("type") == "assistant.tool_result"] assert tool_results payload = tool_results[-1].get("result", {}) assert payload.get("status", {}).get("code") == 200 assert payload.get("output", {}).get("result") == 3 @pytest.mark.asyncio async def test_server_tool_timeout_emits_504_and_continues(monkeypatch): async def _slow_execute(_call): await asyncio.sleep(0.05) return { "tool_call_id": "call_slow", "name": "weather", "output": {"ok": True}, "status": {"code": 200, "message": "ok"}, } monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute) pipeline, events = _build_pipeline( monkeypatch, [ [ LLMStreamEvent( type="tool_call", tool_call={ "id": "call_slow", "executor": "server", "type": "function", "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, }, ), LLMStreamEvent(type="done"), ], [ LLMStreamEvent(type="text_delta", text="timeout fallback."), LLMStreamEvent(type="done"), ], ], ) pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01 await pipeline._handle_turn("weather?") tool_results = [e for e in events if e.get("type") == "assistant.tool_result"] assert tool_results payload = tool_results[-1].get("result", {}) assert payload.get("status", {}).get("code") == 504 finals = [e for e in events if e.get("type") == "assistant.response.final"] assert finals assert "timeout fallback" in finals[-1].get("text", "") @pytest.mark.asyncio async def test_eou_early_return_clears_stale_asr_capture(monkeypatch): pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) await pipeline.conversation.set_state(ConversationState.PROCESSING) pipeline._asr_capture_active = True pipeline._asr_capture_started_ms = 1234.0 pipeline._pending_speech_audio = b"stale" await pipeline._on_end_of_utterance() assert pipeline._asr_capture_active is False assert pipeline._asr_capture_started_ms == 0.0 assert pipeline._pending_speech_audio == b""