Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.

This commit is contained in:
Xin Wang
2026-03-01 14:10:38 +08:00
parent b4fa664d73
commit 6a46ec69f4
14 changed files with 725 additions and 424 deletions

View File

@@ -21,7 +21,6 @@ from services.base import LLMMessage
from models.ws_v1 import (
parse_client_message,
ev,
HelloMessage,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
@@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
@@ -57,7 +55,15 @@ class Session:
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
_METADATA_ALLOWED_TOP_LEVEL_KEYS = {
"overrides",
"dynamicVariables",
"channel",
"source",
"history",
"workflow", # explicitly ignored for this MVP protocol version
}
_METADATA_ALLOWED_OVERRIDE_KEYS = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
@@ -67,18 +73,22 @@ class Session:
"knowledge",
"knowledgeBaseId",
"openerAudio",
"dynamicVariables",
"history",
"userId",
"assistantId",
"source",
"tools",
}
_CLIENT_METADATA_ID_KEYS = {
_METADATA_FORBIDDEN_TOP_LEVEL_KEYS = {
"assistantId",
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
"services",
}
_METADATA_FORBIDDEN_KEY_TOKENS = {
"apikey",
"token",
"secret",
"password",
"authorization",
}
def __init__(
@@ -87,6 +97,7 @@ class Session:
transport: BaseTransport,
use_duplex: bool = None,
backend_gateway: Optional[Any] = None,
assistant_id: Optional[str] = None,
):
"""
Initialize session.
@@ -99,6 +110,7 @@ class Session:
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self._assistant_id = str(assistant_id or "").strip() or None
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
self._history_bridge = SessionHistoryBridge(
history_writer=self._backend_gateway,
@@ -121,9 +133,8 @@ class Session:
# Session state
self.created_at = None
self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO
self.ws_state = WsSessionState.WAIT_START
self._pipeline_started = False
self.protocol_version: Optional[str] = None
# Track IDs
self.current_track_id: str = self.TRACK_CONTROL
@@ -137,7 +148,12 @@ class Session:
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
logger.info(
"Session {} created (duplex={}, assistant_id={})",
self.id,
self.use_duplex,
self._assistant_id or "-",
)
async def handle_text(self, text_data: str) -> None:
"""
@@ -222,23 +238,19 @@ class Session:
msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, HelloMessage):
await self._handle_hello(message)
return
# All messages below require hello handshake first
if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require session.start first
if self.ws_state == WsSessionState.WAIT_START:
await self._send_error(
"client",
"Expected session.start message first",
"protocol.order",
)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
@@ -262,67 +274,56 @@ class Session:
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self._send_event(
ev(
"hello.ack",
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello."""
"""Handle explicit session start."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
requested_assistant_id = (
workflow_runtime.get("assistantId")
or raw_metadata.get("assistantId")
or raw_metadata.get("appId")
or raw_metadata.get("app_id")
)
if server_runtime:
logger.info(
"Session {} loaded trusted runtime config from backend "
"(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})",
self.id,
requested_assistant_id,
server_runtime.get("assistantId"),
server_runtime.get("configVersionId") or server_runtime.get("config_version_id"),
isinstance(server_runtime.get("services"), dict),
if not self._assistant_id:
await self._send_error(
"client",
"Missing required query parameter assistant_id",
"protocol.assistant_id_required",
stage="protocol",
retryable=False,
)
else:
logger.warning(
"Session {} missing trusted backend runtime config "
"(requested_assistant_id={}); falling back to engine defaults + safe client overrides",
self.id,
requested_assistant_id,
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata)
if metadata_error:
await self._send_error(
"client",
metadata_error["message"],
metadata_error["code"],
stage="protocol",
retryable=False,
)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id)
if runtime_error:
await self._send_error(
"server",
runtime_error["message"],
runtime_error["code"],
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {}))
for key in ("channel", "source", "history"):
if key in sanitized_metadata:
metadata[key] = sanitized_metadata[key]
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata)
if dynamic_var_error:
await self._send_error(
"client",
@@ -331,6 +332,8 @@ class Session:
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
# Create history call record early so later turn callbacks can append transcripts.
@@ -348,7 +351,7 @@ class Session:
"(assistantId={}, configVersionId={}, output_mode={}, "
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
self.id,
metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"),
metadata.get("assistantId"),
metadata.get("configVersionId") or metadata.get("config_version_id"),
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
llm_cfg.get("provider"),
@@ -387,24 +390,6 @@ class Session:
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self._send_event(
ev(
"workflow.started",
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self._send_event(
ev(
"workflow.node.entered",
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
# Emit opener only after frontend has received session.started/config events.
await self.pipeline.emit_initial_greeting()
@@ -658,7 +643,7 @@ class Session:
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
if event_type.startswith("session.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
@@ -931,26 +916,41 @@ class Session:
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
assistant_id: str,
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
if not assistant_id:
return {}, {
"code": "protocol.assistant_id_required",
"message": "Missing required query parameter assistant_id",
}
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
if not callable(provider):
return {}
return {}, {
"code": "assistant.config_unavailable",
"message": "Assistant config backend unavailable",
}
payload = await provider(str(assistant_id).strip())
if isinstance(payload, dict):
error_code = str(payload.get("__error_code") or "").strip()
if error_code == "assistant.not_found":
return {}, {
"code": "assistant.not_found",
"message": f"Assistant not found: {assistant_id}",
}
if error_code == "assistant.config_unavailable":
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
if not isinstance(payload, dict):
return {}
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata")
@@ -962,7 +962,10 @@ class Session:
assistant_cfg = payload
if not isinstance(assistant_cfg, dict):
return {}
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
runtime: Dict[str, Any] = {}
passthrough_keys = {
@@ -1009,36 +1012,110 @@ class Session:
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id")
return runtime
return runtime, None
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]:
if isinstance(payload, dict):
for key, value in payload.items():
key_str = str(key)
normalized = key_str.lower().replace("_", "").replace("-", "")
if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS):
return f"{path}.{key_str}"
nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}")
if nested:
return nested
return None
if isinstance(payload, list):
for idx, value in enumerate(payload):
nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]")
if nested:
return nested
return None
This keeps only a small override whitelist and stable config ID fields.
"""
def _validate_and_sanitize_client_metadata(
self,
metadata: Dict[str, Any],
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
if not isinstance(metadata, dict):
return {}
return {}, {
"code": "protocol.invalid_override",
"message": "metadata must be an object",
}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata]
if forbidden_top_level:
return {}, {
"code": "protocol.invalid_override",
"message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}",
}
return sanitized
unknown_keys = [
key
for key in metadata.keys()
if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS
and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS
]
if unknown_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}",
}
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
if "workflow" in metadata:
logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id)
overrides_raw = metadata.get("overrides")
overrides: Dict[str, Any] = {}
if overrides_raw is not None:
if not isinstance(overrides_raw, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata.overrides must be an object",
}
unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS]
if unsupported_override_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}",
}
overrides = dict(overrides_raw)
dynamic_variables = metadata.get("dynamicVariables")
history_raw = metadata.get("history")
history: Dict[str, Any] = {}
if history_raw is not None:
if not isinstance(history_raw, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata.history must be an object",
}
unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"]
if unsupported_history_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}",
}
if "userId" in history_raw:
history["userId"] = history_raw.get("userId")
sanitized: Dict[str, Any] = {"overrides": overrides}
if dynamic_variables is not None:
sanitized["dynamicVariables"] = dynamic_variables
if "channel" in metadata:
sanitized["channel"] = metadata.get("channel")
if "source" in metadata:
sanitized["source"] = metadata.get("source")
if history:
sanitized["history"] = history
forbidden_path = self._find_forbidden_secret_key(sanitized)
if forbidden_path:
return {}, {
"code": "protocol.invalid_override",
"message": f"Forbidden secret-like key detected at {forbidden_path}",
}
return sanitized, None
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
@@ -1047,7 +1124,7 @@ class Session:
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"appId": metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},