From 6a46ec69f4a03b075fdc18234d74578768f8b027 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Sun, 1 Mar 2026 14:10:38 +0800 Subject: [PATCH] Enhance WebSocket session management by requiring `assistant_id` as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process. --- docs/content/api-reference.md | 86 +++-- engine/app/backend_adapters.py | 6 +- engine/app/main.py | 20 +- engine/core/session.py | 353 +++++++++++------- engine/docs/ws_v1_schema.md | 58 ++- engine/examples/mic_client.py | 42 +-- engine/examples/simple_client.py | 36 +- engine/examples/test_websocket.py | 17 +- engine/examples/wav_client.py | 42 +-- engine/examples/web_client.html | 35 +- engine/models/ws_v1.py | 7 - engine/tests/test_tool_call_flow.py | 2 +- .../tests/test_ws_protocol_session_start.py | 238 ++++++++++++ web/pages/Assistants.tsx | 207 +++++----- 14 files changed, 725 insertions(+), 424 deletions(-) create mode 100644 engine/tests/test_ws_protocol_session_start.py diff --git a/docs/content/api-reference.md b/docs/content/api-reference.md index 20b6fa1..d1a4366 100644 --- a/docs/content/api-reference.md +++ b/docs/content/api-reference.md @@ -11,9 +11,11 @@ WebSocket 端点提供双向实时语音对话能力,支持音频流输入输 ### 连接地址 ``` -ws:///ws +ws:///ws?assistant_id= ``` +- `assistant_id` 为必填 query 参数,用于从数据库加载该助手的运行时配置。 + ### 传输规则 - **文本帧**:JSON 格式控制消息 @@ -25,8 +27,6 @@ ws:///ws ### 消息流程 ``` -Client -> hello -Server <- hello.ack Client -> session.start Server <- session.started Server <- config.resolved @@ -44,27 +44,9 @@ Server <- session.stopped ### 客户端 -> 服务端消息 -#### 1. Handshake: `hello` +#### 1. Session Start: `session.start` -客户端连接后发送的第一个消息,用于协议版本协商。 - -```json -{ - "type": "hello", - "version": "v1" -} -``` - -| 字段 | 类型 | 必填 | 说明 | -|---|---|---|---| -| `type` | string | 是 | 固定为 `"hello"` | -| `version` | string | 是 | 协议版本,固定为 `"v1"` | - ---- - -#### 2. Session Start: `session.start` - -握手成功后发送的第二个消息,用于启动对话会话。 +客户端连接后发送的第一个消息,用于启动对话会话。 ```json { @@ -75,13 +57,17 @@ Server <- session.stopped "channels": 1 }, "metadata": { - "appId": "assistant_123", "channel": "web", - "configVersionId": "cfg_20260217_01", - "systemPrompt": "你是简洁助手", - "greeting": "你好,我能帮你什么?", - "output": { - "mode": "audio" + "source": "web_debug", + "history": { + "userId": 1 + }, + "overrides": { + "systemPrompt": "你是简洁助手", + "greeting": "你好,我能帮你什么?", + "output": { + "mode": "audio" + } }, "dynamicVariables": { "customer_name": "Alice", @@ -101,17 +87,33 @@ Server <- session.stopped | `metadata` | object | 否 | 运行时配置 | **metadata 支持的字段**: -- `appId` / `app_id` - 应用 ID - `channel` - 渠道标识 -- `configVersionId` / `config_version_id` - 配置版本 -- `systemPrompt` - 系统提示词 -- `greeting` - 开场白 -- `output.mode` - 输出模式 (`audio` / `text`) +- `source` - 来源标识 +- `history.userId` - 历史记录用户 ID +- `overrides` - 可覆盖字段(仅限安全白名单) - `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符) +**`metadata.overrides` 白名单字段**: +- `systemPrompt` +- `greeting` +- `firstTurnMode` +- `generatedOpenerEnabled` +- `output` +- `bargeIn` +- `knowledgeBaseId` +- `knowledge` +- `tools` +- `openerAudio` + +**限制**: +- `metadata.workflow` 会被忽略(不触发 workflow 事件) +- 禁止提交 `metadata.services` +- 禁止提交 `assistantId` / `appId` / `app_id` / `configVersionId` / `config_version_id` +- 禁止提交包含密钥语义的字段(如 `apiKey` / `token` / `secret` / `password` / `authorization`) + --- -#### 3. Text Input: `input.text` +#### 2. Text Input: `input.text` 发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。 @@ -129,7 +131,7 @@ Server <- session.stopped --- -#### 4. Response Cancel: `response.cancel` +#### 3. Response Cancel: `response.cancel` 请求中断当前回答。 @@ -147,7 +149,7 @@ Server <- session.stopped --- -#### 5. Tool Call Results: `tool_call.results` +#### 4. Tool Call Results: `tool_call.results` 回传客户端执行的工具结果。 @@ -176,7 +178,7 @@ Server <- session.stopped --- -#### 6. Session Stop: `session.stop` +#### 5. Session Stop: `session.stop` 结束对话会话。 @@ -194,7 +196,7 @@ Server <- session.stopped --- -#### 7. Binary Audio +#### 6. Binary Audio 在 `session.started` 之后可持续发送二进制 PCM 音频。 @@ -239,7 +241,6 @@ Server <- session.stopped | 事件 | 说明 | |---|---| -| `hello.ack` | 握手成功响应 | | `session.started` | 会话启动成功 | | `config.resolved` | 服务端最终配置快照 | | `heartbeat` | 保活心跳(默认 50 秒间隔) | @@ -281,7 +282,10 @@ Server <- session.stopped | `protocol.invalid_json` | JSON 格式错误 | | `protocol.invalid_message` | 消息格式错误 | | `protocol.order` | 消息顺序错误 | -| `protocol.version_unsupported` | 协议版本不支持 | +| `protocol.assistant_id_required` | 缺少 `assistant_id` query 参数 | +| `protocol.invalid_override` | metadata 覆盖字段不合法 | +| `assistant.not_found` | assistant 不存在 | +| `assistant.config_unavailable` | assistant 配置不可用 | | `audio.invalid_pcm` | PCM 数据无效 | | `audio.frame_size_mismatch` | 音频帧大小不匹配 | | `server.internal` | 服务端内部错误 | diff --git a/engine/app/backend_adapters.py b/engine/app/backend_adapters.py index a05bd8f..6ff2716 100644 --- a/engine/app/backend_adapters.py +++ b/engine/app/backend_adapters.py @@ -157,16 +157,16 @@ class HttpBackendAdapter: async with session.get(url) as resp: if resp.status == 404: logger.warning(f"Assistant config not found: {assistant_id}") - return None + return {"__error_code": "assistant.not_found", "assistantId": assistant_id} resp.raise_for_status() payload = await resp.json() if not isinstance(payload, dict): logger.warning("Assistant config payload is not a dict; ignoring") - return None + return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id} return payload except Exception as exc: logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") - return None + return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id} async def create_call_record( self, diff --git a/engine/app/main.py b/engine/app/main.py index c13daba..b8a39bb 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -163,13 +163,19 @@ async def websocket_endpoint(websocket: WebSocket): """ await websocket.accept() session_id = str(uuid.uuid4()) + assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None # Create transport and session transport = SocketTransport(websocket) - session = Session(session_id, transport, backend_gateway=backend_gateway) + session = Session( + session_id, + transport, + backend_gateway=backend_gateway, + assistant_id=assistant_id, + ) active_sessions[session_id] = session - logger.info(f"WebSocket connection established: {session_id}") + logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] @@ -239,16 +245,22 @@ async def webrtc_endpoint(websocket: WebSocket): return await websocket.accept() session_id = str(uuid.uuid4()) + assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None # Create WebRTC peer connection pc = RTCPeerConnection() # Create transport and session transport = WebRtcTransport(websocket, pc) - session = Session(session_id, transport, backend_gateway=backend_gateway) + session = Session( + session_id, + transport, + backend_gateway=backend_gateway, + assistant_id=assistant_id, + ) active_sessions[session_id] = session - logger.info(f"WebRTC connection established: {session_id}") + logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] diff --git a/engine/core/session.py b/engine/core/session.py index dda0382..ac365ce 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -21,7 +21,6 @@ from services.base import LLMMessage from models.ws_v1 import ( parse_client_message, ev, - HelloMessage, SessionStartMessage, SessionStopMessage, InputTextMessage, @@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone" class WsSessionState(str, Enum): """Protocol state machine for WS sessions.""" - WAIT_HELLO = "wait_hello" WAIT_START = "wait_start" ACTIVE = "active" STOPPED = "stopped" @@ -57,7 +55,15 @@ class Session: TRACK_AUDIO_OUT = "audio_out" TRACK_CONTROL = "control" AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms - _CLIENT_METADATA_OVERRIDES = { + _METADATA_ALLOWED_TOP_LEVEL_KEYS = { + "overrides", + "dynamicVariables", + "channel", + "source", + "history", + "workflow", # explicitly ignored for this MVP protocol version + } + _METADATA_ALLOWED_OVERRIDE_KEYS = { "firstTurnMode", "greeting", "generatedOpenerEnabled", @@ -67,18 +73,22 @@ class Session: "knowledge", "knowledgeBaseId", "openerAudio", - "dynamicVariables", - "history", - "userId", - "assistantId", - "source", + "tools", } - _CLIENT_METADATA_ID_KEYS = { + _METADATA_FORBIDDEN_TOP_LEVEL_KEYS = { + "assistantId", "appId", "app_id", - "channel", "configVersionId", "config_version_id", + "services", + } + _METADATA_FORBIDDEN_KEY_TOKENS = { + "apikey", + "token", + "secret", + "password", + "authorization", } def __init__( @@ -87,6 +97,7 @@ class Session: transport: BaseTransport, use_duplex: bool = None, backend_gateway: Optional[Any] = None, + assistant_id: Optional[str] = None, ): """ Initialize session. @@ -99,6 +110,7 @@ class Session: self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + self._assistant_id = str(assistant_id or "").strip() or None self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() self._history_bridge = SessionHistoryBridge( history_writer=self._backend_gateway, @@ -121,9 +133,8 @@ class Session: # Session state self.created_at = None self.state = "created" # Legacy call state for /call/lists - self.ws_state = WsSessionState.WAIT_HELLO + self.ws_state = WsSessionState.WAIT_START self._pipeline_started = False - self.protocol_version: Optional[str] = None # Track IDs self.current_track_id: str = self.TRACK_CONTROL @@ -137,7 +148,12 @@ class Session: self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.conversation.on_turn_complete(self._on_turn_complete) - logger.info(f"Session {self.id} created (duplex={self.use_duplex})") + logger.info( + "Session {} created (duplex={}, assistant_id={})", + self.id, + self.use_duplex, + self._assistant_id or "-", + ) async def handle_text(self, text_data: str) -> None: """ @@ -222,23 +238,19 @@ class Session: msg_type = message.type logger.info(f"Session {self.id} received message: {msg_type}") - if isinstance(message, HelloMessage): - await self._handle_hello(message) - return - - # All messages below require hello handshake first - if self.ws_state == WsSessionState.WAIT_HELLO: - await self._send_error( - "client", - "Expected hello message first", - "protocol.order", - ) - return - if isinstance(message, SessionStartMessage): await self._handle_session_start(message) return + # All messages below require session.start first + if self.ws_state == WsSessionState.WAIT_START: + await self._send_error( + "client", + "Expected session.start message first", + "protocol.order", + ) + return + # All messages below require active session if self.ws_state != WsSessionState.ACTIVE: await self._send_error( @@ -262,67 +274,56 @@ class Session: else: await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") - async def _handle_hello(self, message: HelloMessage) -> None: - """Handle initial hello/version negotiation.""" - if self.ws_state != WsSessionState.WAIT_HELLO: - await self._send_error("client", "Duplicate hello", "protocol.order") - return - - if message.version != settings.ws_protocol_version: - await self._send_error( - "client", - f"Unsupported protocol version '{message.version}'", - "protocol.version_unsupported", - ) - await self.transport.close() - self.ws_state = WsSessionState.STOPPED - return - - self.protocol_version = message.version - self.ws_state = WsSessionState.WAIT_START - await self._send_event( - ev( - "hello.ack", - version=self.protocol_version, - ) - ) - async def _handle_session_start(self, message: SessionStartMessage) -> None: - """Handle explicit session start after successful hello.""" + """Handle explicit session start.""" if self.ws_state != WsSessionState.WAIT_START: await self._send_error("client", "Duplicate session.start", "protocol.order") return raw_metadata = message.metadata or {} - workflow_runtime = self._bootstrap_workflow(raw_metadata) - server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime) - client_runtime = self._sanitize_client_metadata(raw_metadata) - requested_assistant_id = ( - workflow_runtime.get("assistantId") - or raw_metadata.get("assistantId") - or raw_metadata.get("appId") - or raw_metadata.get("app_id") - ) - if server_runtime: - logger.info( - "Session {} loaded trusted runtime config from backend " - "(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})", - self.id, - requested_assistant_id, - server_runtime.get("assistantId"), - server_runtime.get("configVersionId") or server_runtime.get("config_version_id"), - isinstance(server_runtime.get("services"), dict), + if not self._assistant_id: + await self._send_error( + "client", + "Missing required query parameter assistant_id", + "protocol.assistant_id_required", + stage="protocol", + retryable=False, ) - else: - logger.warning( - "Session {} missing trusted backend runtime config " - "(requested_assistant_id={}); falling back to engine defaults + safe client overrides", - self.id, - requested_assistant_id, + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata) + if metadata_error: + await self._send_error( + "client", + metadata_error["message"], + metadata_error["code"], + stage="protocol", + retryable=False, ) - metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime)) - metadata = self._merge_runtime_metadata(metadata, client_runtime) - metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id) + if runtime_error: + await self._send_error( + "server", + runtime_error["message"], + runtime_error["code"], + stage="protocol", + retryable=False, + ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {})) + for key in ("channel", "source", "history"): + if key in sanitized_metadata: + metadata[key] = sanitized_metadata[key] + metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata) if dynamic_var_error: await self._send_error( "client", @@ -331,6 +332,8 @@ class Session: stage="protocol", retryable=False, ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED return # Create history call record early so later turn callbacks can append transcripts. @@ -348,7 +351,7 @@ class Session: "(assistantId={}, configVersionId={}, output_mode={}, " "llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})", self.id, - metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"), + metadata.get("assistantId"), metadata.get("configVersionId") or metadata.get("config_version_id"), (resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None, llm_cfg.get("provider"), @@ -387,24 +390,6 @@ class Session: config=self._build_config_resolved(metadata), ) ) - if self.workflow_runner and self._workflow_initial_node: - await self._send_event( - ev( - "workflow.started", - workflowId=self.workflow_runner.workflow_id, - workflowName=self.workflow_runner.name, - nodeId=self._workflow_initial_node.id, - ) - ) - await self._send_event( - ev( - "workflow.node.entered", - workflowId=self.workflow_runner.workflow_id, - nodeId=self._workflow_initial_node.id, - nodeName=self._workflow_initial_node.name, - nodeType=self._workflow_initial_node.node_type, - ) - ) # Emit opener only after frontend has received session.started/config events. await self.pipeline.emit_initial_greeting() @@ -658,7 +643,7 @@ class Session: def _event_source(self, event_type: str) -> str: if event_type.startswith("workflow."): return "system" - if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat": + if event_type.startswith("session.") or event_type == "heartbeat": return "system" if event_type == "error": return "system" @@ -931,26 +916,41 @@ class Session: async def _load_server_runtime_metadata( self, - client_metadata: Dict[str, Any], - workflow_runtime: Dict[str, Any], - ) -> Dict[str, Any]: + assistant_id: str, + ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: """Load trusted runtime metadata from backend assistant config.""" - assistant_id = ( - workflow_runtime.get("assistantId") - or client_metadata.get("assistantId") - or client_metadata.get("appId") - or client_metadata.get("app_id") - ) - if assistant_id is None: - return {} + if not assistant_id: + return {}, { + "code": "protocol.assistant_id_required", + "message": "Missing required query parameter assistant_id", + } provider = getattr(self._backend_gateway, "fetch_assistant_config", None) if not callable(provider): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": "Assistant config backend unavailable", + } payload = await provider(str(assistant_id).strip()) + if isinstance(payload, dict): + error_code = str(payload.get("__error_code") or "").strip() + if error_code == "assistant.not_found": + return {}, { + "code": "assistant.not_found", + "message": f"Assistant not found: {assistant_id}", + } + if error_code == "assistant.config_unavailable": + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } + if not isinstance(payload, dict): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } assistant_cfg: Dict[str, Any] = {} session_start_cfg = payload.get("sessionStartMetadata") @@ -962,7 +962,10 @@ class Session: assistant_cfg = payload if not isinstance(assistant_cfg, dict): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } runtime: Dict[str, Any] = {} passthrough_keys = { @@ -1009,36 +1012,110 @@ class Session: if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None: runtime["configVersionId"] = runtime.get("config_version_id") - return runtime + return runtime, None - def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - """ - Sanitize untrusted metadata sources. + def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]: + if isinstance(payload, dict): + for key, value in payload.items(): + key_str = str(key) + normalized = key_str.lower().replace("_", "").replace("-", "") + if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS): + return f"{path}.{key_str}" + nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}") + if nested: + return nested + return None + if isinstance(payload, list): + for idx, value in enumerate(payload): + nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]") + if nested: + return nested + return None - This keeps only a small override whitelist and stable config ID fields. - """ + def _validate_and_sanitize_client_metadata( + self, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: if not isinstance(metadata, dict): - return {} + return {}, { + "code": "protocol.invalid_override", + "message": "metadata must be an object", + } - sanitized: Dict[str, Any] = {} - for key in self._CLIENT_METADATA_ID_KEYS: - if key in metadata: - sanitized[key] = metadata[key] - for key in self._CLIENT_METADATA_OVERRIDES: - if key in metadata: - sanitized[key] = metadata[key] + forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata] + if forbidden_top_level: + return {}, { + "code": "protocol.invalid_override", + "message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}", + } - return sanitized + unknown_keys = [ + key + for key in metadata.keys() + if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS + and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS + ] + if unknown_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}", + } - def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - """Apply client metadata whitelist and remove forbidden secrets.""" - sanitized = self._sanitize_untrusted_runtime_metadata(metadata) - if isinstance(metadata.get("services"), dict): - logger.warning( - "Session {} provided metadata.services from client; client-side service config is ignored", - self.id, - ) - return sanitized + if "workflow" in metadata: + logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id) + + overrides_raw = metadata.get("overrides") + overrides: Dict[str, Any] = {} + if overrides_raw is not None: + if not isinstance(overrides_raw, dict): + return {}, { + "code": "protocol.invalid_override", + "message": "metadata.overrides must be an object", + } + unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS] + if unsupported_override_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}", + } + overrides = dict(overrides_raw) + + dynamic_variables = metadata.get("dynamicVariables") + history_raw = metadata.get("history") + history: Dict[str, Any] = {} + if history_raw is not None: + if not isinstance(history_raw, dict): + return {}, { + "code": "protocol.invalid_override", + "message": "metadata.history must be an object", + } + unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"] + if unsupported_history_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}", + } + if "userId" in history_raw: + history["userId"] = history_raw.get("userId") + + sanitized: Dict[str, Any] = {"overrides": overrides} + if dynamic_variables is not None: + sanitized["dynamicVariables"] = dynamic_variables + if "channel" in metadata: + sanitized["channel"] = metadata.get("channel") + if "source" in metadata: + sanitized["source"] = metadata.get("source") + if history: + sanitized["history"] = history + + forbidden_path = self._find_forbidden_secret_key(sanitized) + if forbidden_path: + return {}, { + "code": "protocol.invalid_override", + "message": f"Forbidden secret-like key detected at {forbidden_path}", + } + + return sanitized, None def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Build public resolved config payload (secrets removed).""" @@ -1047,7 +1124,7 @@ class Session: runtime = self.pipeline.resolved_runtime_config() return { - "appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"), + "appId": metadata.get("assistantId"), "channel": metadata.get("channel"), "configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"), "prompt": {"sha256": prompt_hash}, diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index 8cf3ebc..eb9db62 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -5,7 +5,7 @@ This document defines the public WebSocket protocol for the `/ws` endpoint. Validation policy: - WS v1 JSON control messages are validated strictly. - Unknown top-level fields are rejected for all defined client message types. -- `hello.version` is fixed to `"v1"`. +- `assistant_id` query parameter is required on `/ws`. ## Transport @@ -17,29 +17,16 @@ Validation policy: Required message order: -1. Client sends `hello`. -2. Server replies `hello.ack`. -3. Client sends `session.start`. -4. Server replies `session.started`. -5. Client may stream binary audio and/or send `input.text`. -6. Client sends `session.stop` (or closes socket). +1. Client connects to `/ws?assistant_id=`. +2. Client sends `session.start`. +3. Server replies `session.started`. +4. Client may stream binary audio and/or send `input.text`. +5. Client sends `session.stop` (or closes socket). If order is violated, server emits `error` with `code = "protocol.order"`. ## Client -> Server Messages -### `hello` - -```json -{ - "type": "hello", - "version": "v1" -} -``` - -Rules: -- `version` must be `v1`. - ### `session.start` ```json @@ -51,15 +38,18 @@ Rules: "channels": 1 }, "metadata": { - "appId": "assistant_123", "channel": "web", - "configVersionId": "cfg_20260217_01", - "client": "web-debug", - "output": { - "mode": "audio" + "source": "web-debug", + "history": { + "userId": 1 + }, + "overrides": { + "output": { + "mode": "audio" + }, + "systemPrompt": "You are concise.", + "greeting": "Hi, how can I help?" }, - "systemPrompt": "You are concise.", - "greeting": "Hi, how can I help?", "dynamicVariables": { "customer_name": "Alice", "plan_tier": "Pro" @@ -69,9 +59,13 @@ Rules: ``` Rules: -- Client-side `metadata.services` is ignored. -- Service config (including secrets) is resolved server-side (env/backend). -- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints). +- Assistant config is resolved strictly by URL query `assistant_id`. +- `metadata` top-level keys allowed: `overrides`, `dynamicVariables`, `channel`, `source`, `history`, `workflow` (`workflow` is ignored). +- `metadata.overrides` whitelist: `systemPrompt`, `greeting`, `firstTurnMode`, `generatedOpenerEnabled`, `output`, `bargeIn`, `knowledgeBaseId`, `knowledge`, `tools`, `openerAudio`. +- `metadata.services` is rejected with `protocol.invalid_override`. +- `metadata.workflow` is ignored in this MVP protocol version. +- Top-level IDs are forbidden in payload (`assistantId`, `appId`, `app_id`, `configVersionId`, `config_version_id`). +- Secret-like keys are forbidden in metadata (`apiKey`, `token`, `secret`, `password`, `authorization`). - `metadata.dynamicVariables` is optional and must be an object of string key/value pairs. - Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$` - Max entries: 30 @@ -85,7 +79,7 @@ Rules: - Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`. Text-only mode: -- Set `metadata.output.mode = "text"`. +- Set `metadata.overrides.output.mode = "text"`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. ### `input.text` @@ -158,8 +152,6 @@ Envelope notes: Common events: -- `hello.ack` - - Fields: `sessionId`, `version` - `session.started` - Fields: `sessionId`, `trackId`, `tracks`, `audio` - `config.resolved` @@ -204,7 +196,7 @@ Common events: Track IDs (MVP fixed values): - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) - `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`) -- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`) +- `control`: session/control events (`session.*`, `error`, `config.resolved`) Correlation IDs (`event.data`): - `turn_id`: one user-assistant interaction turn. diff --git a/engine/examples/mic_client.py b/engine/examples/mic_client.py index 00d403f..d734c95 100644 --- a/engine/examples/mic_client.py +++ b/engine/examples/mic_client.py @@ -23,6 +23,7 @@ import time import threading import queue from pathlib import Path +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -59,9 +60,8 @@ class MicrophoneClient: url: str, sample_rate: int = 16000, chunk_duration_ms: int = 20, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "mic_client", - config_version_id: str = "local-dev", input_device: int = None, output_device: int = None, track_debug: bool = False, @@ -80,9 +80,8 @@ class MicrophoneClient: self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.input_device = input_device self.output_device = output_device self.track_debug = track_debug @@ -125,19 +124,21 @@ class MicrophoneClient: if value: parts.append(f"{key}={value}") return f" [{' '.join(parts)}]" if parts else "" + + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) async def connect(self) -> None: """Connect to WebSocket server.""" - print(f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + print(f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True print("Connected!") - - # WS v1 handshake: hello -> session.start - await self.send_command({ - "type": "hello", - "version": "v1", - }) + await self.send_command({ "type": "session.start", "audio": { @@ -146,9 +147,8 @@ class MicrophoneClient: "channels": 1, }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "mic_client", }, }) @@ -330,7 +330,7 @@ class MicrophoneClient: if self.track_debug: print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type in {"hello.ack", "session.started"}: + if event_type == "session.started": print(f"← Session ready!{ids}") elif event_type == "config.resolved": print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}") @@ -609,20 +609,15 @@ async def main(): help="Show streaming LLM response chunks" ) parser.add_argument( - "--app-id", + "--assistant-id", default="assistant_demo", - help="Stable app/assistant identifier for server-side config lookup" + help="Assistant identifier used in websocket query parameter" ) parser.add_argument( "--channel", default="mic_client", help="Client channel name" ) - parser.add_argument( - "--config-version-id", - default="local-dev", - help="Optional config version identifier" - ) parser.add_argument( "--track-debug", action="store_true", @@ -638,9 +633,8 @@ async def main(): client = MicrophoneClient( url=args.url, sample_rate=args.sample_rate, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, input_device=args.input_device, output_device=args.output_device, track_debug=args.track_debug, diff --git a/engine/examples/simple_client.py b/engine/examples/simple_client.py index b1648bf..8a33fec 100644 --- a/engine/examples/simple_client.py +++ b/engine/examples/simple_client.py @@ -15,6 +15,7 @@ import sys import time import wave import io +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -56,16 +57,14 @@ class SimpleVoiceClient: self, url: str, sample_rate: int = 16000, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "simple_client", - config_version_id: str = "local-dev", track_debug: bool = False, ): self.url = url self.sample_rate = sample_rate - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.track_debug = track_debug self.ws = None self.running = False @@ -88,6 +87,12 @@ class SimpleVoiceClient: # Interrupt handling - discard audio until next trackStart self._discard_audio = False + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) + @staticmethod def _event_ids_suffix(event: dict) -> str: data = event.get("data") if isinstance(event.get("data"), dict) else {} @@ -101,16 +106,12 @@ class SimpleVoiceClient: async def connect(self): """Connect to server.""" - print(f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + print(f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True print("Connected!") - - # WS v1 handshake: hello -> session.start - await self.ws.send(json.dumps({ - "type": "hello", - "version": "v1", - })) + await self.ws.send(json.dumps({ "type": "session.start", "audio": { @@ -119,12 +120,11 @@ class SimpleVoiceClient: "channels": 1, }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "simple_client", }, })) - print("-> hello/session.start") + print("-> session.start") async def send_chat(self, text: str): """Send chat message.""" @@ -311,9 +311,8 @@ async def main(): parser.add_argument("--text", help="Send text and play response") parser.add_argument("--list-devices", action="store_true") parser.add_argument("--sample-rate", type=int, default=16000) - parser.add_argument("--app-id", default="assistant_demo") + parser.add_argument("--assistant-id", default="assistant_demo") parser.add_argument("--channel", default="simple_client") - parser.add_argument("--config-version-id", default="local-dev") parser.add_argument("--track-debug", action="store_true") args = parser.parse_args() @@ -325,9 +324,8 @@ async def main(): client = SimpleVoiceClient( args.url, args.sample_rate, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, track_debug=args.track_debug, ) await client.run(args.text) diff --git a/engine/examples/test_websocket.py b/engine/examples/test_websocket.py index 6717834..30f7aee 100644 --- a/engine/examples/test_websocket.py +++ b/engine/examples/test_websocket.py @@ -12,6 +12,7 @@ import math import argparse import os from datetime import datetime +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit # Configuration SERVER_URL = "ws://localhost:8000/ws" @@ -130,14 +131,17 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa """Run the WebSocket test client.""" session = aiohttp.ClientSession() try: - print(f"🔌 Connecting to {url}...") - async with session.ws_connect(url) as ws: + parts = urlsplit(url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo") + session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) + print(f"🔌 Connecting to {session_url}...") + async with session.ws_connect(session_url) as ws: print("✅ Connected!") session_ready = asyncio.Event() recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug)) - # Send v1 hello + session.start handshake - await ws.send_json({"type": "hello", "version": "v1"}) + # Send v1 session.start initialization await ws.send_json({ "type": "session.start", "audio": { @@ -146,12 +150,11 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa "channels": 1 }, "metadata": { - "appId": "assistant_demo", "channel": "test_websocket", - "configVersionId": "local-dev", + "source": "test_websocket", }, }) - print("📤 Sent v1 hello/session.start") + print("📤 Sent v1 session.start") await asyncio.wait_for(session_ready.wait(), timeout=8) # Select sender based on args diff --git a/engine/examples/wav_client.py b/engine/examples/wav_client.py index 5684256..1e4a50d 100644 --- a/engine/examples/wav_client.py +++ b/engine/examples/wav_client.py @@ -21,6 +21,7 @@ import sys import time import wave from pathlib import Path +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -57,9 +58,8 @@ class WavFileClient: url: str, input_file: str, output_file: str, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "wav_client", - config_version_id: str = "local-dev", sample_rate: int = 16000, chunk_duration_ms: int = 20, wait_time: float = 15.0, @@ -82,9 +82,8 @@ class WavFileClient: self.url = url self.input_file = Path(input_file) self.output_file = Path(output_file) - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) @@ -147,19 +146,21 @@ class WavFileClient: if value: parts.append(f"{key}={value}") return f" [{' '.join(parts)}]" if parts else "" + + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) async def connect(self) -> None: """Connect to WebSocket server.""" - self.log_event("→", f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + self.log_event("→", f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True self.log_event("←", "Connected!") - # WS v1 handshake: hello -> session.start - await self.send_command({ - "type": "hello", - "version": "v1", - }) await self.send_command({ "type": "session.start", "audio": { @@ -168,9 +169,8 @@ class WavFileClient: "channels": 1 }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "wav_client", }, }) @@ -329,9 +329,7 @@ class WavFileClient: if self.track_debug: print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type == "hello.ack": - self.log_event("←", f"Handshake acknowledged{ids}") - elif event_type == "session.started": + if event_type == "session.started": self.session_ready = True self.log_event("←", f"Session ready!{ids}") elif event_type == "config.resolved": @@ -521,20 +519,15 @@ async def main(): help="Target sample rate for audio (default: 16000)" ) parser.add_argument( - "--app-id", + "--assistant-id", default="assistant_demo", - help="Stable app/assistant identifier for server-side config lookup" + help="Assistant identifier used in websocket query parameter" ) parser.add_argument( "--channel", default="wav_client", help="Client channel name" ) - parser.add_argument( - "--config-version-id", - default="local-dev", - help="Optional config version identifier" - ) parser.add_argument( "--chunk-duration", type=int, @@ -570,9 +563,8 @@ async def main(): url=args.url, input_file=args.input, output_file=args.output, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, sample_rate=args.sample_rate, chunk_duration_ms=args.chunk_duration, wait_time=args.wait_time, diff --git a/engine/examples/web_client.html b/engine/examples/web_client.html index 3431c02..10a7556 100644 --- a/engine/examples/web_client.html +++ b/engine/examples/web_client.html @@ -401,9 +401,14 @@ const targetSampleRate = 16000; const playbackStopRampSec = 0.008; - const appId = "assistant_demo"; + const assistantId = "assistant_demo"; const channel = "web_client"; - const configVersionId = "local-dev"; + + function buildSessionWsUrl(baseUrl) { + const parsed = new URL(baseUrl); + parsed.searchParams.set("assistant_id", assistantId); + return parsed.toString(); + } function logLine(type, text, data) { const time = new Date().toLocaleTimeString(); @@ -556,14 +561,25 @@ async function connect() { if (ws && ws.readyState === WebSocket.OPEN) return; - ws = new WebSocket(wsUrl.value.trim()); + const sessionWsUrl = buildSessionWsUrl(wsUrl.value.trim()); + ws = new WebSocket(sessionWsUrl); ws.binaryType = "arraybuffer"; ws.onopen = () => { setStatus(true, "Session open"); logLine("sys", "WebSocket connected"); ensureAudioContext(); - sendCommand({ type: "hello", version: "v1" }); + sendCommand({ + type: "session.start", + audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, + metadata: { + channel, + source: "web_client", + overrides: { + output: { mode: "audio" }, + }, + }, + }); }; ws.onclose = () => { @@ -622,17 +638,6 @@ const type = event.type || "unknown"; const ids = eventIdsSuffix(event); logLine("event", `${type}${ids}`, event); - if (type === "hello.ack") { - sendCommand({ - type: "session.start", - audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, - metadata: { - appId, - channel, - configVersionId, - }, - }); - } if (type === "config.resolved") { logLine("sys", "config.resolved", event.config || {}); } diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py index 7987ff6..39bd61e 100644 --- a/engine/models/ws_v1.py +++ b/engine/models/ws_v1.py @@ -18,12 +18,6 @@ class _StrictModel(BaseModel): model_config = ConfigDict(extra="forbid") -# Client -> Server messages -class HelloMessage(_StrictModel): - type: Literal["hello"] - version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") - - class SessionStartAudio(_StrictModel): encoding: Literal["pcm_s16le"] = "pcm_s16le" sample_rate_hz: Literal[16000] = 16000 @@ -69,7 +63,6 @@ class ToolCallResultsMessage(_StrictModel): CLIENT_MESSAGE_TYPES = { - "hello": HelloMessage, "session.start": SessionStartMessage, "session.stop": SessionStopMessage, "input.text": InputTextMessage, diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 236e86a..ac60de9 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -86,7 +86,7 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl async def _capture_event(event: Dict[str, Any], priority: int = 20): events.append(event) - async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8): + async def _noop_speak(_text: str, *args, **kwargs): return None monkeypatch.setattr(pipeline, "_send_event", _capture_event) diff --git a/engine/tests/test_ws_protocol_session_start.py b/engine/tests/test_ws_protocol_session_start.py new file mode 100644 index 0000000..07ee762 --- /dev/null +++ b/engine/tests/test_ws_protocol_session_start.py @@ -0,0 +1,238 @@ +import pytest + +from core.session import Session, WsSessionState +from models.ws_v1 import SessionStartMessage, parse_client_message + + +def _session() -> Session: + session = Session.__new__(Session) + session.id = "sess_test" + session._assistant_id = "assistant_demo" + return session + + +def test_parse_client_message_rejects_hello_message(): + with pytest.raises(ValueError, match="Unknown client message type: hello"): + parse_client_message({"type": "hello", "version": "v1"}) + + +@pytest.mark.asyncio +async def test_handle_text_reports_invalid_message_for_hello(): + session = Session.__new__(Session) + session.id = "sess_invalid_hello" + session.ws_state = WsSessionState.WAIT_START + + class _Transport: + async def close(self): + return None + + session.transport = _Transport() + captured = [] + + async def _send_error(sender, message, code, **kwargs): + captured.append((sender, code, message, kwargs)) + + session._send_error = _send_error + + await session.handle_text('{"type":"hello","version":"v1"}') + assert captured + sender, code, message, _ = captured[0] + assert sender == "client" + assert code == "protocol.invalid_message" + assert "Unknown client message type: hello" in message + + +def test_validate_metadata_rejects_services_payload(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}}) + assert sanitized == {} + assert error is not None + assert error["code"] == "protocol.invalid_override" + + +def test_validate_metadata_rejects_secret_like_override_keys(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata( + { + "overrides": { + "output": {"mode": "audio"}, + "apiKey": "xxx", + } + } + ) + assert sanitized == {} + assert error is not None + assert error["code"] == "protocol.invalid_override" + + +def test_validate_metadata_ignores_workflow_payload(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata( + { + "workflow": {"nodes": [], "edges": []}, + "channel": "web_debug", + "overrides": {"output": {"mode": "text"}}, + } + ) + assert error is None + assert "workflow" not in sanitized + assert sanitized["channel"] == "web_debug" + assert sanitized["overrides"]["output"]["mode"] == "text" + + +@pytest.mark.asyncio +async def test_load_server_runtime_metadata_returns_not_found_error(): + session = _session() + + class _Gateway: + async def fetch_assistant_config(self, assistant_id: str): + _ = assistant_id + return {"__error_code": "assistant.not_found"} + + session._backend_gateway = _Gateway() + runtime, error = await session._load_server_runtime_metadata("assistant_demo") + assert runtime == {} + assert error is not None + assert error["code"] == "assistant.not_found" + + +@pytest.mark.asyncio +async def test_load_server_runtime_metadata_returns_config_unavailable_error(): + session = _session() + + class _Gateway: + async def fetch_assistant_config(self, assistant_id: str): + _ = assistant_id + return None + + session._backend_gateway = _Gateway() + runtime, error = await session._load_server_runtime_metadata("assistant_demo") + assert runtime == {} + assert error is not None + assert error["code"] == "assistant.config_unavailable" + + +@pytest.mark.asyncio +async def test_handle_session_start_requires_assistant_id_and_closes_transport(): + session = Session.__new__(Session) + session.id = "sess_missing_assistant" + session.ws_state = WsSessionState.WAIT_START + session._assistant_id = None + + class _Transport: + def __init__(self): + self.closed = False + + async def close(self): + self.closed = True + + transport = _Transport() + session.transport = transport + captured_codes = [] + + async def _send_error(sender, message, code, **kwargs): + _ = (sender, message, kwargs) + captured_codes.append(code) + + session._send_error = _send_error + + await session._handle_session_start(SessionStartMessage(type="session.start", metadata={})) + assert captured_codes == ["protocol.assistant_id_required"] + assert transport.closed is True + assert session.ws_state == WsSessionState.STOPPED + + +@pytest.mark.asyncio +async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow(): + session = Session.__new__(Session) + session.id = "sess_start_ok" + session.ws_state = WsSessionState.WAIT_START + session.state = "created" + session._assistant_id = "assistant_demo" + session.current_track_id = Session.TRACK_CONTROL + session._pipeline_started = False + + class _Transport: + async def close(self): + return None + + class _Pipeline: + def __init__(self): + self.started = False + self.applied = {} + self.conversation = type("Conversation", (), {"system_prompt": ""})() + + async def start(self): + self.started = True + + async def emit_initial_greeting(self): + return None + + def apply_runtime_overrides(self, metadata): + self.applied = dict(metadata) + + def resolved_runtime_config(self): + return { + "output": {"mode": "text"}, + "services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}}, + "tools": {"allowlist": []}, + } + + session.transport = _Transport() + session.pipeline = _Pipeline() + events = [] + + async def _start_history_bridge(_metadata): + return None + + async def _load_server_runtime_metadata(_assistant_id): + return ( + { + "assistantId": "assistant_demo", + "configVersionId": "cfg_1", + "systemPrompt": "Base prompt", + "greeting": "Base greeting", + "output": {"mode": "audio"}, + }, + None, + ) + + async def _send_event(event): + events.append(event) + + async def _send_error(sender, message, code, **kwargs): + raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") + + session._start_history_bridge = _start_history_bridge + session._load_server_runtime_metadata = _load_server_runtime_metadata + session._send_event = _send_event + session._send_error = _send_error + + await session._handle_session_start( + SessionStartMessage( + type="session.start", + metadata={ + "workflow": {"nodes": []}, + "channel": "web_debug", + "source": "debug_ui", + "history": {"userId": 7}, + "overrides": { + "greeting": "Override greeting", + "output": {"mode": "text"}, + "tools": [{"name": "calculator"}], + }, + }, + ) + ) + + assert session.ws_state == WsSessionState.ACTIVE + assert session.pipeline.started is True + assert session.pipeline.applied["assistantId"] == "assistant_demo" + assert session.pipeline.applied["greeting"] == "Override greeting" + assert session.pipeline.applied["output"]["mode"] == "text" + assert session.pipeline.applied["tools"] == [{"name": "calculator"}] + assert not any(str(item.get("type", "")).startswith("workflow.") for item in events) + + config_event = next(item for item in events if item.get("type") == "config.resolved") + assert config_event["config"]["appId"] == "assistant_demo" + assert config_event["config"]["channel"] == "web_debug" diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 6ec63c4..adf0dc4 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -2842,6 +2842,61 @@ export const DebugDrawer: React.FC<{ return error; }; + const METADATA_OVERRIDE_WHITELIST = new Set([ + 'firstTurnMode', + 'greeting', + 'generatedOpenerEnabled', + 'systemPrompt', + 'output', + 'bargeIn', + 'knowledge', + 'knowledgeBaseId', + 'openerAudio', + 'tools', + ]); + const METADATA_FORBIDDEN_SECRET_TOKENS = ['apikey', 'token', 'secret', 'password', 'authorization']; + const isPlainObject = (value: unknown): value is Record => Boolean(value) && typeof value === 'object' && !Array.isArray(value); + const isForbiddenSecretKey = (key: string): boolean => { + const normalized = key.toLowerCase().replace(/[_-]/g, ''); + return METADATA_FORBIDDEN_SECRET_TOKENS.some((token) => normalized.includes(token)); + }; + const stripForbiddenSecretKeysDeep = (value: any): any => { + if (Array.isArray(value)) return value.map(stripForbiddenSecretKeysDeep); + if (!isPlainObject(value)) return value; + return Object.entries(value).reduce>((acc, [key, nested]) => { + if (isForbiddenSecretKey(key)) return acc; + acc[key] = stripForbiddenSecretKeysDeep(nested); + return acc; + }, {}); + }; + const sanitizeMetadataForWs = (raw: unknown): Record => { + if (!isPlainObject(raw)) return { overrides: {} }; + const sanitized: Record = { overrides: {} }; + + if (typeof raw.channel === 'string' && raw.channel.trim()) { + sanitized.channel = raw.channel.trim(); + } + if (typeof raw.source === 'string' && raw.source.trim()) { + sanitized.source = raw.source.trim(); + } + if (isPlainObject(raw.history) && raw.history.userId !== undefined) { + sanitized.history = { userId: raw.history.userId }; + } + if (isPlainObject(raw.dynamicVariables)) { + sanitized.dynamicVariables = raw.dynamicVariables; + } + if (isPlainObject(raw.overrides)) { + const overrides = Object.entries(raw.overrides).reduce>((acc, [key, value]) => { + if (!METADATA_OVERRIDE_WHITELIST.has(key)) return acc; + if (isForbiddenSecretKey(key)) return acc; + acc[key] = stripForbiddenSecretKeysDeep(value); + return acc; + }, {}); + sanitized.overrides = overrides; + } + return sanitized; + }; + const buildDynamicVariablesPayload = (): { variables: Record; error?: string } => { const variables: Record = {}; const nonEmptyRows = dynamicVariables @@ -2908,104 +2963,35 @@ export const DebugDrawer: React.FC<{ const buildLocalResolvedRuntime = () => { const warnings: string[] = []; - const services: Record = {}; const ttsEnabled = Boolean(textTtsEnabled); - const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt'; const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim(); const knowledge = knowledgeBaseId ? { enabled: true, kbId: knowledgeBaseId, nResults: 5 } : { enabled: false }; - if (isExternalLlm) { - services.llm = { - provider: 'openai', - model: '', - apiKey: assistant.apiKey || '', - baseUrl: assistant.apiUrl || '', - }; - if (!assistant.apiUrl) warnings.push(`External LLM API URL is empty for mode: ${assistant.configMode}`); - if (!assistant.apiKey) warnings.push(`External LLM API key is empty for mode: ${assistant.configMode}`); - } else if (assistant.llmModelId) { - const llm = llmModels.find((item) => item.id === assistant.llmModelId); - if (llm) { - services.llm = { - provider: 'openai', - model: llm.modelName || llm.name, - apiKey: llm.apiKey, - baseUrl: llm.baseUrl, - }; - } else { - warnings.push(`LLM model not found in loaded list: ${assistant.llmModelId}`); - } - } else { - // Keep empty object to indicate engine should use default provider model. - services.llm = {}; - } - - if (assistant.asrModelId) { - const asr = asrModels.find((item) => item.id === assistant.asrModelId); - if (asr) { - const asrProvider = isOpenAICompatibleVendor(asr.vendor) ? 'openai_compatible' : 'buffered'; - services.asr = { - provider: asrProvider, - model: asr.modelName || asr.name, - apiKey: asrProvider === 'openai_compatible' ? asr.apiKey : null, - }; - } else { - warnings.push(`ASR model not found in loaded list: ${assistant.asrModelId}`); - } - } - - if (assistant.voice) { - const voice = voices.find((item) => item.id === assistant.voice); - if (voice) { - const ttsProvider = isOpenAICompatibleVendor(voice.vendor) ? 'openai_compatible' : 'edge'; - services.tts = { - enabled: ttsEnabled, - provider: ttsProvider, - model: voice.model, - apiKey: ttsProvider === 'openai_compatible' ? voice.apiKey : null, - voice: resolveRuntimeTtsVoice(assistant.voice, voice), - speed: assistant.speed || voice.speed || 1.0, - }; - } else { - services.tts = { - enabled: ttsEnabled, - voice: assistant.voice, - speed: assistant.speed || 1.0, - }; - warnings.push(`Voice resource not found in loaded list: ${assistant.voice}`); - } - } else if (!ttsEnabled) { - services.tts = { - enabled: false, - }; - } - const localResolved = { - assistantId: assistant.id, warnings, sessionStartMetadata: { - output: { - mode: ttsEnabled ? 'audio' : 'text', + overrides: { + output: { + mode: ttsEnabled ? 'audio' : 'text', + }, + systemPrompt: assistant.prompt || '', + firstTurnMode: assistant.firstTurnMode || 'bot_first', + greeting: assistant.opener || '', + generatedOpenerEnabled: assistant.generatedOpenerEnabled === true, + bargeIn: { + enabled: assistant.botCannotBeInterrupted !== true, + minDurationMs: Math.max(0, Number(assistant.interruptionSensitivity ?? 180)), + }, + knowledgeBaseId, + knowledge, + tools: selectedToolSchemas, }, - systemPrompt: assistant.prompt || '', - firstTurnMode: assistant.firstTurnMode || 'bot_first', - greeting: assistant.opener || '', - generatedOpenerEnabled: assistant.generatedOpenerEnabled === true, - bargeIn: { - enabled: assistant.botCannotBeInterrupted !== true, - minDurationMs: Math.max(0, Number(assistant.interruptionSensitivity ?? 180)), - }, - knowledgeBaseId, - knowledge, - tools: selectedToolSchemas, - services, history: { - assistantId: assistant.id, userId: 1, - source: 'debug', }, + source: 'web_debug', }, }; @@ -3020,21 +3006,13 @@ export const DebugDrawer: React.FC<{ } setDynamicVariablesError(''); const localResolved = buildLocalResolvedRuntime(); - const mergedMetadata: Record = { + const mergedMetadata: Record = sanitizeMetadataForWs({ ...localResolved.sessionStartMetadata, ...(sessionMetadataExtras || {}), - }; + }); if (Object.keys(dynamicVariablesResult.variables).length > 0) { mergedMetadata.dynamicVariables = dynamicVariablesResult.variables; } - // Engine resolves trusted runtime config by top-level assistant/app ID. - // Keep these IDs at metadata root so backend /assistants/{id}/config is reachable. - if (!mergedMetadata.assistantId && assistant.id) { - mergedMetadata.assistantId = assistant.id; - } - if (!mergedMetadata.appId && assistant.id) { - mergedMetadata.appId = assistant.id; - } if (!mergedMetadata.channel) { mergedMetadata.channel = 'web_debug'; } @@ -3069,6 +3047,24 @@ export const DebugDrawer: React.FC<{ if (isOpen) setWsStatus('disconnected'); }; + const buildSessionWsUrl = () => { + const base = wsUrl.trim(); + if (!base) return ''; + try { + const parsed = new URL(base); + parsed.searchParams.set('assistant_id', assistant.id); + return parsed.toString(); + } catch { + try { + const parsed = new URL(base, window.location.href); + parsed.searchParams.set('assistant_id', assistant.id); + return parsed.toString(); + } catch { + return base; + } + } + }; + const ensureWsSession = async () => { if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) { return; @@ -3083,18 +3079,25 @@ export const DebugDrawer: React.FC<{ } const metadata = await fetchRuntimeMetadata(); + const sessionWsUrl = buildSessionWsUrl(); setWsStatus('connecting'); setWsError(''); await new Promise((resolve, reject) => { pendingResolveRef.current = resolve; pendingRejectRef.current = reject; - const ws = new WebSocket(wsUrl); + const ws = new WebSocket(sessionWsUrl); ws.binaryType = 'arraybuffer'; wsRef.current = ws; ws.onopen = () => { - ws.send(JSON.stringify({ type: 'hello', version: 'v1' })); + ws.send( + JSON.stringify({ + type: 'session.start', + audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 }, + metadata, + }) + ); }; ws.onmessage = (event) => { @@ -3118,16 +3121,6 @@ export const DebugDrawer: React.FC<{ if (onProtocolEvent) { onProtocolEvent(payload); } - if (type === 'hello.ack') { - ws.send( - JSON.stringify({ - type: 'session.start', - audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 }, - metadata, - }) - ); - return; - } if (type === 'output.audio.start') { // New utterance audio starts: cancel old queued/playing audio to avoid overlap. stopPlaybackImmediately();