diff --git a/docs/content/api-reference.md b/docs/content/api-reference.md index 20b6fa1..d1a4366 100644 --- a/docs/content/api-reference.md +++ b/docs/content/api-reference.md @@ -11,9 +11,11 @@ WebSocket 端点提供双向实时语音对话能力,支持音频流输入输 ### 连接地址 ``` -ws:///ws +ws:///ws?assistant_id= ``` +- `assistant_id` 为必填 query 参数,用于从数据库加载该助手的运行时配置。 + ### 传输规则 - **文本帧**:JSON 格式控制消息 @@ -25,8 +27,6 @@ ws:///ws ### 消息流程 ``` -Client -> hello -Server <- hello.ack Client -> session.start Server <- session.started Server <- config.resolved @@ -44,27 +44,9 @@ Server <- session.stopped ### 客户端 -> 服务端消息 -#### 1. Handshake: `hello` +#### 1. Session Start: `session.start` -客户端连接后发送的第一个消息,用于协议版本协商。 - -```json -{ - "type": "hello", - "version": "v1" -} -``` - -| 字段 | 类型 | 必填 | 说明 | -|---|---|---|---| -| `type` | string | 是 | 固定为 `"hello"` | -| `version` | string | 是 | 协议版本,固定为 `"v1"` | - ---- - -#### 2. Session Start: `session.start` - -握手成功后发送的第二个消息,用于启动对话会话。 +客户端连接后发送的第一个消息,用于启动对话会话。 ```json { @@ -75,13 +57,17 @@ Server <- session.stopped "channels": 1 }, "metadata": { - "appId": "assistant_123", "channel": "web", - "configVersionId": "cfg_20260217_01", - "systemPrompt": "你是简洁助手", - "greeting": "你好,我能帮你什么?", - "output": { - "mode": "audio" + "source": "web_debug", + "history": { + "userId": 1 + }, + "overrides": { + "systemPrompt": "你是简洁助手", + "greeting": "你好,我能帮你什么?", + "output": { + "mode": "audio" + } }, "dynamicVariables": { "customer_name": "Alice", @@ -101,17 +87,33 @@ Server <- session.stopped | `metadata` | object | 否 | 运行时配置 | **metadata 支持的字段**: -- `appId` / `app_id` - 应用 ID - `channel` - 渠道标识 -- `configVersionId` / `config_version_id` - 配置版本 -- `systemPrompt` - 系统提示词 -- `greeting` - 开场白 -- `output.mode` - 输出模式 (`audio` / `text`) +- `source` - 来源标识 +- `history.userId` - 历史记录用户 ID +- `overrides` - 可覆盖字段(仅限安全白名单) - `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符) +**`metadata.overrides` 白名单字段**: +- `systemPrompt` +- `greeting` +- `firstTurnMode` +- `generatedOpenerEnabled` +- `output` +- `bargeIn` +- `knowledgeBaseId` +- `knowledge` +- `tools` +- `openerAudio` + +**限制**: +- `metadata.workflow` 会被忽略(不触发 workflow 事件) +- 禁止提交 `metadata.services` +- 禁止提交 `assistantId` / `appId` / `app_id` / `configVersionId` / `config_version_id` +- 禁止提交包含密钥语义的字段(如 `apiKey` / `token` / `secret` / `password` / `authorization`) + --- -#### 3. Text Input: `input.text` +#### 2. Text Input: `input.text` 发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。 @@ -129,7 +131,7 @@ Server <- session.stopped --- -#### 4. Response Cancel: `response.cancel` +#### 3. Response Cancel: `response.cancel` 请求中断当前回答。 @@ -147,7 +149,7 @@ Server <- session.stopped --- -#### 5. Tool Call Results: `tool_call.results` +#### 4. Tool Call Results: `tool_call.results` 回传客户端执行的工具结果。 @@ -176,7 +178,7 @@ Server <- session.stopped --- -#### 6. Session Stop: `session.stop` +#### 5. Session Stop: `session.stop` 结束对话会话。 @@ -194,7 +196,7 @@ Server <- session.stopped --- -#### 7. Binary Audio +#### 6. Binary Audio 在 `session.started` 之后可持续发送二进制 PCM 音频。 @@ -239,7 +241,6 @@ Server <- session.stopped | 事件 | 说明 | |---|---| -| `hello.ack` | 握手成功响应 | | `session.started` | 会话启动成功 | | `config.resolved` | 服务端最终配置快照 | | `heartbeat` | 保活心跳(默认 50 秒间隔) | @@ -281,7 +282,10 @@ Server <- session.stopped | `protocol.invalid_json` | JSON 格式错误 | | `protocol.invalid_message` | 消息格式错误 | | `protocol.order` | 消息顺序错误 | -| `protocol.version_unsupported` | 协议版本不支持 | +| `protocol.assistant_id_required` | 缺少 `assistant_id` query 参数 | +| `protocol.invalid_override` | metadata 覆盖字段不合法 | +| `assistant.not_found` | assistant 不存在 | +| `assistant.config_unavailable` | assistant 配置不可用 | | `audio.invalid_pcm` | PCM 数据无效 | | `audio.frame_size_mismatch` | 音频帧大小不匹配 | | `server.internal` | 服务端内部错误 | diff --git a/engine/app/backend_adapters.py b/engine/app/backend_adapters.py index a05bd8f..6ff2716 100644 --- a/engine/app/backend_adapters.py +++ b/engine/app/backend_adapters.py @@ -157,16 +157,16 @@ class HttpBackendAdapter: async with session.get(url) as resp: if resp.status == 404: logger.warning(f"Assistant config not found: {assistant_id}") - return None + return {"__error_code": "assistant.not_found", "assistantId": assistant_id} resp.raise_for_status() payload = await resp.json() if not isinstance(payload, dict): logger.warning("Assistant config payload is not a dict; ignoring") - return None + return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id} return payload except Exception as exc: logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") - return None + return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id} async def create_call_record( self, diff --git a/engine/app/main.py b/engine/app/main.py index c13daba..b8a39bb 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -163,13 +163,19 @@ async def websocket_endpoint(websocket: WebSocket): """ await websocket.accept() session_id = str(uuid.uuid4()) + assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None # Create transport and session transport = SocketTransport(websocket) - session = Session(session_id, transport, backend_gateway=backend_gateway) + session = Session( + session_id, + transport, + backend_gateway=backend_gateway, + assistant_id=assistant_id, + ) active_sessions[session_id] = session - logger.info(f"WebSocket connection established: {session_id}") + logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] @@ -239,16 +245,22 @@ async def webrtc_endpoint(websocket: WebSocket): return await websocket.accept() session_id = str(uuid.uuid4()) + assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None # Create WebRTC peer connection pc = RTCPeerConnection() # Create transport and session transport = WebRtcTransport(websocket, pc) - session = Session(session_id, transport, backend_gateway=backend_gateway) + session = Session( + session_id, + transport, + backend_gateway=backend_gateway, + assistant_id=assistant_id, + ) active_sessions[session_id] = session - logger.info(f"WebRTC connection established: {session_id}") + logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] diff --git a/engine/core/session.py b/engine/core/session.py index dda0382..ac365ce 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -21,7 +21,6 @@ from services.base import LLMMessage from models.ws_v1 import ( parse_client_message, ev, - HelloMessage, SessionStartMessage, SessionStopMessage, InputTextMessage, @@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone" class WsSessionState(str, Enum): """Protocol state machine for WS sessions.""" - WAIT_HELLO = "wait_hello" WAIT_START = "wait_start" ACTIVE = "active" STOPPED = "stopped" @@ -57,7 +55,15 @@ class Session: TRACK_AUDIO_OUT = "audio_out" TRACK_CONTROL = "control" AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms - _CLIENT_METADATA_OVERRIDES = { + _METADATA_ALLOWED_TOP_LEVEL_KEYS = { + "overrides", + "dynamicVariables", + "channel", + "source", + "history", + "workflow", # explicitly ignored for this MVP protocol version + } + _METADATA_ALLOWED_OVERRIDE_KEYS = { "firstTurnMode", "greeting", "generatedOpenerEnabled", @@ -67,18 +73,22 @@ class Session: "knowledge", "knowledgeBaseId", "openerAudio", - "dynamicVariables", - "history", - "userId", - "assistantId", - "source", + "tools", } - _CLIENT_METADATA_ID_KEYS = { + _METADATA_FORBIDDEN_TOP_LEVEL_KEYS = { + "assistantId", "appId", "app_id", - "channel", "configVersionId", "config_version_id", + "services", + } + _METADATA_FORBIDDEN_KEY_TOKENS = { + "apikey", + "token", + "secret", + "password", + "authorization", } def __init__( @@ -87,6 +97,7 @@ class Session: transport: BaseTransport, use_duplex: bool = None, backend_gateway: Optional[Any] = None, + assistant_id: Optional[str] = None, ): """ Initialize session. @@ -99,6 +110,7 @@ class Session: self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + self._assistant_id = str(assistant_id or "").strip() or None self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() self._history_bridge = SessionHistoryBridge( history_writer=self._backend_gateway, @@ -121,9 +133,8 @@ class Session: # Session state self.created_at = None self.state = "created" # Legacy call state for /call/lists - self.ws_state = WsSessionState.WAIT_HELLO + self.ws_state = WsSessionState.WAIT_START self._pipeline_started = False - self.protocol_version: Optional[str] = None # Track IDs self.current_track_id: str = self.TRACK_CONTROL @@ -137,7 +148,12 @@ class Session: self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.conversation.on_turn_complete(self._on_turn_complete) - logger.info(f"Session {self.id} created (duplex={self.use_duplex})") + logger.info( + "Session {} created (duplex={}, assistant_id={})", + self.id, + self.use_duplex, + self._assistant_id or "-", + ) async def handle_text(self, text_data: str) -> None: """ @@ -222,23 +238,19 @@ class Session: msg_type = message.type logger.info(f"Session {self.id} received message: {msg_type}") - if isinstance(message, HelloMessage): - await self._handle_hello(message) - return - - # All messages below require hello handshake first - if self.ws_state == WsSessionState.WAIT_HELLO: - await self._send_error( - "client", - "Expected hello message first", - "protocol.order", - ) - return - if isinstance(message, SessionStartMessage): await self._handle_session_start(message) return + # All messages below require session.start first + if self.ws_state == WsSessionState.WAIT_START: + await self._send_error( + "client", + "Expected session.start message first", + "protocol.order", + ) + return + # All messages below require active session if self.ws_state != WsSessionState.ACTIVE: await self._send_error( @@ -262,67 +274,56 @@ class Session: else: await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") - async def _handle_hello(self, message: HelloMessage) -> None: - """Handle initial hello/version negotiation.""" - if self.ws_state != WsSessionState.WAIT_HELLO: - await self._send_error("client", "Duplicate hello", "protocol.order") - return - - if message.version != settings.ws_protocol_version: - await self._send_error( - "client", - f"Unsupported protocol version '{message.version}'", - "protocol.version_unsupported", - ) - await self.transport.close() - self.ws_state = WsSessionState.STOPPED - return - - self.protocol_version = message.version - self.ws_state = WsSessionState.WAIT_START - await self._send_event( - ev( - "hello.ack", - version=self.protocol_version, - ) - ) - async def _handle_session_start(self, message: SessionStartMessage) -> None: - """Handle explicit session start after successful hello.""" + """Handle explicit session start.""" if self.ws_state != WsSessionState.WAIT_START: await self._send_error("client", "Duplicate session.start", "protocol.order") return raw_metadata = message.metadata or {} - workflow_runtime = self._bootstrap_workflow(raw_metadata) - server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime) - client_runtime = self._sanitize_client_metadata(raw_metadata) - requested_assistant_id = ( - workflow_runtime.get("assistantId") - or raw_metadata.get("assistantId") - or raw_metadata.get("appId") - or raw_metadata.get("app_id") - ) - if server_runtime: - logger.info( - "Session {} loaded trusted runtime config from backend " - "(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})", - self.id, - requested_assistant_id, - server_runtime.get("assistantId"), - server_runtime.get("configVersionId") or server_runtime.get("config_version_id"), - isinstance(server_runtime.get("services"), dict), + if not self._assistant_id: + await self._send_error( + "client", + "Missing required query parameter assistant_id", + "protocol.assistant_id_required", + stage="protocol", + retryable=False, ) - else: - logger.warning( - "Session {} missing trusted backend runtime config " - "(requested_assistant_id={}); falling back to engine defaults + safe client overrides", - self.id, - requested_assistant_id, + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata) + if metadata_error: + await self._send_error( + "client", + metadata_error["message"], + metadata_error["code"], + stage="protocol", + retryable=False, ) - metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime)) - metadata = self._merge_runtime_metadata(metadata, client_runtime) - metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id) + if runtime_error: + await self._send_error( + "server", + runtime_error["message"], + runtime_error["code"], + stage="protocol", + retryable=False, + ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {})) + for key in ("channel", "source", "history"): + if key in sanitized_metadata: + metadata[key] = sanitized_metadata[key] + metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata) if dynamic_var_error: await self._send_error( "client", @@ -331,6 +332,8 @@ class Session: stage="protocol", retryable=False, ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED return # Create history call record early so later turn callbacks can append transcripts. @@ -348,7 +351,7 @@ class Session: "(assistantId={}, configVersionId={}, output_mode={}, " "llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})", self.id, - metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"), + metadata.get("assistantId"), metadata.get("configVersionId") or metadata.get("config_version_id"), (resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None, llm_cfg.get("provider"), @@ -387,24 +390,6 @@ class Session: config=self._build_config_resolved(metadata), ) ) - if self.workflow_runner and self._workflow_initial_node: - await self._send_event( - ev( - "workflow.started", - workflowId=self.workflow_runner.workflow_id, - workflowName=self.workflow_runner.name, - nodeId=self._workflow_initial_node.id, - ) - ) - await self._send_event( - ev( - "workflow.node.entered", - workflowId=self.workflow_runner.workflow_id, - nodeId=self._workflow_initial_node.id, - nodeName=self._workflow_initial_node.name, - nodeType=self._workflow_initial_node.node_type, - ) - ) # Emit opener only after frontend has received session.started/config events. await self.pipeline.emit_initial_greeting() @@ -658,7 +643,7 @@ class Session: def _event_source(self, event_type: str) -> str: if event_type.startswith("workflow."): return "system" - if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat": + if event_type.startswith("session.") or event_type == "heartbeat": return "system" if event_type == "error": return "system" @@ -931,26 +916,41 @@ class Session: async def _load_server_runtime_metadata( self, - client_metadata: Dict[str, Any], - workflow_runtime: Dict[str, Any], - ) -> Dict[str, Any]: + assistant_id: str, + ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: """Load trusted runtime metadata from backend assistant config.""" - assistant_id = ( - workflow_runtime.get("assistantId") - or client_metadata.get("assistantId") - or client_metadata.get("appId") - or client_metadata.get("app_id") - ) - if assistant_id is None: - return {} + if not assistant_id: + return {}, { + "code": "protocol.assistant_id_required", + "message": "Missing required query parameter assistant_id", + } provider = getattr(self._backend_gateway, "fetch_assistant_config", None) if not callable(provider): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": "Assistant config backend unavailable", + } payload = await provider(str(assistant_id).strip()) + if isinstance(payload, dict): + error_code = str(payload.get("__error_code") or "").strip() + if error_code == "assistant.not_found": + return {}, { + "code": "assistant.not_found", + "message": f"Assistant not found: {assistant_id}", + } + if error_code == "assistant.config_unavailable": + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } + if not isinstance(payload, dict): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } assistant_cfg: Dict[str, Any] = {} session_start_cfg = payload.get("sessionStartMetadata") @@ -962,7 +962,10 @@ class Session: assistant_cfg = payload if not isinstance(assistant_cfg, dict): - return {} + return {}, { + "code": "assistant.config_unavailable", + "message": f"Assistant config unavailable: {assistant_id}", + } runtime: Dict[str, Any] = {} passthrough_keys = { @@ -1009,36 +1012,110 @@ class Session: if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None: runtime["configVersionId"] = runtime.get("config_version_id") - return runtime + return runtime, None - def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - """ - Sanitize untrusted metadata sources. + def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]: + if isinstance(payload, dict): + for key, value in payload.items(): + key_str = str(key) + normalized = key_str.lower().replace("_", "").replace("-", "") + if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS): + return f"{path}.{key_str}" + nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}") + if nested: + return nested + return None + if isinstance(payload, list): + for idx, value in enumerate(payload): + nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]") + if nested: + return nested + return None - This keeps only a small override whitelist and stable config ID fields. - """ + def _validate_and_sanitize_client_metadata( + self, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: if not isinstance(metadata, dict): - return {} + return {}, { + "code": "protocol.invalid_override", + "message": "metadata must be an object", + } - sanitized: Dict[str, Any] = {} - for key in self._CLIENT_METADATA_ID_KEYS: - if key in metadata: - sanitized[key] = metadata[key] - for key in self._CLIENT_METADATA_OVERRIDES: - if key in metadata: - sanitized[key] = metadata[key] + forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata] + if forbidden_top_level: + return {}, { + "code": "protocol.invalid_override", + "message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}", + } - return sanitized + unknown_keys = [ + key + for key in metadata.keys() + if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS + and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS + ] + if unknown_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}", + } - def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - """Apply client metadata whitelist and remove forbidden secrets.""" - sanitized = self._sanitize_untrusted_runtime_metadata(metadata) - if isinstance(metadata.get("services"), dict): - logger.warning( - "Session {} provided metadata.services from client; client-side service config is ignored", - self.id, - ) - return sanitized + if "workflow" in metadata: + logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id) + + overrides_raw = metadata.get("overrides") + overrides: Dict[str, Any] = {} + if overrides_raw is not None: + if not isinstance(overrides_raw, dict): + return {}, { + "code": "protocol.invalid_override", + "message": "metadata.overrides must be an object", + } + unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS] + if unsupported_override_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}", + } + overrides = dict(overrides_raw) + + dynamic_variables = metadata.get("dynamicVariables") + history_raw = metadata.get("history") + history: Dict[str, Any] = {} + if history_raw is not None: + if not isinstance(history_raw, dict): + return {}, { + "code": "protocol.invalid_override", + "message": "metadata.history must be an object", + } + unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"] + if unsupported_history_keys: + return {}, { + "code": "protocol.invalid_override", + "message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}", + } + if "userId" in history_raw: + history["userId"] = history_raw.get("userId") + + sanitized: Dict[str, Any] = {"overrides": overrides} + if dynamic_variables is not None: + sanitized["dynamicVariables"] = dynamic_variables + if "channel" in metadata: + sanitized["channel"] = metadata.get("channel") + if "source" in metadata: + sanitized["source"] = metadata.get("source") + if history: + sanitized["history"] = history + + forbidden_path = self._find_forbidden_secret_key(sanitized) + if forbidden_path: + return {}, { + "code": "protocol.invalid_override", + "message": f"Forbidden secret-like key detected at {forbidden_path}", + } + + return sanitized, None def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Build public resolved config payload (secrets removed).""" @@ -1047,7 +1124,7 @@ class Session: runtime = self.pipeline.resolved_runtime_config() return { - "appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"), + "appId": metadata.get("assistantId"), "channel": metadata.get("channel"), "configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"), "prompt": {"sha256": prompt_hash}, diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index 8cf3ebc..eb9db62 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -5,7 +5,7 @@ This document defines the public WebSocket protocol for the `/ws` endpoint. Validation policy: - WS v1 JSON control messages are validated strictly. - Unknown top-level fields are rejected for all defined client message types. -- `hello.version` is fixed to `"v1"`. +- `assistant_id` query parameter is required on `/ws`. ## Transport @@ -17,29 +17,16 @@ Validation policy: Required message order: -1. Client sends `hello`. -2. Server replies `hello.ack`. -3. Client sends `session.start`. -4. Server replies `session.started`. -5. Client may stream binary audio and/or send `input.text`. -6. Client sends `session.stop` (or closes socket). +1. Client connects to `/ws?assistant_id=`. +2. Client sends `session.start`. +3. Server replies `session.started`. +4. Client may stream binary audio and/or send `input.text`. +5. Client sends `session.stop` (or closes socket). If order is violated, server emits `error` with `code = "protocol.order"`. ## Client -> Server Messages -### `hello` - -```json -{ - "type": "hello", - "version": "v1" -} -``` - -Rules: -- `version` must be `v1`. - ### `session.start` ```json @@ -51,15 +38,18 @@ Rules: "channels": 1 }, "metadata": { - "appId": "assistant_123", "channel": "web", - "configVersionId": "cfg_20260217_01", - "client": "web-debug", - "output": { - "mode": "audio" + "source": "web-debug", + "history": { + "userId": 1 + }, + "overrides": { + "output": { + "mode": "audio" + }, + "systemPrompt": "You are concise.", + "greeting": "Hi, how can I help?" }, - "systemPrompt": "You are concise.", - "greeting": "Hi, how can I help?", "dynamicVariables": { "customer_name": "Alice", "plan_tier": "Pro" @@ -69,9 +59,13 @@ Rules: ``` Rules: -- Client-side `metadata.services` is ignored. -- Service config (including secrets) is resolved server-side (env/backend). -- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints). +- Assistant config is resolved strictly by URL query `assistant_id`. +- `metadata` top-level keys allowed: `overrides`, `dynamicVariables`, `channel`, `source`, `history`, `workflow` (`workflow` is ignored). +- `metadata.overrides` whitelist: `systemPrompt`, `greeting`, `firstTurnMode`, `generatedOpenerEnabled`, `output`, `bargeIn`, `knowledgeBaseId`, `knowledge`, `tools`, `openerAudio`. +- `metadata.services` is rejected with `protocol.invalid_override`. +- `metadata.workflow` is ignored in this MVP protocol version. +- Top-level IDs are forbidden in payload (`assistantId`, `appId`, `app_id`, `configVersionId`, `config_version_id`). +- Secret-like keys are forbidden in metadata (`apiKey`, `token`, `secret`, `password`, `authorization`). - `metadata.dynamicVariables` is optional and must be an object of string key/value pairs. - Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$` - Max entries: 30 @@ -85,7 +79,7 @@ Rules: - Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`. Text-only mode: -- Set `metadata.output.mode = "text"`. +- Set `metadata.overrides.output.mode = "text"`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. ### `input.text` @@ -158,8 +152,6 @@ Envelope notes: Common events: -- `hello.ack` - - Fields: `sessionId`, `version` - `session.started` - Fields: `sessionId`, `trackId`, `tracks`, `audio` - `config.resolved` @@ -204,7 +196,7 @@ Common events: Track IDs (MVP fixed values): - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) - `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`) -- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`) +- `control`: session/control events (`session.*`, `error`, `config.resolved`) Correlation IDs (`event.data`): - `turn_id`: one user-assistant interaction turn. diff --git a/engine/examples/mic_client.py b/engine/examples/mic_client.py index 00d403f..d734c95 100644 --- a/engine/examples/mic_client.py +++ b/engine/examples/mic_client.py @@ -23,6 +23,7 @@ import time import threading import queue from pathlib import Path +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -59,9 +60,8 @@ class MicrophoneClient: url: str, sample_rate: int = 16000, chunk_duration_ms: int = 20, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "mic_client", - config_version_id: str = "local-dev", input_device: int = None, output_device: int = None, track_debug: bool = False, @@ -80,9 +80,8 @@ class MicrophoneClient: self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.input_device = input_device self.output_device = output_device self.track_debug = track_debug @@ -125,19 +124,21 @@ class MicrophoneClient: if value: parts.append(f"{key}={value}") return f" [{' '.join(parts)}]" if parts else "" + + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) async def connect(self) -> None: """Connect to WebSocket server.""" - print(f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + print(f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True print("Connected!") - - # WS v1 handshake: hello -> session.start - await self.send_command({ - "type": "hello", - "version": "v1", - }) + await self.send_command({ "type": "session.start", "audio": { @@ -146,9 +147,8 @@ class MicrophoneClient: "channels": 1, }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "mic_client", }, }) @@ -330,7 +330,7 @@ class MicrophoneClient: if self.track_debug: print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type in {"hello.ack", "session.started"}: + if event_type == "session.started": print(f"← Session ready!{ids}") elif event_type == "config.resolved": print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}") @@ -609,20 +609,15 @@ async def main(): help="Show streaming LLM response chunks" ) parser.add_argument( - "--app-id", + "--assistant-id", default="assistant_demo", - help="Stable app/assistant identifier for server-side config lookup" + help="Assistant identifier used in websocket query parameter" ) parser.add_argument( "--channel", default="mic_client", help="Client channel name" ) - parser.add_argument( - "--config-version-id", - default="local-dev", - help="Optional config version identifier" - ) parser.add_argument( "--track-debug", action="store_true", @@ -638,9 +633,8 @@ async def main(): client = MicrophoneClient( url=args.url, sample_rate=args.sample_rate, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, input_device=args.input_device, output_device=args.output_device, track_debug=args.track_debug, diff --git a/engine/examples/simple_client.py b/engine/examples/simple_client.py index b1648bf..8a33fec 100644 --- a/engine/examples/simple_client.py +++ b/engine/examples/simple_client.py @@ -15,6 +15,7 @@ import sys import time import wave import io +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -56,16 +57,14 @@ class SimpleVoiceClient: self, url: str, sample_rate: int = 16000, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "simple_client", - config_version_id: str = "local-dev", track_debug: bool = False, ): self.url = url self.sample_rate = sample_rate - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.track_debug = track_debug self.ws = None self.running = False @@ -88,6 +87,12 @@ class SimpleVoiceClient: # Interrupt handling - discard audio until next trackStart self._discard_audio = False + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) + @staticmethod def _event_ids_suffix(event: dict) -> str: data = event.get("data") if isinstance(event.get("data"), dict) else {} @@ -101,16 +106,12 @@ class SimpleVoiceClient: async def connect(self): """Connect to server.""" - print(f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + print(f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True print("Connected!") - - # WS v1 handshake: hello -> session.start - await self.ws.send(json.dumps({ - "type": "hello", - "version": "v1", - })) + await self.ws.send(json.dumps({ "type": "session.start", "audio": { @@ -119,12 +120,11 @@ class SimpleVoiceClient: "channels": 1, }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "simple_client", }, })) - print("-> hello/session.start") + print("-> session.start") async def send_chat(self, text: str): """Send chat message.""" @@ -311,9 +311,8 @@ async def main(): parser.add_argument("--text", help="Send text and play response") parser.add_argument("--list-devices", action="store_true") parser.add_argument("--sample-rate", type=int, default=16000) - parser.add_argument("--app-id", default="assistant_demo") + parser.add_argument("--assistant-id", default="assistant_demo") parser.add_argument("--channel", default="simple_client") - parser.add_argument("--config-version-id", default="local-dev") parser.add_argument("--track-debug", action="store_true") args = parser.parse_args() @@ -325,9 +324,8 @@ async def main(): client = SimpleVoiceClient( args.url, args.sample_rate, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, track_debug=args.track_debug, ) await client.run(args.text) diff --git a/engine/examples/test_websocket.py b/engine/examples/test_websocket.py index 6717834..30f7aee 100644 --- a/engine/examples/test_websocket.py +++ b/engine/examples/test_websocket.py @@ -12,6 +12,7 @@ import math import argparse import os from datetime import datetime +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit # Configuration SERVER_URL = "ws://localhost:8000/ws" @@ -130,14 +131,17 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa """Run the WebSocket test client.""" session = aiohttp.ClientSession() try: - print(f"🔌 Connecting to {url}...") - async with session.ws_connect(url) as ws: + parts = urlsplit(url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo") + session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) + print(f"🔌 Connecting to {session_url}...") + async with session.ws_connect(session_url) as ws: print("✅ Connected!") session_ready = asyncio.Event() recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug)) - # Send v1 hello + session.start handshake - await ws.send_json({"type": "hello", "version": "v1"}) + # Send v1 session.start initialization await ws.send_json({ "type": "session.start", "audio": { @@ -146,12 +150,11 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa "channels": 1 }, "metadata": { - "appId": "assistant_demo", "channel": "test_websocket", - "configVersionId": "local-dev", + "source": "test_websocket", }, }) - print("📤 Sent v1 hello/session.start") + print("📤 Sent v1 session.start") await asyncio.wait_for(session_ready.wait(), timeout=8) # Select sender based on args diff --git a/engine/examples/wav_client.py b/engine/examples/wav_client.py index 5684256..1e4a50d 100644 --- a/engine/examples/wav_client.py +++ b/engine/examples/wav_client.py @@ -21,6 +21,7 @@ import sys import time import wave from pathlib import Path +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit try: import numpy as np @@ -57,9 +58,8 @@ class WavFileClient: url: str, input_file: str, output_file: str, - app_id: str = "assistant_demo", + assistant_id: str = "assistant_demo", channel: str = "wav_client", - config_version_id: str = "local-dev", sample_rate: int = 16000, chunk_duration_ms: int = 20, wait_time: float = 15.0, @@ -82,9 +82,8 @@ class WavFileClient: self.url = url self.input_file = Path(input_file) self.output_file = Path(output_file) - self.app_id = app_id + self.assistant_id = assistant_id self.channel = channel - self.config_version_id = config_version_id self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) @@ -147,19 +146,21 @@ class WavFileClient: if value: parts.append(f"{key}={value}") return f" [{' '.join(parts)}]" if parts else "" + + def _session_url(self) -> str: + parts = urlsplit(self.url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["assistant_id"] = self.assistant_id + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) async def connect(self) -> None: """Connect to WebSocket server.""" - self.log_event("→", f"Connecting to {self.url}...") - self.ws = await websockets.connect(self.url) + session_url = self._session_url() + self.log_event("→", f"Connecting to {session_url}...") + self.ws = await websockets.connect(session_url) self.running = True self.log_event("←", "Connected!") - # WS v1 handshake: hello -> session.start - await self.send_command({ - "type": "hello", - "version": "v1", - }) await self.send_command({ "type": "session.start", "audio": { @@ -168,9 +169,8 @@ class WavFileClient: "channels": 1 }, "metadata": { - "appId": self.app_id, "channel": self.channel, - "configVersionId": self.config_version_id, + "source": "wav_client", }, }) @@ -329,9 +329,7 @@ class WavFileClient: if self.track_debug: print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type == "hello.ack": - self.log_event("←", f"Handshake acknowledged{ids}") - elif event_type == "session.started": + if event_type == "session.started": self.session_ready = True self.log_event("←", f"Session ready!{ids}") elif event_type == "config.resolved": @@ -521,20 +519,15 @@ async def main(): help="Target sample rate for audio (default: 16000)" ) parser.add_argument( - "--app-id", + "--assistant-id", default="assistant_demo", - help="Stable app/assistant identifier for server-side config lookup" + help="Assistant identifier used in websocket query parameter" ) parser.add_argument( "--channel", default="wav_client", help="Client channel name" ) - parser.add_argument( - "--config-version-id", - default="local-dev", - help="Optional config version identifier" - ) parser.add_argument( "--chunk-duration", type=int, @@ -570,9 +563,8 @@ async def main(): url=args.url, input_file=args.input, output_file=args.output, - app_id=args.app_id, + assistant_id=args.assistant_id, channel=args.channel, - config_version_id=args.config_version_id, sample_rate=args.sample_rate, chunk_duration_ms=args.chunk_duration, wait_time=args.wait_time, diff --git a/engine/examples/web_client.html b/engine/examples/web_client.html index 3431c02..10a7556 100644 --- a/engine/examples/web_client.html +++ b/engine/examples/web_client.html @@ -401,9 +401,14 @@ const targetSampleRate = 16000; const playbackStopRampSec = 0.008; - const appId = "assistant_demo"; + const assistantId = "assistant_demo"; const channel = "web_client"; - const configVersionId = "local-dev"; + + function buildSessionWsUrl(baseUrl) { + const parsed = new URL(baseUrl); + parsed.searchParams.set("assistant_id", assistantId); + return parsed.toString(); + } function logLine(type, text, data) { const time = new Date().toLocaleTimeString(); @@ -556,14 +561,25 @@ async function connect() { if (ws && ws.readyState === WebSocket.OPEN) return; - ws = new WebSocket(wsUrl.value.trim()); + const sessionWsUrl = buildSessionWsUrl(wsUrl.value.trim()); + ws = new WebSocket(sessionWsUrl); ws.binaryType = "arraybuffer"; ws.onopen = () => { setStatus(true, "Session open"); logLine("sys", "WebSocket connected"); ensureAudioContext(); - sendCommand({ type: "hello", version: "v1" }); + sendCommand({ + type: "session.start", + audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, + metadata: { + channel, + source: "web_client", + overrides: { + output: { mode: "audio" }, + }, + }, + }); }; ws.onclose = () => { @@ -622,17 +638,6 @@ const type = event.type || "unknown"; const ids = eventIdsSuffix(event); logLine("event", `${type}${ids}`, event); - if (type === "hello.ack") { - sendCommand({ - type: "session.start", - audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, - metadata: { - appId, - channel, - configVersionId, - }, - }); - } if (type === "config.resolved") { logLine("sys", "config.resolved", event.config || {}); } diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py index 7987ff6..39bd61e 100644 --- a/engine/models/ws_v1.py +++ b/engine/models/ws_v1.py @@ -18,12 +18,6 @@ class _StrictModel(BaseModel): model_config = ConfigDict(extra="forbid") -# Client -> Server messages -class HelloMessage(_StrictModel): - type: Literal["hello"] - version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") - - class SessionStartAudio(_StrictModel): encoding: Literal["pcm_s16le"] = "pcm_s16le" sample_rate_hz: Literal[16000] = 16000 @@ -69,7 +63,6 @@ class ToolCallResultsMessage(_StrictModel): CLIENT_MESSAGE_TYPES = { - "hello": HelloMessage, "session.start": SessionStartMessage, "session.stop": SessionStopMessage, "input.text": InputTextMessage, diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 236e86a..ac60de9 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -86,7 +86,7 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl async def _capture_event(event: Dict[str, Any], priority: int = 20): events.append(event) - async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8): + async def _noop_speak(_text: str, *args, **kwargs): return None monkeypatch.setattr(pipeline, "_send_event", _capture_event) diff --git a/engine/tests/test_ws_protocol_session_start.py b/engine/tests/test_ws_protocol_session_start.py new file mode 100644 index 0000000..07ee762 --- /dev/null +++ b/engine/tests/test_ws_protocol_session_start.py @@ -0,0 +1,238 @@ +import pytest + +from core.session import Session, WsSessionState +from models.ws_v1 import SessionStartMessage, parse_client_message + + +def _session() -> Session: + session = Session.__new__(Session) + session.id = "sess_test" + session._assistant_id = "assistant_demo" + return session + + +def test_parse_client_message_rejects_hello_message(): + with pytest.raises(ValueError, match="Unknown client message type: hello"): + parse_client_message({"type": "hello", "version": "v1"}) + + +@pytest.mark.asyncio +async def test_handle_text_reports_invalid_message_for_hello(): + session = Session.__new__(Session) + session.id = "sess_invalid_hello" + session.ws_state = WsSessionState.WAIT_START + + class _Transport: + async def close(self): + return None + + session.transport = _Transport() + captured = [] + + async def _send_error(sender, message, code, **kwargs): + captured.append((sender, code, message, kwargs)) + + session._send_error = _send_error + + await session.handle_text('{"type":"hello","version":"v1"}') + assert captured + sender, code, message, _ = captured[0] + assert sender == "client" + assert code == "protocol.invalid_message" + assert "Unknown client message type: hello" in message + + +def test_validate_metadata_rejects_services_payload(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}}) + assert sanitized == {} + assert error is not None + assert error["code"] == "protocol.invalid_override" + + +def test_validate_metadata_rejects_secret_like_override_keys(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata( + { + "overrides": { + "output": {"mode": "audio"}, + "apiKey": "xxx", + } + } + ) + assert sanitized == {} + assert error is not None + assert error["code"] == "protocol.invalid_override" + + +def test_validate_metadata_ignores_workflow_payload(): + session = _session() + sanitized, error = session._validate_and_sanitize_client_metadata( + { + "workflow": {"nodes": [], "edges": []}, + "channel": "web_debug", + "overrides": {"output": {"mode": "text"}}, + } + ) + assert error is None + assert "workflow" not in sanitized + assert sanitized["channel"] == "web_debug" + assert sanitized["overrides"]["output"]["mode"] == "text" + + +@pytest.mark.asyncio +async def test_load_server_runtime_metadata_returns_not_found_error(): + session = _session() + + class _Gateway: + async def fetch_assistant_config(self, assistant_id: str): + _ = assistant_id + return {"__error_code": "assistant.not_found"} + + session._backend_gateway = _Gateway() + runtime, error = await session._load_server_runtime_metadata("assistant_demo") + assert runtime == {} + assert error is not None + assert error["code"] == "assistant.not_found" + + +@pytest.mark.asyncio +async def test_load_server_runtime_metadata_returns_config_unavailable_error(): + session = _session() + + class _Gateway: + async def fetch_assistant_config(self, assistant_id: str): + _ = assistant_id + return None + + session._backend_gateway = _Gateway() + runtime, error = await session._load_server_runtime_metadata("assistant_demo") + assert runtime == {} + assert error is not None + assert error["code"] == "assistant.config_unavailable" + + +@pytest.mark.asyncio +async def test_handle_session_start_requires_assistant_id_and_closes_transport(): + session = Session.__new__(Session) + session.id = "sess_missing_assistant" + session.ws_state = WsSessionState.WAIT_START + session._assistant_id = None + + class _Transport: + def __init__(self): + self.closed = False + + async def close(self): + self.closed = True + + transport = _Transport() + session.transport = transport + captured_codes = [] + + async def _send_error(sender, message, code, **kwargs): + _ = (sender, message, kwargs) + captured_codes.append(code) + + session._send_error = _send_error + + await session._handle_session_start(SessionStartMessage(type="session.start", metadata={})) + assert captured_codes == ["protocol.assistant_id_required"] + assert transport.closed is True + assert session.ws_state == WsSessionState.STOPPED + + +@pytest.mark.asyncio +async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow(): + session = Session.__new__(Session) + session.id = "sess_start_ok" + session.ws_state = WsSessionState.WAIT_START + session.state = "created" + session._assistant_id = "assistant_demo" + session.current_track_id = Session.TRACK_CONTROL + session._pipeline_started = False + + class _Transport: + async def close(self): + return None + + class _Pipeline: + def __init__(self): + self.started = False + self.applied = {} + self.conversation = type("Conversation", (), {"system_prompt": ""})() + + async def start(self): + self.started = True + + async def emit_initial_greeting(self): + return None + + def apply_runtime_overrides(self, metadata): + self.applied = dict(metadata) + + def resolved_runtime_config(self): + return { + "output": {"mode": "text"}, + "services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}}, + "tools": {"allowlist": []}, + } + + session.transport = _Transport() + session.pipeline = _Pipeline() + events = [] + + async def _start_history_bridge(_metadata): + return None + + async def _load_server_runtime_metadata(_assistant_id): + return ( + { + "assistantId": "assistant_demo", + "configVersionId": "cfg_1", + "systemPrompt": "Base prompt", + "greeting": "Base greeting", + "output": {"mode": "audio"}, + }, + None, + ) + + async def _send_event(event): + events.append(event) + + async def _send_error(sender, message, code, **kwargs): + raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") + + session._start_history_bridge = _start_history_bridge + session._load_server_runtime_metadata = _load_server_runtime_metadata + session._send_event = _send_event + session._send_error = _send_error + + await session._handle_session_start( + SessionStartMessage( + type="session.start", + metadata={ + "workflow": {"nodes": []}, + "channel": "web_debug", + "source": "debug_ui", + "history": {"userId": 7}, + "overrides": { + "greeting": "Override greeting", + "output": {"mode": "text"}, + "tools": [{"name": "calculator"}], + }, + }, + ) + ) + + assert session.ws_state == WsSessionState.ACTIVE + assert session.pipeline.started is True + assert session.pipeline.applied["assistantId"] == "assistant_demo" + assert session.pipeline.applied["greeting"] == "Override greeting" + assert session.pipeline.applied["output"]["mode"] == "text" + assert session.pipeline.applied["tools"] == [{"name": "calculator"}] + assert not any(str(item.get("type", "")).startswith("workflow.") for item in events) + + config_event = next(item for item in events if item.get("type") == "config.resolved") + assert config_event["config"]["appId"] == "assistant_demo" + assert config_event["config"]["channel"] == "web_debug" diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 6ec63c4..adf0dc4 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -2842,6 +2842,61 @@ export const DebugDrawer: React.FC<{ return error; }; + const METADATA_OVERRIDE_WHITELIST = new Set([ + 'firstTurnMode', + 'greeting', + 'generatedOpenerEnabled', + 'systemPrompt', + 'output', + 'bargeIn', + 'knowledge', + 'knowledgeBaseId', + 'openerAudio', + 'tools', + ]); + const METADATA_FORBIDDEN_SECRET_TOKENS = ['apikey', 'token', 'secret', 'password', 'authorization']; + const isPlainObject = (value: unknown): value is Record => Boolean(value) && typeof value === 'object' && !Array.isArray(value); + const isForbiddenSecretKey = (key: string): boolean => { + const normalized = key.toLowerCase().replace(/[_-]/g, ''); + return METADATA_FORBIDDEN_SECRET_TOKENS.some((token) => normalized.includes(token)); + }; + const stripForbiddenSecretKeysDeep = (value: any): any => { + if (Array.isArray(value)) return value.map(stripForbiddenSecretKeysDeep); + if (!isPlainObject(value)) return value; + return Object.entries(value).reduce>((acc, [key, nested]) => { + if (isForbiddenSecretKey(key)) return acc; + acc[key] = stripForbiddenSecretKeysDeep(nested); + return acc; + }, {}); + }; + const sanitizeMetadataForWs = (raw: unknown): Record => { + if (!isPlainObject(raw)) return { overrides: {} }; + const sanitized: Record = { overrides: {} }; + + if (typeof raw.channel === 'string' && raw.channel.trim()) { + sanitized.channel = raw.channel.trim(); + } + if (typeof raw.source === 'string' && raw.source.trim()) { + sanitized.source = raw.source.trim(); + } + if (isPlainObject(raw.history) && raw.history.userId !== undefined) { + sanitized.history = { userId: raw.history.userId }; + } + if (isPlainObject(raw.dynamicVariables)) { + sanitized.dynamicVariables = raw.dynamicVariables; + } + if (isPlainObject(raw.overrides)) { + const overrides = Object.entries(raw.overrides).reduce>((acc, [key, value]) => { + if (!METADATA_OVERRIDE_WHITELIST.has(key)) return acc; + if (isForbiddenSecretKey(key)) return acc; + acc[key] = stripForbiddenSecretKeysDeep(value); + return acc; + }, {}); + sanitized.overrides = overrides; + } + return sanitized; + }; + const buildDynamicVariablesPayload = (): { variables: Record; error?: string } => { const variables: Record = {}; const nonEmptyRows = dynamicVariables @@ -2908,104 +2963,35 @@ export const DebugDrawer: React.FC<{ const buildLocalResolvedRuntime = () => { const warnings: string[] = []; - const services: Record = {}; const ttsEnabled = Boolean(textTtsEnabled); - const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt'; const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim(); const knowledge = knowledgeBaseId ? { enabled: true, kbId: knowledgeBaseId, nResults: 5 } : { enabled: false }; - if (isExternalLlm) { - services.llm = { - provider: 'openai', - model: '', - apiKey: assistant.apiKey || '', - baseUrl: assistant.apiUrl || '', - }; - if (!assistant.apiUrl) warnings.push(`External LLM API URL is empty for mode: ${assistant.configMode}`); - if (!assistant.apiKey) warnings.push(`External LLM API key is empty for mode: ${assistant.configMode}`); - } else if (assistant.llmModelId) { - const llm = llmModels.find((item) => item.id === assistant.llmModelId); - if (llm) { - services.llm = { - provider: 'openai', - model: llm.modelName || llm.name, - apiKey: llm.apiKey, - baseUrl: llm.baseUrl, - }; - } else { - warnings.push(`LLM model not found in loaded list: ${assistant.llmModelId}`); - } - } else { - // Keep empty object to indicate engine should use default provider model. - services.llm = {}; - } - - if (assistant.asrModelId) { - const asr = asrModels.find((item) => item.id === assistant.asrModelId); - if (asr) { - const asrProvider = isOpenAICompatibleVendor(asr.vendor) ? 'openai_compatible' : 'buffered'; - services.asr = { - provider: asrProvider, - model: asr.modelName || asr.name, - apiKey: asrProvider === 'openai_compatible' ? asr.apiKey : null, - }; - } else { - warnings.push(`ASR model not found in loaded list: ${assistant.asrModelId}`); - } - } - - if (assistant.voice) { - const voice = voices.find((item) => item.id === assistant.voice); - if (voice) { - const ttsProvider = isOpenAICompatibleVendor(voice.vendor) ? 'openai_compatible' : 'edge'; - services.tts = { - enabled: ttsEnabled, - provider: ttsProvider, - model: voice.model, - apiKey: ttsProvider === 'openai_compatible' ? voice.apiKey : null, - voice: resolveRuntimeTtsVoice(assistant.voice, voice), - speed: assistant.speed || voice.speed || 1.0, - }; - } else { - services.tts = { - enabled: ttsEnabled, - voice: assistant.voice, - speed: assistant.speed || 1.0, - }; - warnings.push(`Voice resource not found in loaded list: ${assistant.voice}`); - } - } else if (!ttsEnabled) { - services.tts = { - enabled: false, - }; - } - const localResolved = { - assistantId: assistant.id, warnings, sessionStartMetadata: { - output: { - mode: ttsEnabled ? 'audio' : 'text', + overrides: { + output: { + mode: ttsEnabled ? 'audio' : 'text', + }, + systemPrompt: assistant.prompt || '', + firstTurnMode: assistant.firstTurnMode || 'bot_first', + greeting: assistant.opener || '', + generatedOpenerEnabled: assistant.generatedOpenerEnabled === true, + bargeIn: { + enabled: assistant.botCannotBeInterrupted !== true, + minDurationMs: Math.max(0, Number(assistant.interruptionSensitivity ?? 180)), + }, + knowledgeBaseId, + knowledge, + tools: selectedToolSchemas, }, - systemPrompt: assistant.prompt || '', - firstTurnMode: assistant.firstTurnMode || 'bot_first', - greeting: assistant.opener || '', - generatedOpenerEnabled: assistant.generatedOpenerEnabled === true, - bargeIn: { - enabled: assistant.botCannotBeInterrupted !== true, - minDurationMs: Math.max(0, Number(assistant.interruptionSensitivity ?? 180)), - }, - knowledgeBaseId, - knowledge, - tools: selectedToolSchemas, - services, history: { - assistantId: assistant.id, userId: 1, - source: 'debug', }, + source: 'web_debug', }, }; @@ -3020,21 +3006,13 @@ export const DebugDrawer: React.FC<{ } setDynamicVariablesError(''); const localResolved = buildLocalResolvedRuntime(); - const mergedMetadata: Record = { + const mergedMetadata: Record = sanitizeMetadataForWs({ ...localResolved.sessionStartMetadata, ...(sessionMetadataExtras || {}), - }; + }); if (Object.keys(dynamicVariablesResult.variables).length > 0) { mergedMetadata.dynamicVariables = dynamicVariablesResult.variables; } - // Engine resolves trusted runtime config by top-level assistant/app ID. - // Keep these IDs at metadata root so backend /assistants/{id}/config is reachable. - if (!mergedMetadata.assistantId && assistant.id) { - mergedMetadata.assistantId = assistant.id; - } - if (!mergedMetadata.appId && assistant.id) { - mergedMetadata.appId = assistant.id; - } if (!mergedMetadata.channel) { mergedMetadata.channel = 'web_debug'; } @@ -3069,6 +3047,24 @@ export const DebugDrawer: React.FC<{ if (isOpen) setWsStatus('disconnected'); }; + const buildSessionWsUrl = () => { + const base = wsUrl.trim(); + if (!base) return ''; + try { + const parsed = new URL(base); + parsed.searchParams.set('assistant_id', assistant.id); + return parsed.toString(); + } catch { + try { + const parsed = new URL(base, window.location.href); + parsed.searchParams.set('assistant_id', assistant.id); + return parsed.toString(); + } catch { + return base; + } + } + }; + const ensureWsSession = async () => { if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) { return; @@ -3083,18 +3079,25 @@ export const DebugDrawer: React.FC<{ } const metadata = await fetchRuntimeMetadata(); + const sessionWsUrl = buildSessionWsUrl(); setWsStatus('connecting'); setWsError(''); await new Promise((resolve, reject) => { pendingResolveRef.current = resolve; pendingRejectRef.current = reject; - const ws = new WebSocket(wsUrl); + const ws = new WebSocket(sessionWsUrl); ws.binaryType = 'arraybuffer'; wsRef.current = ws; ws.onopen = () => { - ws.send(JSON.stringify({ type: 'hello', version: 'v1' })); + ws.send( + JSON.stringify({ + type: 'session.start', + audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 }, + metadata, + }) + ); }; ws.onmessage = (event) => { @@ -3118,16 +3121,6 @@ export const DebugDrawer: React.FC<{ if (onProtocolEvent) { onProtocolEvent(payload); } - if (type === 'hello.ack') { - ws.send( - JSON.stringify({ - type: 'session.start', - audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 }, - metadata, - }) - ); - return; - } if (type === 'output.audio.start') { // New utterance audio starts: cancel old queued/playing audio to avoid overlap. stopPlaybackImmediately();