Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.
This commit is contained in:
@@ -21,7 +21,6 @@ from services.base import LLMMessage
|
||||
from models.ws_v1 import (
|
||||
parse_client_message,
|
||||
ev,
|
||||
HelloMessage,
|
||||
SessionStartMessage,
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
@@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"
|
||||
class WsSessionState(str, Enum):
|
||||
"""Protocol state machine for WS sessions."""
|
||||
|
||||
WAIT_HELLO = "wait_hello"
|
||||
WAIT_START = "wait_start"
|
||||
ACTIVE = "active"
|
||||
STOPPED = "stopped"
|
||||
@@ -57,7 +55,15 @@ class Session:
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_CLIENT_METADATA_OVERRIDES = {
|
||||
_METADATA_ALLOWED_TOP_LEVEL_KEYS = {
|
||||
"overrides",
|
||||
"dynamicVariables",
|
||||
"channel",
|
||||
"source",
|
||||
"history",
|
||||
"workflow", # explicitly ignored for this MVP protocol version
|
||||
}
|
||||
_METADATA_ALLOWED_OVERRIDE_KEYS = {
|
||||
"firstTurnMode",
|
||||
"greeting",
|
||||
"generatedOpenerEnabled",
|
||||
@@ -67,18 +73,22 @@ class Session:
|
||||
"knowledge",
|
||||
"knowledgeBaseId",
|
||||
"openerAudio",
|
||||
"dynamicVariables",
|
||||
"history",
|
||||
"userId",
|
||||
"assistantId",
|
||||
"source",
|
||||
"tools",
|
||||
}
|
||||
_CLIENT_METADATA_ID_KEYS = {
|
||||
_METADATA_FORBIDDEN_TOP_LEVEL_KEYS = {
|
||||
"assistantId",
|
||||
"appId",
|
||||
"app_id",
|
||||
"channel",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
"services",
|
||||
}
|
||||
_METADATA_FORBIDDEN_KEY_TOKENS = {
|
||||
"apikey",
|
||||
"token",
|
||||
"secret",
|
||||
"password",
|
||||
"authorization",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -87,6 +97,7 @@ class Session:
|
||||
transport: BaseTransport,
|
||||
use_duplex: bool = None,
|
||||
backend_gateway: Optional[Any] = None,
|
||||
assistant_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize session.
|
||||
@@ -99,6 +110,7 @@ class Session:
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
self._assistant_id = str(assistant_id or "").strip() or None
|
||||
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
||||
self._history_bridge = SessionHistoryBridge(
|
||||
history_writer=self._backend_gateway,
|
||||
@@ -121,9 +133,8 @@ class Session:
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # Legacy call state for /call/lists
|
||||
self.ws_state = WsSessionState.WAIT_HELLO
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
self._pipeline_started = False
|
||||
self.protocol_version: Optional[str] = None
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: str = self.TRACK_CONTROL
|
||||
@@ -137,7 +148,12 @@ class Session:
|
||||
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
logger.info(
|
||||
"Session {} created (duplex={}, assistant_id={})",
|
||||
self.id,
|
||||
self.use_duplex,
|
||||
self._assistant_id or "-",
|
||||
)
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
@@ -222,23 +238,19 @@ class Session:
|
||||
msg_type = message.type
|
||||
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||
|
||||
if isinstance(message, HelloMessage):
|
||||
await self._handle_hello(message)
|
||||
return
|
||||
|
||||
# All messages below require hello handshake first
|
||||
if self.ws_state == WsSessionState.WAIT_HELLO:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Expected hello message first",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, SessionStartMessage):
|
||||
await self._handle_session_start(message)
|
||||
return
|
||||
|
||||
# All messages below require session.start first
|
||||
if self.ws_state == WsSessionState.WAIT_START:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Expected session.start message first",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
# All messages below require active session
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
@@ -262,67 +274,56 @@ class Session:
|
||||
else:
|
||||
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||
|
||||
async def _handle_hello(self, message: HelloMessage) -> None:
|
||||
"""Handle initial hello/version negotiation."""
|
||||
if self.ws_state != WsSessionState.WAIT_HELLO:
|
||||
await self._send_error("client", "Duplicate hello", "protocol.order")
|
||||
return
|
||||
|
||||
if message.version != settings.ws_protocol_version:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Unsupported protocol version '{message.version}'",
|
||||
"protocol.version_unsupported",
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self._send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||
"""Handle explicit session start after successful hello."""
|
||||
"""Handle explicit session start."""
|
||||
if self.ws_state != WsSessionState.WAIT_START:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
raw_metadata = message.metadata or {}
|
||||
workflow_runtime = self._bootstrap_workflow(raw_metadata)
|
||||
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
|
||||
client_runtime = self._sanitize_client_metadata(raw_metadata)
|
||||
requested_assistant_id = (
|
||||
workflow_runtime.get("assistantId")
|
||||
or raw_metadata.get("assistantId")
|
||||
or raw_metadata.get("appId")
|
||||
or raw_metadata.get("app_id")
|
||||
)
|
||||
if server_runtime:
|
||||
logger.info(
|
||||
"Session {} loaded trusted runtime config from backend "
|
||||
"(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})",
|
||||
self.id,
|
||||
requested_assistant_id,
|
||||
server_runtime.get("assistantId"),
|
||||
server_runtime.get("configVersionId") or server_runtime.get("config_version_id"),
|
||||
isinstance(server_runtime.get("services"), dict),
|
||||
if not self._assistant_id:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Missing required query parameter assistant_id",
|
||||
"protocol.assistant_id_required",
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Session {} missing trusted backend runtime config "
|
||||
"(requested_assistant_id={}); falling back to engine defaults + safe client overrides",
|
||||
self.id,
|
||||
requested_assistant_id,
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata)
|
||||
if metadata_error:
|
||||
await self._send_error(
|
||||
"client",
|
||||
metadata_error["message"],
|
||||
metadata_error["code"],
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
|
||||
metadata = self._merge_runtime_metadata(metadata, client_runtime)
|
||||
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id)
|
||||
if runtime_error:
|
||||
await self._send_error(
|
||||
"server",
|
||||
runtime_error["message"],
|
||||
runtime_error["code"],
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {}))
|
||||
for key in ("channel", "source", "history"):
|
||||
if key in sanitized_metadata:
|
||||
metadata[key] = sanitized_metadata[key]
|
||||
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata)
|
||||
if dynamic_var_error:
|
||||
await self._send_error(
|
||||
"client",
|
||||
@@ -331,6 +332,8 @@ class Session:
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
@@ -348,7 +351,7 @@ class Session:
|
||||
"(assistantId={}, configVersionId={}, output_mode={}, "
|
||||
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
|
||||
self.id,
|
||||
metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"),
|
||||
metadata.get("assistantId"),
|
||||
metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
|
||||
llm_cfg.get("provider"),
|
||||
@@ -387,24 +390,6 @@ class Session:
|
||||
config=self._build_config_resolved(metadata),
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
nodeType=self._workflow_initial_node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Emit opener only after frontend has received session.started/config events.
|
||||
await self.pipeline.emit_initial_greeting()
|
||||
@@ -658,7 +643,7 @@ class Session:
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("workflow."):
|
||||
return "system"
|
||||
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
|
||||
if event_type.startswith("session.") or event_type == "heartbeat":
|
||||
return "system"
|
||||
if event_type == "error":
|
||||
return "system"
|
||||
@@ -931,26 +916,41 @@ class Session:
|
||||
|
||||
async def _load_server_runtime_metadata(
|
||||
self,
|
||||
client_metadata: Dict[str, Any],
|
||||
workflow_runtime: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
assistant_id: str,
|
||||
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
|
||||
"""Load trusted runtime metadata from backend assistant config."""
|
||||
assistant_id = (
|
||||
workflow_runtime.get("assistantId")
|
||||
or client_metadata.get("assistantId")
|
||||
or client_metadata.get("appId")
|
||||
or client_metadata.get("app_id")
|
||||
)
|
||||
if assistant_id is None:
|
||||
return {}
|
||||
if not assistant_id:
|
||||
return {}, {
|
||||
"code": "protocol.assistant_id_required",
|
||||
"message": "Missing required query parameter assistant_id",
|
||||
}
|
||||
|
||||
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
||||
if not callable(provider):
|
||||
return {}
|
||||
return {}, {
|
||||
"code": "assistant.config_unavailable",
|
||||
"message": "Assistant config backend unavailable",
|
||||
}
|
||||
|
||||
payload = await provider(str(assistant_id).strip())
|
||||
if isinstance(payload, dict):
|
||||
error_code = str(payload.get("__error_code") or "").strip()
|
||||
if error_code == "assistant.not_found":
|
||||
return {}, {
|
||||
"code": "assistant.not_found",
|
||||
"message": f"Assistant not found: {assistant_id}",
|
||||
}
|
||||
if error_code == "assistant.config_unavailable":
|
||||
return {}, {
|
||||
"code": "assistant.config_unavailable",
|
||||
"message": f"Assistant config unavailable: {assistant_id}",
|
||||
}
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
return {}, {
|
||||
"code": "assistant.config_unavailable",
|
||||
"message": f"Assistant config unavailable: {assistant_id}",
|
||||
}
|
||||
|
||||
assistant_cfg: Dict[str, Any] = {}
|
||||
session_start_cfg = payload.get("sessionStartMetadata")
|
||||
@@ -962,7 +962,10 @@ class Session:
|
||||
assistant_cfg = payload
|
||||
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
return {}
|
||||
return {}, {
|
||||
"code": "assistant.config_unavailable",
|
||||
"message": f"Assistant config unavailable: {assistant_id}",
|
||||
}
|
||||
|
||||
runtime: Dict[str, Any] = {}
|
||||
passthrough_keys = {
|
||||
@@ -1009,36 +1012,110 @@ class Session:
|
||||
|
||||
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
||||
runtime["configVersionId"] = runtime.get("config_version_id")
|
||||
return runtime
|
||||
return runtime, None
|
||||
|
||||
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitize untrusted metadata sources.
|
||||
def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]:
|
||||
if isinstance(payload, dict):
|
||||
for key, value in payload.items():
|
||||
key_str = str(key)
|
||||
normalized = key_str.lower().replace("_", "").replace("-", "")
|
||||
if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS):
|
||||
return f"{path}.{key_str}"
|
||||
nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}")
|
||||
if nested:
|
||||
return nested
|
||||
return None
|
||||
if isinstance(payload, list):
|
||||
for idx, value in enumerate(payload):
|
||||
nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]")
|
||||
if nested:
|
||||
return nested
|
||||
return None
|
||||
|
||||
This keeps only a small override whitelist and stable config ID fields.
|
||||
"""
|
||||
def _validate_and_sanitize_client_metadata(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
|
||||
if not isinstance(metadata, dict):
|
||||
return {}
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": "metadata must be an object",
|
||||
}
|
||||
|
||||
sanitized: Dict[str, Any] = {}
|
||||
for key in self._CLIENT_METADATA_ID_KEYS:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
for key in self._CLIENT_METADATA_OVERRIDES:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata]
|
||||
if forbidden_top_level:
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}",
|
||||
}
|
||||
|
||||
return sanitized
|
||||
unknown_keys = [
|
||||
key
|
||||
for key in metadata.keys()
|
||||
if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS
|
||||
and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS
|
||||
]
|
||||
if unknown_keys:
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}",
|
||||
}
|
||||
|
||||
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply client metadata whitelist and remove forbidden secrets."""
|
||||
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
|
||||
if isinstance(metadata.get("services"), dict):
|
||||
logger.warning(
|
||||
"Session {} provided metadata.services from client; client-side service config is ignored",
|
||||
self.id,
|
||||
)
|
||||
return sanitized
|
||||
if "workflow" in metadata:
|
||||
logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id)
|
||||
|
||||
overrides_raw = metadata.get("overrides")
|
||||
overrides: Dict[str, Any] = {}
|
||||
if overrides_raw is not None:
|
||||
if not isinstance(overrides_raw, dict):
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": "metadata.overrides must be an object",
|
||||
}
|
||||
unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS]
|
||||
if unsupported_override_keys:
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}",
|
||||
}
|
||||
overrides = dict(overrides_raw)
|
||||
|
||||
dynamic_variables = metadata.get("dynamicVariables")
|
||||
history_raw = metadata.get("history")
|
||||
history: Dict[str, Any] = {}
|
||||
if history_raw is not None:
|
||||
if not isinstance(history_raw, dict):
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": "metadata.history must be an object",
|
||||
}
|
||||
unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"]
|
||||
if unsupported_history_keys:
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}",
|
||||
}
|
||||
if "userId" in history_raw:
|
||||
history["userId"] = history_raw.get("userId")
|
||||
|
||||
sanitized: Dict[str, Any] = {"overrides": overrides}
|
||||
if dynamic_variables is not None:
|
||||
sanitized["dynamicVariables"] = dynamic_variables
|
||||
if "channel" in metadata:
|
||||
sanitized["channel"] = metadata.get("channel")
|
||||
if "source" in metadata:
|
||||
sanitized["source"] = metadata.get("source")
|
||||
if history:
|
||||
sanitized["history"] = history
|
||||
|
||||
forbidden_path = self._find_forbidden_secret_key(sanitized)
|
||||
if forbidden_path:
|
||||
return {}, {
|
||||
"code": "protocol.invalid_override",
|
||||
"message": f"Forbidden secret-like key detected at {forbidden_path}",
|
||||
}
|
||||
|
||||
return sanitized, None
|
||||
|
||||
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build public resolved config payload (secrets removed)."""
|
||||
@@ -1047,7 +1124,7 @@ class Session:
|
||||
runtime = self.pipeline.resolved_runtime_config()
|
||||
|
||||
return {
|
||||
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
|
||||
"appId": metadata.get("assistantId"),
|
||||
"channel": metadata.get("channel"),
|
||||
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||
"prompt": {"sha256": prompt_hash},
|
||||
|
||||
Reference in New Issue
Block a user