Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.

This commit is contained in:
Xin Wang
2026-03-01 14:10:38 +08:00
parent b4fa664d73
commit 6a46ec69f4
14 changed files with 725 additions and 424 deletions

View File

@@ -0,0 +1,238 @@
import pytest
from core.session import Session, WsSessionState
from models.ws_v1 import SessionStartMessage, parse_client_message
def _session() -> Session:
session = Session.__new__(Session)
session.id = "sess_test"
session._assistant_id = "assistant_demo"
return session
def test_parse_client_message_rejects_hello_message():
with pytest.raises(ValueError, match="Unknown client message type: hello"):
parse_client_message({"type": "hello", "version": "v1"})
@pytest.mark.asyncio
async def test_handle_text_reports_invalid_message_for_hello():
session = Session.__new__(Session)
session.id = "sess_invalid_hello"
session.ws_state = WsSessionState.WAIT_START
class _Transport:
async def close(self):
return None
session.transport = _Transport()
captured = []
async def _send_error(sender, message, code, **kwargs):
captured.append((sender, code, message, kwargs))
session._send_error = _send_error
await session.handle_text('{"type":"hello","version":"v1"}')
assert captured
sender, code, message, _ = captured[0]
assert sender == "client"
assert code == "protocol.invalid_message"
assert "Unknown client message type: hello" in message
def test_validate_metadata_rejects_services_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_rejects_secret_like_override_keys():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"overrides": {
"output": {"mode": "audio"},
"apiKey": "xxx",
}
}
)
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_ignores_workflow_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"workflow": {"nodes": [], "edges": []},
"channel": "web_debug",
"overrides": {"output": {"mode": "text"}},
}
)
assert error is None
assert "workflow" not in sanitized
assert sanitized["channel"] == "web_debug"
assert sanitized["overrides"]["output"]["mode"] == "text"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_not_found_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return {"__error_code": "assistant.not_found"}
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.not_found"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return None
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.config_unavailable"
@pytest.mark.asyncio
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
session = Session.__new__(Session)
session.id = "sess_missing_assistant"
session.ws_state = WsSessionState.WAIT_START
session._assistant_id = None
class _Transport:
def __init__(self):
self.closed = False
async def close(self):
self.closed = True
transport = _Transport()
session.transport = transport
captured_codes = []
async def _send_error(sender, message, code, **kwargs):
_ = (sender, message, kwargs)
captured_codes.append(code)
session._send_error = _send_error
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
assert captured_codes == ["protocol.assistant_id_required"]
assert transport.closed is True
assert session.ws_state == WsSessionState.STOPPED
@pytest.mark.asyncio
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow():
session = Session.__new__(Session)
session.id = "sess_start_ok"
session.ws_state = WsSessionState.WAIT_START
session.state = "created"
session._assistant_id = "assistant_demo"
session.current_track_id = Session.TRACK_CONTROL
session._pipeline_started = False
class _Transport:
async def close(self):
return None
class _Pipeline:
def __init__(self):
self.started = False
self.applied = {}
self.conversation = type("Conversation", (), {"system_prompt": ""})()
async def start(self):
self.started = True
async def emit_initial_greeting(self):
return None
def apply_runtime_overrides(self, metadata):
self.applied = dict(metadata)
def resolved_runtime_config(self):
return {
"output": {"mode": "text"},
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
"tools": {"allowlist": []},
}
session.transport = _Transport()
session.pipeline = _Pipeline()
events = []
async def _start_history_bridge(_metadata):
return None
async def _load_server_runtime_metadata(_assistant_id):
return (
{
"assistantId": "assistant_demo",
"configVersionId": "cfg_1",
"systemPrompt": "Base prompt",
"greeting": "Base greeting",
"output": {"mode": "audio"},
},
None,
)
async def _send_event(event):
events.append(event)
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._start_history_bridge = _start_history_bridge
session._load_server_runtime_metadata = _load_server_runtime_metadata
session._send_event = _send_event
session._send_error = _send_error
await session._handle_session_start(
SessionStartMessage(
type="session.start",
metadata={
"workflow": {"nodes": []},
"channel": "web_debug",
"source": "debug_ui",
"history": {"userId": 7},
"overrides": {
"greeting": "Override greeting",
"output": {"mode": "text"},
"tools": [{"name": "calculator"}],
},
},
)
)
assert session.ws_state == WsSessionState.ACTIVE
assert session.pipeline.started is True
assert session.pipeline.applied["assistantId"] == "assistant_demo"
assert session.pipeline.applied["greeting"] == "Override greeting"
assert session.pipeline.applied["output"]["mode"] == "text"
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
config_event = next(item for item in events if item.get("type") == "config.resolved")
assert config_event["config"]["appId"] == "assistant_demo"
assert config_event["config"]["channel"] == "web_debug"