Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.
This commit is contained in:
238
engine/tests/test_ws_protocol_session_start.py
Normal file
238
engine/tests/test_ws_protocol_session_start.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import pytest
|
||||
|
||||
from core.session import Session, WsSessionState
|
||||
from models.ws_v1 import SessionStartMessage, parse_client_message
|
||||
|
||||
|
||||
def _session() -> Session:
|
||||
session = Session.__new__(Session)
|
||||
session.id = "sess_test"
|
||||
session._assistant_id = "assistant_demo"
|
||||
return session
|
||||
|
||||
|
||||
def test_parse_client_message_rejects_hello_message():
|
||||
with pytest.raises(ValueError, match="Unknown client message type: hello"):
|
||||
parse_client_message({"type": "hello", "version": "v1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_text_reports_invalid_message_for_hello():
|
||||
session = Session.__new__(Session)
|
||||
session.id = "sess_invalid_hello"
|
||||
session.ws_state = WsSessionState.WAIT_START
|
||||
|
||||
class _Transport:
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
session.transport = _Transport()
|
||||
captured = []
|
||||
|
||||
async def _send_error(sender, message, code, **kwargs):
|
||||
captured.append((sender, code, message, kwargs))
|
||||
|
||||
session._send_error = _send_error
|
||||
|
||||
await session.handle_text('{"type":"hello","version":"v1"}')
|
||||
assert captured
|
||||
sender, code, message, _ = captured[0]
|
||||
assert sender == "client"
|
||||
assert code == "protocol.invalid_message"
|
||||
assert "Unknown client message type: hello" in message
|
||||
|
||||
|
||||
def test_validate_metadata_rejects_services_payload():
|
||||
session = _session()
|
||||
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
|
||||
assert sanitized == {}
|
||||
assert error is not None
|
||||
assert error["code"] == "protocol.invalid_override"
|
||||
|
||||
|
||||
def test_validate_metadata_rejects_secret_like_override_keys():
|
||||
session = _session()
|
||||
sanitized, error = session._validate_and_sanitize_client_metadata(
|
||||
{
|
||||
"overrides": {
|
||||
"output": {"mode": "audio"},
|
||||
"apiKey": "xxx",
|
||||
}
|
||||
}
|
||||
)
|
||||
assert sanitized == {}
|
||||
assert error is not None
|
||||
assert error["code"] == "protocol.invalid_override"
|
||||
|
||||
|
||||
def test_validate_metadata_ignores_workflow_payload():
|
||||
session = _session()
|
||||
sanitized, error = session._validate_and_sanitize_client_metadata(
|
||||
{
|
||||
"workflow": {"nodes": [], "edges": []},
|
||||
"channel": "web_debug",
|
||||
"overrides": {"output": {"mode": "text"}},
|
||||
}
|
||||
)
|
||||
assert error is None
|
||||
assert "workflow" not in sanitized
|
||||
assert sanitized["channel"] == "web_debug"
|
||||
assert sanitized["overrides"]["output"]["mode"] == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_server_runtime_metadata_returns_not_found_error():
|
||||
session = _session()
|
||||
|
||||
class _Gateway:
|
||||
async def fetch_assistant_config(self, assistant_id: str):
|
||||
_ = assistant_id
|
||||
return {"__error_code": "assistant.not_found"}
|
||||
|
||||
session._backend_gateway = _Gateway()
|
||||
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
||||
assert runtime == {}
|
||||
assert error is not None
|
||||
assert error["code"] == "assistant.not_found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
|
||||
session = _session()
|
||||
|
||||
class _Gateway:
|
||||
async def fetch_assistant_config(self, assistant_id: str):
|
||||
_ = assistant_id
|
||||
return None
|
||||
|
||||
session._backend_gateway = _Gateway()
|
||||
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
||||
assert runtime == {}
|
||||
assert error is not None
|
||||
assert error["code"] == "assistant.config_unavailable"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
|
||||
session = Session.__new__(Session)
|
||||
session.id = "sess_missing_assistant"
|
||||
session.ws_state = WsSessionState.WAIT_START
|
||||
session._assistant_id = None
|
||||
|
||||
class _Transport:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
transport = _Transport()
|
||||
session.transport = transport
|
||||
captured_codes = []
|
||||
|
||||
async def _send_error(sender, message, code, **kwargs):
|
||||
_ = (sender, message, kwargs)
|
||||
captured_codes.append(code)
|
||||
|
||||
session._send_error = _send_error
|
||||
|
||||
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
|
||||
assert captured_codes == ["protocol.assistant_id_required"]
|
||||
assert transport.closed is True
|
||||
assert session.ws_state == WsSessionState.STOPPED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow():
|
||||
session = Session.__new__(Session)
|
||||
session.id = "sess_start_ok"
|
||||
session.ws_state = WsSessionState.WAIT_START
|
||||
session.state = "created"
|
||||
session._assistant_id = "assistant_demo"
|
||||
session.current_track_id = Session.TRACK_CONTROL
|
||||
session._pipeline_started = False
|
||||
|
||||
class _Transport:
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self):
|
||||
self.started = False
|
||||
self.applied = {}
|
||||
self.conversation = type("Conversation", (), {"system_prompt": ""})()
|
||||
|
||||
async def start(self):
|
||||
self.started = True
|
||||
|
||||
async def emit_initial_greeting(self):
|
||||
return None
|
||||
|
||||
def apply_runtime_overrides(self, metadata):
|
||||
self.applied = dict(metadata)
|
||||
|
||||
def resolved_runtime_config(self):
|
||||
return {
|
||||
"output": {"mode": "text"},
|
||||
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
|
||||
"tools": {"allowlist": []},
|
||||
}
|
||||
|
||||
session.transport = _Transport()
|
||||
session.pipeline = _Pipeline()
|
||||
events = []
|
||||
|
||||
async def _start_history_bridge(_metadata):
|
||||
return None
|
||||
|
||||
async def _load_server_runtime_metadata(_assistant_id):
|
||||
return (
|
||||
{
|
||||
"assistantId": "assistant_demo",
|
||||
"configVersionId": "cfg_1",
|
||||
"systemPrompt": "Base prompt",
|
||||
"greeting": "Base greeting",
|
||||
"output": {"mode": "audio"},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
async def _send_event(event):
|
||||
events.append(event)
|
||||
|
||||
async def _send_error(sender, message, code, **kwargs):
|
||||
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
||||
|
||||
session._start_history_bridge = _start_history_bridge
|
||||
session._load_server_runtime_metadata = _load_server_runtime_metadata
|
||||
session._send_event = _send_event
|
||||
session._send_error = _send_error
|
||||
|
||||
await session._handle_session_start(
|
||||
SessionStartMessage(
|
||||
type="session.start",
|
||||
metadata={
|
||||
"workflow": {"nodes": []},
|
||||
"channel": "web_debug",
|
||||
"source": "debug_ui",
|
||||
"history": {"userId": 7},
|
||||
"overrides": {
|
||||
"greeting": "Override greeting",
|
||||
"output": {"mode": "text"},
|
||||
"tools": [{"name": "calculator"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert session.ws_state == WsSessionState.ACTIVE
|
||||
assert session.pipeline.started is True
|
||||
assert session.pipeline.applied["assistantId"] == "assistant_demo"
|
||||
assert session.pipeline.applied["greeting"] == "Override greeting"
|
||||
assert session.pipeline.applied["output"]["mode"] == "text"
|
||||
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
|
||||
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
|
||||
|
||||
config_event = next(item for item in events if item.get("type") == "config.resolved")
|
||||
assert config_event["config"]["appId"] == "assistant_demo"
|
||||
assert config_event["config"]["channel"] == "web_debug"
|
||||
Reference in New Issue
Block a user