import pytest from core.session import Session, WsSessionState from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message def _session() -> Session: session = Session.__new__(Session) session.id = "sess_test" session._assistant_id = "assistant_demo" return session def test_parse_client_message_rejects_hello_message(): with pytest.raises(ValueError, match="Unknown client message type: hello"): parse_client_message({"type": "hello", "version": "v1"}) def test_parse_client_message_accepts_output_audio_played(): message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"}) assert isinstance(message, OutputAudioPlayedMessage) assert message.tts_id == "tts_001" def test_parse_client_message_rejects_output_audio_played_without_tts_id(): with pytest.raises(ValueError, match="tts_id"): parse_client_message({"type": "output.audio.played", "tts_id": ""}) @pytest.mark.asyncio async def test_handle_text_reports_invalid_message_for_hello(): session = Session.__new__(Session) session.id = "sess_invalid_hello" session.ws_state = WsSessionState.WAIT_START class _Transport: async def close(self): return None session.transport = _Transport() captured = [] async def _send_error(sender, message, code, **kwargs): captured.append((sender, code, message, kwargs)) session._send_error = _send_error await session.handle_text('{"type":"hello","version":"v1"}') assert captured sender, code, message, _ = captured[0] assert sender == "client" assert code == "protocol.invalid_message" assert "Unknown client message type: hello" in message @pytest.mark.asyncio async def test_handle_v1_message_routes_output_audio_played_to_pipeline(): session = Session.__new__(Session) session.id = "sess_output_audio_played" session.ws_state = WsSessionState.ACTIVE received = {} class _Pipeline: async def handle_output_audio_played(self, **payload): received.update(payload) session.pipeline = _Pipeline() async def _send_error(sender, message, code, **kwargs): raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") session._send_error = _send_error await session._handle_v1_message( OutputAudioPlayedMessage( type="output.audio.played", tts_id="tts_001", response_id="resp_001", turn_id="turn_001", played_at_ms=1730000018450, played_ms=2520, ) ) assert received == { "tts_id": "tts_001", "response_id": "resp_001", "turn_id": "turn_001", "played_at_ms": 1730000018450, "played_ms": 2520, } def test_validate_metadata_rejects_services_payload(): session = _session() sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}}) assert sanitized == {} assert error is not None assert error["code"] == "protocol.invalid_override" def test_validate_metadata_rejects_secret_like_override_keys(): session = _session() sanitized, error = session._validate_and_sanitize_client_metadata( { "overrides": { "output": {"mode": "audio"}, "apiKey": "xxx", } } ) assert sanitized == {} assert error is not None assert error["code"] == "protocol.invalid_override" def test_validate_metadata_ignores_workflow_payload(): session = _session() sanitized, error = session._validate_and_sanitize_client_metadata( { "workflow": {"nodes": [], "edges": []}, "channel": "web_debug", "overrides": {"output": {"mode": "text"}}, } ) assert error is None assert "workflow" not in sanitized assert sanitized["channel"] == "web_debug" assert sanitized["overrides"]["output"]["mode"] == "text" @pytest.mark.asyncio async def test_load_server_runtime_metadata_returns_not_found_error(): session = _session() class _Gateway: async def fetch_assistant_config(self, assistant_id: str): _ = assistant_id return {"__error_code": "assistant.not_found"} session._runtime_config_provider = _Gateway() runtime, error = await session._load_server_runtime_metadata("assistant_demo") assert runtime == {} assert error is not None assert error["code"] == "assistant.not_found" @pytest.mark.asyncio async def test_load_server_runtime_metadata_returns_config_unavailable_error(): session = _session() class _Gateway: async def fetch_assistant_config(self, assistant_id: str): _ = assistant_id return None session._runtime_config_provider = _Gateway() runtime, error = await session._load_server_runtime_metadata("assistant_demo") assert runtime == {} assert error is not None assert error["code"] == "assistant.config_unavailable" @pytest.mark.asyncio async def test_handle_session_start_requires_assistant_id_and_closes_transport(): session = Session.__new__(Session) session.id = "sess_missing_assistant" session.ws_state = WsSessionState.WAIT_START session._assistant_id = None class _Transport: def __init__(self): self.closed = False async def close(self): self.closed = True transport = _Transport() session.transport = transport captured_codes = [] async def _send_error(sender, message, code, **kwargs): _ = (sender, message, kwargs) captured_codes.append(code) session._send_error = _send_error await session._handle_session_start(SessionStartMessage(type="session.start", metadata={})) assert captured_codes == ["protocol.assistant_id_required"] assert transport.closed is True assert session.ws_state == WsSessionState.STOPPED @pytest.mark.asyncio async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow(monkeypatch): monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", False) session = Session.__new__(Session) session.id = "sess_start_ok" session.ws_state = WsSessionState.WAIT_START session.state = "created" session._assistant_id = "assistant_demo" session.current_track_id = Session.TRACK_CONTROL session._pipeline_started = False class _Transport: async def close(self): return None class _Pipeline: def __init__(self): self.started = False self.applied = {} self.conversation = type("Conversation", (), {"system_prompt": ""})() async def start(self): self.started = True async def emit_initial_greeting(self): return None def apply_runtime_overrides(self, metadata): self.applied = dict(metadata) def resolved_runtime_config(self): return { "output": {"mode": "text"}, "services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}}, "tools": {"allowlist": ["calculator"]}, } session.transport = _Transport() session.pipeline = _Pipeline() events = [] async def _start_history_bridge(_metadata): return None async def _load_server_runtime_metadata(_assistant_id): return ( { "assistantId": "assistant_demo", "configVersionId": "cfg_1", "systemPrompt": "Base prompt", "greeting": "Base greeting", "output": {"mode": "audio"}, }, None, ) async def _send_event(event): events.append(event) async def _send_error(sender, message, code, **kwargs): raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") session._start_history_bridge = _start_history_bridge session._load_server_runtime_metadata = _load_server_runtime_metadata session._send_event = _send_event session._send_error = _send_error await session._handle_session_start( SessionStartMessage( type="session.start", metadata={ "workflow": {"nodes": []}, "channel": "web_debug", "source": "debug_ui", "history": {"userId": 7}, "overrides": { "greeting": "Override greeting", "output": {"mode": "text"}, "tools": [{"name": "calculator"}], }, }, ) ) assert session.ws_state == WsSessionState.ACTIVE assert session.pipeline.started is True assert session.pipeline.applied["assistantId"] == "assistant_demo" assert session.pipeline.applied["greeting"] == "Override greeting" assert session.pipeline.applied["output"]["mode"] == "text" assert session.pipeline.applied["tools"] == [{"name": "calculator"}] assert not any(str(item.get("type", "")).startswith("workflow.") for item in events) assert not any(item.get("type") == "config.resolved" for item in events) @pytest.mark.asyncio async def test_handle_session_start_emits_config_resolved_when_enabled(monkeypatch): monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", True) monkeypatch.setattr("core.session.settings.ws_protocol_version", "v1-custom") monkeypatch.setattr("core.session.settings.default_codec", "pcmu") session = Session.__new__(Session) session.id = "sess_start_emit_config" session.ws_state = WsSessionState.WAIT_START session.state = "created" session._assistant_id = "assistant_demo" session.current_track_id = Session.TRACK_CONTROL session._pipeline_started = False class _Transport: async def close(self): return None class _Pipeline: def __init__(self): self.started = False self.applied = {} self.conversation = type("Conversation", (), {"system_prompt": ""})() async def start(self): self.started = True async def emit_initial_greeting(self): return None def apply_runtime_overrides(self, metadata): self.applied = dict(metadata) def resolved_runtime_config(self): return { "output": {"mode": "text"}, "services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}}, "tools": {"allowlist": ["calculator"]}, } session.transport = _Transport() session.pipeline = _Pipeline() events = [] async def _start_history_bridge(_metadata): return None async def _load_server_runtime_metadata(_assistant_id): return ( { "assistantId": "assistant_demo", "configVersionId": "cfg_1", "systemPrompt": "Base prompt", "greeting": "Base greeting", "output": {"mode": "audio"}, }, None, ) async def _send_event(event): events.append(event) async def _send_error(sender, message, code, **kwargs): raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") session._start_history_bridge = _start_history_bridge session._load_server_runtime_metadata = _load_server_runtime_metadata session._send_event = _send_event session._send_error = _send_error await session._handle_session_start( SessionStartMessage( type="session.start", metadata={ "channel": "web_debug", "overrides": { "output": {"mode": "text"}, }, }, ) ) config_event = next(item for item in events if item.get("type") == "config.resolved") session_started_event = next(item for item in events if item.get("type") == "session.started") assert session_started_event["protocolVersion"] == "v1-custom" assert "appId" not in config_event["config"] assert "configVersionId" not in config_event["config"] assert "services" not in config_event["config"] assert config_event["config"]["protocolVersion"] == "v1-custom" assert config_event["config"]["channel"] == "web_debug" assert config_event["config"]["output"]["mode"] == "text" assert config_event["config"]["output"]["codec"] == "pcmu" assert config_event["config"]["tools"]["enabled"] is True assert config_event["config"]["tools"]["count"] == 1 @pytest.mark.asyncio async def test_handle_audio_uses_chunk_size_for_frame_validation(monkeypatch): monkeypatch.setattr("core.session.settings.sample_rate", 16000) monkeypatch.setattr("core.session.settings.chunk_size_ms", 10) session = Session.__new__(Session) session.id = "sess_chunk_frame" session.ws_state = WsSessionState.ACTIVE class _Pipeline: def __init__(self): self.frames = [] async def process_audio(self, frame: bytes): self.frames.append(frame) session.pipeline = _Pipeline() errors = [] async def _send_error(sender, message, code, **kwargs): _ = (sender, kwargs) errors.append((code, message)) session._send_error = _send_error payload = b"\x00\x01" * 320 # 640 bytes = 2 frames when chunk_size_ms=10 await session.handle_audio(payload) assert errors == [] assert len(session.pipeline.frames) == 2 assert all(len(frame) == 320 for frame in session.pipeline.frames)