Files
AI-VideoAssistant/engine/tests/test_ws_protocol_session_start.py
Xin Wang 6b589a1b7c Enhance session management and logging configuration
- Updated .env.example to clarify audio frame size validation and default codec settings.
- Refactored logging setup in main.py to support JSON serialization based on log format configuration.
- Improved session.py to dynamically compute audio frame bytes and include protocol version in session events.
- Added tests to validate session start events and audio frame handling based on chunk size settings.
2026-03-05 21:44:23 +08:00

416 lines
14 KiB
Python

import pytest
from core.session import Session, WsSessionState
from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message
def _session() -> Session:
session = Session.__new__(Session)
session.id = "sess_test"
session._assistant_id = "assistant_demo"
return session
def test_parse_client_message_rejects_hello_message():
with pytest.raises(ValueError, match="Unknown client message type: hello"):
parse_client_message({"type": "hello", "version": "v1"})
def test_parse_client_message_accepts_output_audio_played():
message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"})
assert isinstance(message, OutputAudioPlayedMessage)
assert message.tts_id == "tts_001"
def test_parse_client_message_rejects_output_audio_played_without_tts_id():
with pytest.raises(ValueError, match="tts_id"):
parse_client_message({"type": "output.audio.played", "tts_id": ""})
@pytest.mark.asyncio
async def test_handle_text_reports_invalid_message_for_hello():
session = Session.__new__(Session)
session.id = "sess_invalid_hello"
session.ws_state = WsSessionState.WAIT_START
class _Transport:
async def close(self):
return None
session.transport = _Transport()
captured = []
async def _send_error(sender, message, code, **kwargs):
captured.append((sender, code, message, kwargs))
session._send_error = _send_error
await session.handle_text('{"type":"hello","version":"v1"}')
assert captured
sender, code, message, _ = captured[0]
assert sender == "client"
assert code == "protocol.invalid_message"
assert "Unknown client message type: hello" in message
@pytest.mark.asyncio
async def test_handle_v1_message_routes_output_audio_played_to_pipeline():
session = Session.__new__(Session)
session.id = "sess_output_audio_played"
session.ws_state = WsSessionState.ACTIVE
received = {}
class _Pipeline:
async def handle_output_audio_played(self, **payload):
received.update(payload)
session.pipeline = _Pipeline()
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._send_error = _send_error
await session._handle_v1_message(
OutputAudioPlayedMessage(
type="output.audio.played",
tts_id="tts_001",
response_id="resp_001",
turn_id="turn_001",
played_at_ms=1730000018450,
played_ms=2520,
)
)
assert received == {
"tts_id": "tts_001",
"response_id": "resp_001",
"turn_id": "turn_001",
"played_at_ms": 1730000018450,
"played_ms": 2520,
}
def test_validate_metadata_rejects_services_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_rejects_secret_like_override_keys():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"overrides": {
"output": {"mode": "audio"},
"apiKey": "xxx",
}
}
)
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_ignores_workflow_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"workflow": {"nodes": [], "edges": []},
"channel": "web_debug",
"overrides": {"output": {"mode": "text"}},
}
)
assert error is None
assert "workflow" not in sanitized
assert sanitized["channel"] == "web_debug"
assert sanitized["overrides"]["output"]["mode"] == "text"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_not_found_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return {"__error_code": "assistant.not_found"}
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.not_found"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return None
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.config_unavailable"
@pytest.mark.asyncio
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
session = Session.__new__(Session)
session.id = "sess_missing_assistant"
session.ws_state = WsSessionState.WAIT_START
session._assistant_id = None
class _Transport:
def __init__(self):
self.closed = False
async def close(self):
self.closed = True
transport = _Transport()
session.transport = transport
captured_codes = []
async def _send_error(sender, message, code, **kwargs):
_ = (sender, message, kwargs)
captured_codes.append(code)
session._send_error = _send_error
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
assert captured_codes == ["protocol.assistant_id_required"]
assert transport.closed is True
assert session.ws_state == WsSessionState.STOPPED
@pytest.mark.asyncio
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow(monkeypatch):
monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", False)
session = Session.__new__(Session)
session.id = "sess_start_ok"
session.ws_state = WsSessionState.WAIT_START
session.state = "created"
session._assistant_id = "assistant_demo"
session.current_track_id = Session.TRACK_CONTROL
session._pipeline_started = False
class _Transport:
async def close(self):
return None
class _Pipeline:
def __init__(self):
self.started = False
self.applied = {}
self.conversation = type("Conversation", (), {"system_prompt": ""})()
async def start(self):
self.started = True
async def emit_initial_greeting(self):
return None
def apply_runtime_overrides(self, metadata):
self.applied = dict(metadata)
def resolved_runtime_config(self):
return {
"output": {"mode": "text"},
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
"tools": {"allowlist": ["calculator"]},
}
session.transport = _Transport()
session.pipeline = _Pipeline()
events = []
async def _start_history_bridge(_metadata):
return None
async def _load_server_runtime_metadata(_assistant_id):
return (
{
"assistantId": "assistant_demo",
"configVersionId": "cfg_1",
"systemPrompt": "Base prompt",
"greeting": "Base greeting",
"output": {"mode": "audio"},
},
None,
)
async def _send_event(event):
events.append(event)
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._start_history_bridge = _start_history_bridge
session._load_server_runtime_metadata = _load_server_runtime_metadata
session._send_event = _send_event
session._send_error = _send_error
await session._handle_session_start(
SessionStartMessage(
type="session.start",
metadata={
"workflow": {"nodes": []},
"channel": "web_debug",
"source": "debug_ui",
"history": {"userId": 7},
"overrides": {
"greeting": "Override greeting",
"output": {"mode": "text"},
"tools": [{"name": "calculator"}],
},
},
)
)
assert session.ws_state == WsSessionState.ACTIVE
assert session.pipeline.started is True
assert session.pipeline.applied["assistantId"] == "assistant_demo"
assert session.pipeline.applied["greeting"] == "Override greeting"
assert session.pipeline.applied["output"]["mode"] == "text"
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
assert not any(item.get("type") == "config.resolved" for item in events)
@pytest.mark.asyncio
async def test_handle_session_start_emits_config_resolved_when_enabled(monkeypatch):
monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", True)
monkeypatch.setattr("core.session.settings.ws_protocol_version", "v1-custom")
monkeypatch.setattr("core.session.settings.default_codec", "pcmu")
session = Session.__new__(Session)
session.id = "sess_start_emit_config"
session.ws_state = WsSessionState.WAIT_START
session.state = "created"
session._assistant_id = "assistant_demo"
session.current_track_id = Session.TRACK_CONTROL
session._pipeline_started = False
class _Transport:
async def close(self):
return None
class _Pipeline:
def __init__(self):
self.started = False
self.applied = {}
self.conversation = type("Conversation", (), {"system_prompt": ""})()
async def start(self):
self.started = True
async def emit_initial_greeting(self):
return None
def apply_runtime_overrides(self, metadata):
self.applied = dict(metadata)
def resolved_runtime_config(self):
return {
"output": {"mode": "text"},
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
"tools": {"allowlist": ["calculator"]},
}
session.transport = _Transport()
session.pipeline = _Pipeline()
events = []
async def _start_history_bridge(_metadata):
return None
async def _load_server_runtime_metadata(_assistant_id):
return (
{
"assistantId": "assistant_demo",
"configVersionId": "cfg_1",
"systemPrompt": "Base prompt",
"greeting": "Base greeting",
"output": {"mode": "audio"},
},
None,
)
async def _send_event(event):
events.append(event)
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._start_history_bridge = _start_history_bridge
session._load_server_runtime_metadata = _load_server_runtime_metadata
session._send_event = _send_event
session._send_error = _send_error
await session._handle_session_start(
SessionStartMessage(
type="session.start",
metadata={
"channel": "web_debug",
"overrides": {
"output": {"mode": "text"},
},
},
)
)
config_event = next(item for item in events if item.get("type") == "config.resolved")
session_started_event = next(item for item in events if item.get("type") == "session.started")
assert session_started_event["protocolVersion"] == "v1-custom"
assert "appId" not in config_event["config"]
assert "configVersionId" not in config_event["config"]
assert "services" not in config_event["config"]
assert config_event["config"]["protocolVersion"] == "v1-custom"
assert config_event["config"]["channel"] == "web_debug"
assert config_event["config"]["output"]["mode"] == "text"
assert config_event["config"]["output"]["codec"] == "pcmu"
assert config_event["config"]["tools"]["enabled"] is True
assert config_event["config"]["tools"]["count"] == 1
@pytest.mark.asyncio
async def test_handle_audio_uses_chunk_size_for_frame_validation(monkeypatch):
monkeypatch.setattr("core.session.settings.sample_rate", 16000)
monkeypatch.setattr("core.session.settings.chunk_size_ms", 10)
session = Session.__new__(Session)
session.id = "sess_chunk_frame"
session.ws_state = WsSessionState.ACTIVE
class _Pipeline:
def __init__(self):
self.frames = []
async def process_audio(self, frame: bytes):
self.frames.append(frame)
session.pipeline = _Pipeline()
errors = []
async def _send_error(sender, message, code, **kwargs):
_ = (sender, kwargs)
errors.append((code, message))
session._send_error = _send_error
payload = b"\x00\x01" * 320 # 640 bytes = 2 frames when chunk_size_ms=10
await session.handle_audio(payload)
assert errors == []
assert len(session.pipeline.frames) == 2
assert all(len(frame) == 320 for frame in session.pipeline.frames)