Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.

This commit is contained in:
Xin Wang
2026-03-01 14:10:38 +08:00
parent b4fa664d73
commit 6a46ec69f4
14 changed files with 725 additions and 424 deletions

View File

@@ -11,9 +11,11 @@ WebSocket 端点提供双向实时语音对话能力,支持音频流输入输
### 连接地址 ### 连接地址
``` ```
ws://<host>/ws ws://<host>/ws?assistant_id=<assistant_id>
``` ```
- `assistant_id` 为必填 query 参数,用于从数据库加载该助手的运行时配置。
### 传输规则 ### 传输规则
- **文本帧**JSON 格式控制消息 - **文本帧**JSON 格式控制消息
@@ -25,8 +27,6 @@ ws://<host>/ws
### 消息流程 ### 消息流程
``` ```
Client -> hello
Server <- hello.ack
Client -> session.start Client -> session.start
Server <- session.started Server <- session.started
Server <- config.resolved Server <- config.resolved
@@ -44,27 +44,9 @@ Server <- session.stopped
### 客户端 -> 服务端消息 ### 客户端 -> 服务端消息
#### 1. Handshake: `hello` #### 1. Session Start: `session.start`
客户端连接后发送的第一个消息,用于协议版本协商 客户端连接后发送的第一个消息,用于启动对话会话
```json
{
"type": "hello",
"version": "v1"
}
```
| 字段 | 类型 | 必填 | 说明 |
|---|---|---|---|
| `type` | string | 是 | 固定为 `"hello"` |
| `version` | string | 是 | 协议版本,固定为 `"v1"` |
---
#### 2. Session Start: `session.start`
握手成功后发送的第二个消息,用于启动对话会话。
```json ```json
{ {
@@ -75,13 +57,17 @@ Server <- session.stopped
"channels": 1 "channels": 1
}, },
"metadata": { "metadata": {
"appId": "assistant_123",
"channel": "web", "channel": "web",
"configVersionId": "cfg_20260217_01", "source": "web_debug",
"history": {
"userId": 1
},
"overrides": {
"systemPrompt": "你是简洁助手", "systemPrompt": "你是简洁助手",
"greeting": "你好,我能帮你什么?", "greeting": "你好,我能帮你什么?",
"output": { "output": {
"mode": "audio" "mode": "audio"
}
}, },
"dynamicVariables": { "dynamicVariables": {
"customer_name": "Alice", "customer_name": "Alice",
@@ -101,17 +87,33 @@ Server <- session.stopped
| `metadata` | object | 否 | 运行时配置 | | `metadata` | object | 否 | 运行时配置 |
**metadata 支持的字段** **metadata 支持的字段**
- `appId` / `app_id` - 应用 ID
- `channel` - 渠道标识 - `channel` - 渠道标识
- `configVersionId` / `config_version_id` - 配置版本 - `source` - 来源标识
- `systemPrompt` - 系统提示词 - `history.userId` - 历史记录用户 ID
- `greeting` - 开场白 - `overrides` - 可覆盖字段(仅限安全白名单)
- `output.mode` - 输出模式 (`audio` / `text`)
- `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符) - `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符)
**`metadata.overrides` 白名单字段**
- `systemPrompt`
- `greeting`
- `firstTurnMode`
- `generatedOpenerEnabled`
- `output`
- `bargeIn`
- `knowledgeBaseId`
- `knowledge`
- `tools`
- `openerAudio`
**限制**
- `metadata.workflow` 会被忽略(不触发 workflow 事件)
- 禁止提交 `metadata.services`
- 禁止提交 `assistantId` / `appId` / `app_id` / `configVersionId` / `config_version_id`
- 禁止提交包含密钥语义的字段(如 `apiKey` / `token` / `secret` / `password` / `authorization`
--- ---
#### 3. Text Input: `input.text` #### 2. Text Input: `input.text`
发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。 发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。
@@ -129,7 +131,7 @@ Server <- session.stopped
--- ---
#### 4. Response Cancel: `response.cancel` #### 3. Response Cancel: `response.cancel`
请求中断当前回答。 请求中断当前回答。
@@ -147,7 +149,7 @@ Server <- session.stopped
--- ---
#### 5. Tool Call Results: `tool_call.results` #### 4. Tool Call Results: `tool_call.results`
回传客户端执行的工具结果。 回传客户端执行的工具结果。
@@ -176,7 +178,7 @@ Server <- session.stopped
--- ---
#### 6. Session Stop: `session.stop` #### 5. Session Stop: `session.stop`
结束对话会话。 结束对话会话。
@@ -194,7 +196,7 @@ Server <- session.stopped
--- ---
#### 7. Binary Audio #### 6. Binary Audio
`session.started` 之后可持续发送二进制 PCM 音频。 `session.started` 之后可持续发送二进制 PCM 音频。
@@ -239,7 +241,6 @@ Server <- session.stopped
| 事件 | 说明 | | 事件 | 说明 |
|---|---| |---|---|
| `hello.ack` | 握手成功响应 |
| `session.started` | 会话启动成功 | | `session.started` | 会话启动成功 |
| `config.resolved` | 服务端最终配置快照 | | `config.resolved` | 服务端最终配置快照 |
| `heartbeat` | 保活心跳(默认 50 秒间隔) | | `heartbeat` | 保活心跳(默认 50 秒间隔) |
@@ -281,7 +282,10 @@ Server <- session.stopped
| `protocol.invalid_json` | JSON 格式错误 | | `protocol.invalid_json` | JSON 格式错误 |
| `protocol.invalid_message` | 消息格式错误 | | `protocol.invalid_message` | 消息格式错误 |
| `protocol.order` | 消息顺序错误 | | `protocol.order` | 消息顺序错误 |
| `protocol.version_unsupported` | 协议版本不支持 | | `protocol.assistant_id_required` | 缺少 `assistant_id` query 参数 |
| `protocol.invalid_override` | metadata 覆盖字段不合法 |
| `assistant.not_found` | assistant 不存在 |
| `assistant.config_unavailable` | assistant 配置不可用 |
| `audio.invalid_pcm` | PCM 数据无效 | | `audio.invalid_pcm` | PCM 数据无效 |
| `audio.frame_size_mismatch` | 音频帧大小不匹配 | | `audio.frame_size_mismatch` | 音频帧大小不匹配 |
| `server.internal` | 服务端内部错误 | | `server.internal` | 服务端内部错误 |

View File

@@ -157,16 +157,16 @@ class HttpBackendAdapter:
async with session.get(url) as resp: async with session.get(url) as resp:
if resp.status == 404: if resp.status == 404:
logger.warning(f"Assistant config not found: {assistant_id}") logger.warning(f"Assistant config not found: {assistant_id}")
return None return {"__error_code": "assistant.not_found", "assistantId": assistant_id}
resp.raise_for_status() resp.raise_for_status()
payload = await resp.json() payload = await resp.json()
if not isinstance(payload, dict): if not isinstance(payload, dict):
logger.warning("Assistant config payload is not a dict; ignoring") logger.warning("Assistant config payload is not a dict; ignoring")
return None return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id}
return payload return payload
except Exception as exc: except Exception as exc:
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
return None return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id}
async def create_call_record( async def create_call_record(
self, self,

View File

@@ -163,13 +163,19 @@ async def websocket_endpoint(websocket: WebSocket):
""" """
await websocket.accept() await websocket.accept()
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
# Create transport and session # Create transport and session
transport = SocketTransport(websocket) transport = SocketTransport(websocket)
session = Session(session_id, transport, backend_gateway=backend_gateway) session = Session(
session_id,
transport,
backend_gateway=backend_gateway,
assistant_id=assistant_id,
)
active_sessions[session_id] = session active_sessions[session_id] = session
logger.info(f"WebSocket connection established: {session_id}") logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}")
last_received_at: List[float] = [time.monotonic()] last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0] last_heartbeat_at: List[float] = [0.0]
@@ -239,16 +245,22 @@ async def webrtc_endpoint(websocket: WebSocket):
return return
await websocket.accept() await websocket.accept()
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
# Create WebRTC peer connection # Create WebRTC peer connection
pc = RTCPeerConnection() pc = RTCPeerConnection()
# Create transport and session # Create transport and session
transport = WebRtcTransport(websocket, pc) transport = WebRtcTransport(websocket, pc)
session = Session(session_id, transport, backend_gateway=backend_gateway) session = Session(
session_id,
transport,
backend_gateway=backend_gateway,
assistant_id=assistant_id,
)
active_sessions[session_id] = session active_sessions[session_id] = session
logger.info(f"WebRTC connection established: {session_id}") logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}")
last_received_at: List[float] = [time.monotonic()] last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0] last_heartbeat_at: List[float] = [0.0]

View File

@@ -21,7 +21,6 @@ from services.base import LLMMessage
from models.ws_v1 import ( from models.ws_v1 import (
parse_client_message, parse_client_message,
ev, ev,
HelloMessage,
SessionStartMessage, SessionStartMessage,
SessionStopMessage, SessionStopMessage,
InputTextMessage, InputTextMessage,
@@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"
class WsSessionState(str, Enum): class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions.""" """Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start" WAIT_START = "wait_start"
ACTIVE = "active" ACTIVE = "active"
STOPPED = "stopped" STOPPED = "stopped"
@@ -57,7 +55,15 @@ class Session:
TRACK_AUDIO_OUT = "audio_out" TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control" TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = { _METADATA_ALLOWED_TOP_LEVEL_KEYS = {
"overrides",
"dynamicVariables",
"channel",
"source",
"history",
"workflow", # explicitly ignored for this MVP protocol version
}
_METADATA_ALLOWED_OVERRIDE_KEYS = {
"firstTurnMode", "firstTurnMode",
"greeting", "greeting",
"generatedOpenerEnabled", "generatedOpenerEnabled",
@@ -67,18 +73,22 @@ class Session:
"knowledge", "knowledge",
"knowledgeBaseId", "knowledgeBaseId",
"openerAudio", "openerAudio",
"dynamicVariables", "tools",
"history",
"userId",
"assistantId",
"source",
} }
_CLIENT_METADATA_ID_KEYS = { _METADATA_FORBIDDEN_TOP_LEVEL_KEYS = {
"assistantId",
"appId", "appId",
"app_id", "app_id",
"channel",
"configVersionId", "configVersionId",
"config_version_id", "config_version_id",
"services",
}
_METADATA_FORBIDDEN_KEY_TOKENS = {
"apikey",
"token",
"secret",
"password",
"authorization",
} }
def __init__( def __init__(
@@ -87,6 +97,7 @@ class Session:
transport: BaseTransport, transport: BaseTransport,
use_duplex: bool = None, use_duplex: bool = None,
backend_gateway: Optional[Any] = None, backend_gateway: Optional[Any] = None,
assistant_id: Optional[str] = None,
): ):
""" """
Initialize session. Initialize session.
@@ -99,6 +110,7 @@ class Session:
self.id = session_id self.id = session_id
self.transport = transport self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self._assistant_id = str(assistant_id or "").strip() or None
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
self._history_bridge = SessionHistoryBridge( self._history_bridge = SessionHistoryBridge(
history_writer=self._backend_gateway, history_writer=self._backend_gateway,
@@ -121,9 +133,8 @@ class Session:
# Session state # Session state
self.created_at = None self.created_at = None
self.state = "created" # Legacy call state for /call/lists self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO self.ws_state = WsSessionState.WAIT_START
self._pipeline_started = False self._pipeline_started = False
self.protocol_version: Optional[str] = None
# Track IDs # Track IDs
self.current_track_id: str = self.TRACK_CONTROL self.current_track_id: str = self.TRACK_CONTROL
@@ -137,7 +148,12 @@ class Session:
self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete) self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})") logger.info(
"Session {} created (duplex={}, assistant_id={})",
self.id,
self.use_duplex,
self._assistant_id or "-",
)
async def handle_text(self, text_data: str) -> None: async def handle_text(self, text_data: str) -> None:
""" """
@@ -222,23 +238,19 @@ class Session:
msg_type = message.type msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}") logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, HelloMessage):
await self._handle_hello(message)
return
# All messages below require hello handshake first
if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage): if isinstance(message, SessionStartMessage):
await self._handle_session_start(message) await self._handle_session_start(message)
return return
# All messages below require session.start first
if self.ws_state == WsSessionState.WAIT_START:
await self._send_error(
"client",
"Expected session.start message first",
"protocol.order",
)
return
# All messages below require active session # All messages below require active session
if self.ws_state != WsSessionState.ACTIVE: if self.ws_state != WsSessionState.ACTIVE:
await self._send_error( await self._send_error(
@@ -262,67 +274,56 @@ class Session:
else: else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self._send_event(
ev(
"hello.ack",
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None: async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello.""" """Handle explicit session start."""
if self.ws_state != WsSessionState.WAIT_START: if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order") await self._send_error("client", "Duplicate session.start", "protocol.order")
return return
raw_metadata = message.metadata or {} raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata) if not self._assistant_id:
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime) await self._send_error(
client_runtime = self._sanitize_client_metadata(raw_metadata) "client",
requested_assistant_id = ( "Missing required query parameter assistant_id",
workflow_runtime.get("assistantId") "protocol.assistant_id_required",
or raw_metadata.get("assistantId") stage="protocol",
or raw_metadata.get("appId") retryable=False,
or raw_metadata.get("app_id")
) )
if server_runtime: await self.transport.close()
logger.info( self.ws_state = WsSessionState.STOPPED
"Session {} loaded trusted runtime config from backend " return
"(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})",
self.id, sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata)
requested_assistant_id, if metadata_error:
server_runtime.get("assistantId"), await self._send_error(
server_runtime.get("configVersionId") or server_runtime.get("config_version_id"), "client",
isinstance(server_runtime.get("services"), dict), metadata_error["message"],
metadata_error["code"],
stage="protocol",
retryable=False,
) )
else: await self.transport.close()
logger.warning( self.ws_state = WsSessionState.STOPPED
"Session {} missing trusted backend runtime config " return
"(requested_assistant_id={}); falling back to engine defaults + safe client overrides",
self.id, server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id)
requested_assistant_id, if runtime_error:
await self._send_error(
"server",
runtime_error["message"],
runtime_error["code"],
stage="protocol",
retryable=False,
) )
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime)) await self.transport.close()
metadata = self._merge_runtime_metadata(metadata, client_runtime) self.ws_state = WsSessionState.STOPPED
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata) return
metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {}))
for key in ("channel", "source", "history"):
if key in sanitized_metadata:
metadata[key] = sanitized_metadata[key]
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata)
if dynamic_var_error: if dynamic_var_error:
await self._send_error( await self._send_error(
"client", "client",
@@ -331,6 +332,8 @@ class Session:
stage="protocol", stage="protocol",
retryable=False, retryable=False,
) )
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return return
# Create history call record early so later turn callbacks can append transcripts. # Create history call record early so later turn callbacks can append transcripts.
@@ -348,7 +351,7 @@ class Session:
"(assistantId={}, configVersionId={}, output_mode={}, " "(assistantId={}, configVersionId={}, output_mode={}, "
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})", "llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
self.id, self.id,
metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"), metadata.get("assistantId"),
metadata.get("configVersionId") or metadata.get("config_version_id"), metadata.get("configVersionId") or metadata.get("config_version_id"),
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None, (resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
llm_cfg.get("provider"), llm_cfg.get("provider"),
@@ -387,24 +390,6 @@ class Session:
config=self._build_config_resolved(metadata), config=self._build_config_resolved(metadata),
) )
) )
if self.workflow_runner and self._workflow_initial_node:
await self._send_event(
ev(
"workflow.started",
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self._send_event(
ev(
"workflow.node.entered",
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
# Emit opener only after frontend has received session.started/config events. # Emit opener only after frontend has received session.started/config events.
await self.pipeline.emit_initial_greeting() await self.pipeline.emit_initial_greeting()
@@ -658,7 +643,7 @@ class Session:
def _event_source(self, event_type: str) -> str: def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."): if event_type.startswith("workflow."):
return "system" return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat": if event_type.startswith("session.") or event_type == "heartbeat":
return "system" return "system"
if event_type == "error": if event_type == "error":
return "system" return "system"
@@ -931,26 +916,41 @@ class Session:
async def _load_server_runtime_metadata( async def _load_server_runtime_metadata(
self, self,
client_metadata: Dict[str, Any], assistant_id: str,
workflow_runtime: Dict[str, Any], ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config.""" """Load trusted runtime metadata from backend assistant config."""
assistant_id = ( if not assistant_id:
workflow_runtime.get("assistantId") return {}, {
or client_metadata.get("assistantId") "code": "protocol.assistant_id_required",
or client_metadata.get("appId") "message": "Missing required query parameter assistant_id",
or client_metadata.get("app_id") }
)
if assistant_id is None:
return {}
provider = getattr(self._backend_gateway, "fetch_assistant_config", None) provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
if not callable(provider): if not callable(provider):
return {} return {}, {
"code": "assistant.config_unavailable",
"message": "Assistant config backend unavailable",
}
payload = await provider(str(assistant_id).strip()) payload = await provider(str(assistant_id).strip())
if isinstance(payload, dict):
error_code = str(payload.get("__error_code") or "").strip()
if error_code == "assistant.not_found":
return {}, {
"code": "assistant.not_found",
"message": f"Assistant not found: {assistant_id}",
}
if error_code == "assistant.config_unavailable":
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
if not isinstance(payload, dict): if not isinstance(payload, dict):
return {} return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
assistant_cfg: Dict[str, Any] = {} assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata") session_start_cfg = payload.get("sessionStartMetadata")
@@ -962,7 +962,10 @@ class Session:
assistant_cfg = payload assistant_cfg = payload
if not isinstance(assistant_cfg, dict): if not isinstance(assistant_cfg, dict):
return {} return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
runtime: Dict[str, Any] = {} runtime: Dict[str, Any] = {}
passthrough_keys = { passthrough_keys = {
@@ -1009,36 +1012,110 @@ class Session:
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None: if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id") runtime["configVersionId"] = runtime.get("config_version_id")
return runtime return runtime, None
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]:
""" if isinstance(payload, dict):
Sanitize untrusted metadata sources. for key, value in payload.items():
key_str = str(key)
normalized = key_str.lower().replace("_", "").replace("-", "")
if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS):
return f"{path}.{key_str}"
nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}")
if nested:
return nested
return None
if isinstance(payload, list):
for idx, value in enumerate(payload):
nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]")
if nested:
return nested
return None
This keeps only a small override whitelist and stable config ID fields. def _validate_and_sanitize_client_metadata(
""" self,
metadata: Dict[str, Any],
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
return {} return {}, {
"code": "protocol.invalid_override",
"message": "metadata must be an object",
}
sanitized: Dict[str, Any] = {} forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata]
for key in self._CLIENT_METADATA_ID_KEYS: if forbidden_top_level:
if key in metadata: return {}, {
sanitized[key] = metadata[key] "code": "protocol.invalid_override",
for key in self._CLIENT_METADATA_OVERRIDES: "message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}",
if key in metadata: }
sanitized[key] = metadata[key]
return sanitized unknown_keys = [
key
for key in metadata.keys()
if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS
and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS
]
if unknown_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}",
}
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: if "workflow" in metadata:
"""Apply client metadata whitelist and remove forbidden secrets.""" logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id)
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict): overrides_raw = metadata.get("overrides")
logger.warning( overrides: Dict[str, Any] = {}
"Session {} provided metadata.services from client; client-side service config is ignored", if overrides_raw is not None:
self.id, if not isinstance(overrides_raw, dict):
) return {}, {
return sanitized "code": "protocol.invalid_override",
"message": "metadata.overrides must be an object",
}
unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS]
if unsupported_override_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}",
}
overrides = dict(overrides_raw)
dynamic_variables = metadata.get("dynamicVariables")
history_raw = metadata.get("history")
history: Dict[str, Any] = {}
if history_raw is not None:
if not isinstance(history_raw, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata.history must be an object",
}
unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"]
if unsupported_history_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}",
}
if "userId" in history_raw:
history["userId"] = history_raw.get("userId")
sanitized: Dict[str, Any] = {"overrides": overrides}
if dynamic_variables is not None:
sanitized["dynamicVariables"] = dynamic_variables
if "channel" in metadata:
sanitized["channel"] = metadata.get("channel")
if "source" in metadata:
sanitized["source"] = metadata.get("source")
if history:
sanitized["history"] = history
forbidden_path = self._find_forbidden_secret_key(sanitized)
if forbidden_path:
return {}, {
"code": "protocol.invalid_override",
"message": f"Forbidden secret-like key detected at {forbidden_path}",
}
return sanitized, None
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed).""" """Build public resolved config payload (secrets removed)."""
@@ -1047,7 +1124,7 @@ class Session:
runtime = self.pipeline.resolved_runtime_config() runtime = self.pipeline.resolved_runtime_config()
return { return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"), "appId": metadata.get("assistantId"),
"channel": metadata.get("channel"), "channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"), "configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash}, "prompt": {"sha256": prompt_hash},

View File

@@ -5,7 +5,7 @@ This document defines the public WebSocket protocol for the `/ws` endpoint.
Validation policy: Validation policy:
- WS v1 JSON control messages are validated strictly. - WS v1 JSON control messages are validated strictly.
- Unknown top-level fields are rejected for all defined client message types. - Unknown top-level fields are rejected for all defined client message types.
- `hello.version` is fixed to `"v1"`. - `assistant_id` query parameter is required on `/ws`.
## Transport ## Transport
@@ -17,29 +17,16 @@ Validation policy:
Required message order: Required message order:
1. Client sends `hello`. 1. Client connects to `/ws?assistant_id=<id>`.
2. Server replies `hello.ack`. 2. Client sends `session.start`.
3. Client sends `session.start`. 3. Server replies `session.started`.
4. Server replies `session.started`. 4. Client may stream binary audio and/or send `input.text`.
5. Client may stream binary audio and/or send `input.text`. 5. Client sends `session.stop` (or closes socket).
6. Client sends `session.stop` (or closes socket).
If order is violated, server emits `error` with `code = "protocol.order"`. If order is violated, server emits `error` with `code = "protocol.order"`.
## Client -> Server Messages ## Client -> Server Messages
### `hello`
```json
{
"type": "hello",
"version": "v1"
}
```
Rules:
- `version` must be `v1`.
### `session.start` ### `session.start`
```json ```json
@@ -51,15 +38,18 @@ Rules:
"channels": 1 "channels": 1
}, },
"metadata": { "metadata": {
"appId": "assistant_123",
"channel": "web", "channel": "web",
"configVersionId": "cfg_20260217_01", "source": "web-debug",
"client": "web-debug", "history": {
"userId": 1
},
"overrides": {
"output": { "output": {
"mode": "audio" "mode": "audio"
}, },
"systemPrompt": "You are concise.", "systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?", "greeting": "Hi, how can I help?"
},
"dynamicVariables": { "dynamicVariables": {
"customer_name": "Alice", "customer_name": "Alice",
"plan_tier": "Pro" "plan_tier": "Pro"
@@ -69,9 +59,13 @@ Rules:
``` ```
Rules: Rules:
- Client-side `metadata.services` is ignored. - Assistant config is resolved strictly by URL query `assistant_id`.
- Service config (including secrets) is resolved server-side (env/backend). - `metadata` top-level keys allowed: `overrides`, `dynamicVariables`, `channel`, `source`, `history`, `workflow` (`workflow` is ignored).
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints). - `metadata.overrides` whitelist: `systemPrompt`, `greeting`, `firstTurnMode`, `generatedOpenerEnabled`, `output`, `bargeIn`, `knowledgeBaseId`, `knowledge`, `tools`, `openerAudio`.
- `metadata.services` is rejected with `protocol.invalid_override`.
- `metadata.workflow` is ignored in this MVP protocol version.
- Top-level IDs are forbidden in payload (`assistantId`, `appId`, `app_id`, `configVersionId`, `config_version_id`).
- Secret-like keys are forbidden in metadata (`apiKey`, `token`, `secret`, `password`, `authorization`).
- `metadata.dynamicVariables` is optional and must be an object of string key/value pairs. - `metadata.dynamicVariables` is optional and must be an object of string key/value pairs.
- Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$` - Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`
- Max entries: 30 - Max entries: 30
@@ -85,7 +79,7 @@ Rules:
- Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`. - Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`.
Text-only mode: Text-only mode:
- Set `metadata.output.mode = "text"`. - Set `metadata.overrides.output.mode = "text"`.
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
### `input.text` ### `input.text`
@@ -158,8 +152,6 @@ Envelope notes:
Common events: Common events:
- `hello.ack`
- Fields: `sessionId`, `version`
- `session.started` - `session.started`
- Fields: `sessionId`, `trackId`, `tracks`, `audio` - Fields: `sessionId`, `trackId`, `tracks`, `audio`
- `config.resolved` - `config.resolved`
@@ -204,7 +196,7 @@ Common events:
Track IDs (MVP fixed values): Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`) - `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`) - `control`: session/control events (`session.*`, `error`, `config.resolved`)
Correlation IDs (`event.data`): Correlation IDs (`event.data`):
- `turn_id`: one user-assistant interaction turn. - `turn_id`: one user-assistant interaction turn.

View File

@@ -23,6 +23,7 @@ import time
import threading import threading
import queue import queue
from pathlib import Path from pathlib import Path
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
try: try:
import numpy as np import numpy as np
@@ -59,9 +60,8 @@ class MicrophoneClient:
url: str, url: str,
sample_rate: int = 16000, sample_rate: int = 16000,
chunk_duration_ms: int = 20, chunk_duration_ms: int = 20,
app_id: str = "assistant_demo", assistant_id: str = "assistant_demo",
channel: str = "mic_client", channel: str = "mic_client",
config_version_id: str = "local-dev",
input_device: int = None, input_device: int = None,
output_device: int = None, output_device: int = None,
track_debug: bool = False, track_debug: bool = False,
@@ -80,9 +80,8 @@ class MicrophoneClient:
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.app_id = app_id self.assistant_id = assistant_id
self.channel = channel self.channel = channel
self.config_version_id = config_version_id
self.input_device = input_device self.input_device = input_device
self.output_device = output_device self.output_device = output_device
self.track_debug = track_debug self.track_debug = track_debug
@@ -126,18 +125,20 @@ class MicrophoneClient:
parts.append(f"{key}={value}") parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else "" return f" [{' '.join(parts)}]" if parts else ""
def _session_url(self) -> str:
parts = urlsplit(self.url)
query = dict(parse_qsl(parts.query, keep_blank_values=True))
query["assistant_id"] = self.assistant_id
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to WebSocket server.""" """Connect to WebSocket server."""
print(f"Connecting to {self.url}...") session_url = self._session_url()
self.ws = await websockets.connect(self.url) print(f"Connecting to {session_url}...")
self.ws = await websockets.connect(session_url)
self.running = True self.running = True
print("Connected!") print("Connected!")
# WS v1 handshake: hello -> session.start
await self.send_command({
"type": "hello",
"version": "v1",
})
await self.send_command({ await self.send_command({
"type": "session.start", "type": "session.start",
"audio": { "audio": {
@@ -146,9 +147,8 @@ class MicrophoneClient:
"channels": 1, "channels": 1,
}, },
"metadata": { "metadata": {
"appId": self.app_id,
"channel": self.channel, "channel": self.channel,
"configVersionId": self.config_version_id, "source": "mic_client",
}, },
}) })
@@ -330,7 +330,7 @@ class MicrophoneClient:
if self.track_debug: if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type in {"hello.ack", "session.started"}: if event_type == "session.started":
print(f"← Session ready!{ids}") print(f"← Session ready!{ids}")
elif event_type == "config.resolved": elif event_type == "config.resolved":
print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}") print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}")
@@ -609,20 +609,15 @@ async def main():
help="Show streaming LLM response chunks" help="Show streaming LLM response chunks"
) )
parser.add_argument( parser.add_argument(
"--app-id", "--assistant-id",
default="assistant_demo", default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup" help="Assistant identifier used in websocket query parameter"
) )
parser.add_argument( parser.add_argument(
"--channel", "--channel",
default="mic_client", default="mic_client",
help="Client channel name" help="Client channel name"
) )
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument( parser.add_argument(
"--track-debug", "--track-debug",
action="store_true", action="store_true",
@@ -638,9 +633,8 @@ async def main():
client = MicrophoneClient( client = MicrophoneClient(
url=args.url, url=args.url,
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
app_id=args.app_id, assistant_id=args.assistant_id,
channel=args.channel, channel=args.channel,
config_version_id=args.config_version_id,
input_device=args.input_device, input_device=args.input_device,
output_device=args.output_device, output_device=args.output_device,
track_debug=args.track_debug, track_debug=args.track_debug,

View File

@@ -15,6 +15,7 @@ import sys
import time import time
import wave import wave
import io import io
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
try: try:
import numpy as np import numpy as np
@@ -56,16 +57,14 @@ class SimpleVoiceClient:
self, self,
url: str, url: str,
sample_rate: int = 16000, sample_rate: int = 16000,
app_id: str = "assistant_demo", assistant_id: str = "assistant_demo",
channel: str = "simple_client", channel: str = "simple_client",
config_version_id: str = "local-dev",
track_debug: bool = False, track_debug: bool = False,
): ):
self.url = url self.url = url
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.app_id = app_id self.assistant_id = assistant_id
self.channel = channel self.channel = channel
self.config_version_id = config_version_id
self.track_debug = track_debug self.track_debug = track_debug
self.ws = None self.ws = None
self.running = False self.running = False
@@ -88,6 +87,12 @@ class SimpleVoiceClient:
# Interrupt handling - discard audio until next trackStart # Interrupt handling - discard audio until next trackStart
self._discard_audio = False self._discard_audio = False
def _session_url(self) -> str:
parts = urlsplit(self.url)
query = dict(parse_qsl(parts.query, keep_blank_values=True))
query["assistant_id"] = self.assistant_id
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
@staticmethod @staticmethod
def _event_ids_suffix(event: dict) -> str: def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {} data = event.get("data") if isinstance(event.get("data"), dict) else {}
@@ -101,16 +106,12 @@ class SimpleVoiceClient:
async def connect(self): async def connect(self):
"""Connect to server.""" """Connect to server."""
print(f"Connecting to {self.url}...") session_url = self._session_url()
self.ws = await websockets.connect(self.url) print(f"Connecting to {session_url}...")
self.ws = await websockets.connect(session_url)
self.running = True self.running = True
print("Connected!") print("Connected!")
# WS v1 handshake: hello -> session.start
await self.ws.send(json.dumps({
"type": "hello",
"version": "v1",
}))
await self.ws.send(json.dumps({ await self.ws.send(json.dumps({
"type": "session.start", "type": "session.start",
"audio": { "audio": {
@@ -119,12 +120,11 @@ class SimpleVoiceClient:
"channels": 1, "channels": 1,
}, },
"metadata": { "metadata": {
"appId": self.app_id,
"channel": self.channel, "channel": self.channel,
"configVersionId": self.config_version_id, "source": "simple_client",
}, },
})) }))
print("-> hello/session.start") print("-> session.start")
async def send_chat(self, text: str): async def send_chat(self, text: str):
"""Send chat message.""" """Send chat message."""
@@ -311,9 +311,8 @@ async def main():
parser.add_argument("--text", help="Send text and play response") parser.add_argument("--text", help="Send text and play response")
parser.add_argument("--list-devices", action="store_true") parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--sample-rate", type=int, default=16000) parser.add_argument("--sample-rate", type=int, default=16000)
parser.add_argument("--app-id", default="assistant_demo") parser.add_argument("--assistant-id", default="assistant_demo")
parser.add_argument("--channel", default="simple_client") parser.add_argument("--channel", default="simple_client")
parser.add_argument("--config-version-id", default="local-dev")
parser.add_argument("--track-debug", action="store_true") parser.add_argument("--track-debug", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@@ -325,9 +324,8 @@ async def main():
client = SimpleVoiceClient( client = SimpleVoiceClient(
args.url, args.url,
args.sample_rate, args.sample_rate,
app_id=args.app_id, assistant_id=args.assistant_id,
channel=args.channel, channel=args.channel,
config_version_id=args.config_version_id,
track_debug=args.track_debug, track_debug=args.track_debug,
) )
await client.run(args.text) await client.run(args.text)

View File

@@ -12,6 +12,7 @@ import math
import argparse import argparse
import os import os
from datetime import datetime from datetime import datetime
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
# Configuration # Configuration
SERVER_URL = "ws://localhost:8000/ws" SERVER_URL = "ws://localhost:8000/ws"
@@ -130,14 +131,17 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
"""Run the WebSocket test client.""" """Run the WebSocket test client."""
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
try: try:
print(f"🔌 Connecting to {url}...") parts = urlsplit(url)
async with session.ws_connect(url) as ws: query = dict(parse_qsl(parts.query, keep_blank_values=True))
query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo")
session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
print(f"🔌 Connecting to {session_url}...")
async with session.ws_connect(session_url) as ws:
print("✅ Connected!") print("✅ Connected!")
session_ready = asyncio.Event() session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug)) recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
# Send v1 hello + session.start handshake # Send v1 session.start initialization
await ws.send_json({"type": "hello", "version": "v1"})
await ws.send_json({ await ws.send_json({
"type": "session.start", "type": "session.start",
"audio": { "audio": {
@@ -146,12 +150,11 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
"channels": 1 "channels": 1
}, },
"metadata": { "metadata": {
"appId": "assistant_demo",
"channel": "test_websocket", "channel": "test_websocket",
"configVersionId": "local-dev", "source": "test_websocket",
}, },
}) })
print("📤 Sent v1 hello/session.start") print("📤 Sent v1 session.start")
await asyncio.wait_for(session_ready.wait(), timeout=8) await asyncio.wait_for(session_ready.wait(), timeout=8)
# Select sender based on args # Select sender based on args

View File

@@ -21,6 +21,7 @@ import sys
import time import time
import wave import wave
from pathlib import Path from pathlib import Path
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
try: try:
import numpy as np import numpy as np
@@ -57,9 +58,8 @@ class WavFileClient:
url: str, url: str,
input_file: str, input_file: str,
output_file: str, output_file: str,
app_id: str = "assistant_demo", assistant_id: str = "assistant_demo",
channel: str = "wav_client", channel: str = "wav_client",
config_version_id: str = "local-dev",
sample_rate: int = 16000, sample_rate: int = 16000,
chunk_duration_ms: int = 20, chunk_duration_ms: int = 20,
wait_time: float = 15.0, wait_time: float = 15.0,
@@ -82,9 +82,8 @@ class WavFileClient:
self.url = url self.url = url
self.input_file = Path(input_file) self.input_file = Path(input_file)
self.output_file = Path(output_file) self.output_file = Path(output_file)
self.app_id = app_id self.assistant_id = assistant_id
self.channel = channel self.channel = channel
self.config_version_id = config_version_id
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
@@ -148,18 +147,20 @@ class WavFileClient:
parts.append(f"{key}={value}") parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else "" return f" [{' '.join(parts)}]" if parts else ""
def _session_url(self) -> str:
parts = urlsplit(self.url)
query = dict(parse_qsl(parts.query, keep_blank_values=True))
query["assistant_id"] = self.assistant_id
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to WebSocket server.""" """Connect to WebSocket server."""
self.log_event("", f"Connecting to {self.url}...") session_url = self._session_url()
self.ws = await websockets.connect(self.url) self.log_event("", f"Connecting to {session_url}...")
self.ws = await websockets.connect(session_url)
self.running = True self.running = True
self.log_event("", "Connected!") self.log_event("", "Connected!")
# WS v1 handshake: hello -> session.start
await self.send_command({
"type": "hello",
"version": "v1",
})
await self.send_command({ await self.send_command({
"type": "session.start", "type": "session.start",
"audio": { "audio": {
@@ -168,9 +169,8 @@ class WavFileClient:
"channels": 1 "channels": 1
}, },
"metadata": { "metadata": {
"appId": self.app_id,
"channel": self.channel, "channel": self.channel,
"configVersionId": self.config_version_id, "source": "wav_client",
}, },
}) })
@@ -329,9 +329,7 @@ class WavFileClient:
if self.track_debug: if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "hello.ack": if event_type == "session.started":
self.log_event("", f"Handshake acknowledged{ids}")
elif event_type == "session.started":
self.session_ready = True self.session_ready = True
self.log_event("", f"Session ready!{ids}") self.log_event("", f"Session ready!{ids}")
elif event_type == "config.resolved": elif event_type == "config.resolved":
@@ -521,20 +519,15 @@ async def main():
help="Target sample rate for audio (default: 16000)" help="Target sample rate for audio (default: 16000)"
) )
parser.add_argument( parser.add_argument(
"--app-id", "--assistant-id",
default="assistant_demo", default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup" help="Assistant identifier used in websocket query parameter"
) )
parser.add_argument( parser.add_argument(
"--channel", "--channel",
default="wav_client", default="wav_client",
help="Client channel name" help="Client channel name"
) )
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument( parser.add_argument(
"--chunk-duration", "--chunk-duration",
type=int, type=int,
@@ -570,9 +563,8 @@ async def main():
url=args.url, url=args.url,
input_file=args.input, input_file=args.input,
output_file=args.output, output_file=args.output,
app_id=args.app_id, assistant_id=args.assistant_id,
channel=args.channel, channel=args.channel,
config_version_id=args.config_version_id,
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration, chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time, wait_time=args.wait_time,

View File

@@ -401,9 +401,14 @@
const targetSampleRate = 16000; const targetSampleRate = 16000;
const playbackStopRampSec = 0.008; const playbackStopRampSec = 0.008;
const appId = "assistant_demo"; const assistantId = "assistant_demo";
const channel = "web_client"; const channel = "web_client";
const configVersionId = "local-dev";
function buildSessionWsUrl(baseUrl) {
const parsed = new URL(baseUrl);
parsed.searchParams.set("assistant_id", assistantId);
return parsed.toString();
}
function logLine(type, text, data) { function logLine(type, text, data) {
const time = new Date().toLocaleTimeString(); const time = new Date().toLocaleTimeString();
@@ -556,14 +561,25 @@
async function connect() { async function connect() {
if (ws && ws.readyState === WebSocket.OPEN) return; if (ws && ws.readyState === WebSocket.OPEN) return;
ws = new WebSocket(wsUrl.value.trim()); const sessionWsUrl = buildSessionWsUrl(wsUrl.value.trim());
ws = new WebSocket(sessionWsUrl);
ws.binaryType = "arraybuffer"; ws.binaryType = "arraybuffer";
ws.onopen = () => { ws.onopen = () => {
setStatus(true, "Session open"); setStatus(true, "Session open");
logLine("sys", "WebSocket connected"); logLine("sys", "WebSocket connected");
ensureAudioContext(); ensureAudioContext();
sendCommand({ type: "hello", version: "v1" }); sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
metadata: {
channel,
source: "web_client",
overrides: {
output: { mode: "audio" },
},
},
});
}; };
ws.onclose = () => { ws.onclose = () => {
@@ -622,17 +638,6 @@
const type = event.type || "unknown"; const type = event.type || "unknown";
const ids = eventIdsSuffix(event); const ids = eventIdsSuffix(event);
logLine("event", `${type}${ids}`, event); logLine("event", `${type}${ids}`, event);
if (type === "hello.ack") {
sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
metadata: {
appId,
channel,
configVersionId,
},
});
}
if (type === "config.resolved") { if (type === "config.resolved") {
logLine("sys", "config.resolved", event.config || {}); logLine("sys", "config.resolved", event.config || {});
} }

View File

@@ -18,12 +18,6 @@ class _StrictModel(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
# Client -> Server messages
class HelloMessage(_StrictModel):
type: Literal["hello"]
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
class SessionStartAudio(_StrictModel): class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le" encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000 sample_rate_hz: Literal[16000] = 16000
@@ -69,7 +63,6 @@ class ToolCallResultsMessage(_StrictModel):
CLIENT_MESSAGE_TYPES = { CLIENT_MESSAGE_TYPES = {
"hello": HelloMessage,
"session.start": SessionStartMessage, "session.start": SessionStartMessage,
"session.stop": SessionStopMessage, "session.stop": SessionStopMessage,
"input.text": InputTextMessage, "input.text": InputTextMessage,

View File

@@ -86,7 +86,7 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl
async def _capture_event(event: Dict[str, Any], priority: int = 20): async def _capture_event(event: Dict[str, Any], priority: int = 20):
events.append(event) events.append(event)
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8): async def _noop_speak(_text: str, *args, **kwargs):
return None return None
monkeypatch.setattr(pipeline, "_send_event", _capture_event) monkeypatch.setattr(pipeline, "_send_event", _capture_event)

View File

@@ -0,0 +1,238 @@
import pytest
from core.session import Session, WsSessionState
from models.ws_v1 import SessionStartMessage, parse_client_message
def _session() -> Session:
session = Session.__new__(Session)
session.id = "sess_test"
session._assistant_id = "assistant_demo"
return session
def test_parse_client_message_rejects_hello_message():
with pytest.raises(ValueError, match="Unknown client message type: hello"):
parse_client_message({"type": "hello", "version": "v1"})
@pytest.mark.asyncio
async def test_handle_text_reports_invalid_message_for_hello():
session = Session.__new__(Session)
session.id = "sess_invalid_hello"
session.ws_state = WsSessionState.WAIT_START
class _Transport:
async def close(self):
return None
session.transport = _Transport()
captured = []
async def _send_error(sender, message, code, **kwargs):
captured.append((sender, code, message, kwargs))
session._send_error = _send_error
await session.handle_text('{"type":"hello","version":"v1"}')
assert captured
sender, code, message, _ = captured[0]
assert sender == "client"
assert code == "protocol.invalid_message"
assert "Unknown client message type: hello" in message
def test_validate_metadata_rejects_services_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_rejects_secret_like_override_keys():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"overrides": {
"output": {"mode": "audio"},
"apiKey": "xxx",
}
}
)
assert sanitized == {}
assert error is not None
assert error["code"] == "protocol.invalid_override"
def test_validate_metadata_ignores_workflow_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata(
{
"workflow": {"nodes": [], "edges": []},
"channel": "web_debug",
"overrides": {"output": {"mode": "text"}},
}
)
assert error is None
assert "workflow" not in sanitized
assert sanitized["channel"] == "web_debug"
assert sanitized["overrides"]["output"]["mode"] == "text"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_not_found_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return {"__error_code": "assistant.not_found"}
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.not_found"
@pytest.mark.asyncio
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
session = _session()
class _Gateway:
async def fetch_assistant_config(self, assistant_id: str):
_ = assistant_id
return None
session._backend_gateway = _Gateway()
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
assert runtime == {}
assert error is not None
assert error["code"] == "assistant.config_unavailable"
@pytest.mark.asyncio
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
session = Session.__new__(Session)
session.id = "sess_missing_assistant"
session.ws_state = WsSessionState.WAIT_START
session._assistant_id = None
class _Transport:
def __init__(self):
self.closed = False
async def close(self):
self.closed = True
transport = _Transport()
session.transport = transport
captured_codes = []
async def _send_error(sender, message, code, **kwargs):
_ = (sender, message, kwargs)
captured_codes.append(code)
session._send_error = _send_error
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
assert captured_codes == ["protocol.assistant_id_required"]
assert transport.closed is True
assert session.ws_state == WsSessionState.STOPPED
@pytest.mark.asyncio
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow():
session = Session.__new__(Session)
session.id = "sess_start_ok"
session.ws_state = WsSessionState.WAIT_START
session.state = "created"
session._assistant_id = "assistant_demo"
session.current_track_id = Session.TRACK_CONTROL
session._pipeline_started = False
class _Transport:
async def close(self):
return None
class _Pipeline:
def __init__(self):
self.started = False
self.applied = {}
self.conversation = type("Conversation", (), {"system_prompt": ""})()
async def start(self):
self.started = True
async def emit_initial_greeting(self):
return None
def apply_runtime_overrides(self, metadata):
self.applied = dict(metadata)
def resolved_runtime_config(self):
return {
"output": {"mode": "text"},
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
"tools": {"allowlist": []},
}
session.transport = _Transport()
session.pipeline = _Pipeline()
events = []
async def _start_history_bridge(_metadata):
return None
async def _load_server_runtime_metadata(_assistant_id):
return (
{
"assistantId": "assistant_demo",
"configVersionId": "cfg_1",
"systemPrompt": "Base prompt",
"greeting": "Base greeting",
"output": {"mode": "audio"},
},
None,
)
async def _send_event(event):
events.append(event)
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._start_history_bridge = _start_history_bridge
session._load_server_runtime_metadata = _load_server_runtime_metadata
session._send_event = _send_event
session._send_error = _send_error
await session._handle_session_start(
SessionStartMessage(
type="session.start",
metadata={
"workflow": {"nodes": []},
"channel": "web_debug",
"source": "debug_ui",
"history": {"userId": 7},
"overrides": {
"greeting": "Override greeting",
"output": {"mode": "text"},
"tools": [{"name": "calculator"}],
},
},
)
)
assert session.ws_state == WsSessionState.ACTIVE
assert session.pipeline.started is True
assert session.pipeline.applied["assistantId"] == "assistant_demo"
assert session.pipeline.applied["greeting"] == "Override greeting"
assert session.pipeline.applied["output"]["mode"] == "text"
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
config_event = next(item for item in events if item.get("type") == "config.resolved")
assert config_event["config"]["appId"] == "assistant_demo"
assert config_event["config"]["channel"] == "web_debug"

View File

@@ -2842,6 +2842,61 @@ export const DebugDrawer: React.FC<{
return error; return error;
}; };
const METADATA_OVERRIDE_WHITELIST = new Set([
'firstTurnMode',
'greeting',
'generatedOpenerEnabled',
'systemPrompt',
'output',
'bargeIn',
'knowledge',
'knowledgeBaseId',
'openerAudio',
'tools',
]);
const METADATA_FORBIDDEN_SECRET_TOKENS = ['apikey', 'token', 'secret', 'password', 'authorization'];
const isPlainObject = (value: unknown): value is Record<string, any> => Boolean(value) && typeof value === 'object' && !Array.isArray(value);
const isForbiddenSecretKey = (key: string): boolean => {
const normalized = key.toLowerCase().replace(/[_-]/g, '');
return METADATA_FORBIDDEN_SECRET_TOKENS.some((token) => normalized.includes(token));
};
const stripForbiddenSecretKeysDeep = (value: any): any => {
if (Array.isArray(value)) return value.map(stripForbiddenSecretKeysDeep);
if (!isPlainObject(value)) return value;
return Object.entries(value).reduce<Record<string, any>>((acc, [key, nested]) => {
if (isForbiddenSecretKey(key)) return acc;
acc[key] = stripForbiddenSecretKeysDeep(nested);
return acc;
}, {});
};
const sanitizeMetadataForWs = (raw: unknown): Record<string, any> => {
if (!isPlainObject(raw)) return { overrides: {} };
const sanitized: Record<string, any> = { overrides: {} };
if (typeof raw.channel === 'string' && raw.channel.trim()) {
sanitized.channel = raw.channel.trim();
}
if (typeof raw.source === 'string' && raw.source.trim()) {
sanitized.source = raw.source.trim();
}
if (isPlainObject(raw.history) && raw.history.userId !== undefined) {
sanitized.history = { userId: raw.history.userId };
}
if (isPlainObject(raw.dynamicVariables)) {
sanitized.dynamicVariables = raw.dynamicVariables;
}
if (isPlainObject(raw.overrides)) {
const overrides = Object.entries(raw.overrides).reduce<Record<string, any>>((acc, [key, value]) => {
if (!METADATA_OVERRIDE_WHITELIST.has(key)) return acc;
if (isForbiddenSecretKey(key)) return acc;
acc[key] = stripForbiddenSecretKeysDeep(value);
return acc;
}, {});
sanitized.overrides = overrides;
}
return sanitized;
};
const buildDynamicVariablesPayload = (): { variables: Record<string, string>; error?: string } => { const buildDynamicVariablesPayload = (): { variables: Record<string, string>; error?: string } => {
const variables: Record<string, string> = {}; const variables: Record<string, string> = {};
const nonEmptyRows = dynamicVariables const nonEmptyRows = dynamicVariables
@@ -2908,84 +2963,16 @@ export const DebugDrawer: React.FC<{
const buildLocalResolvedRuntime = () => { const buildLocalResolvedRuntime = () => {
const warnings: string[] = []; const warnings: string[] = [];
const services: Record<string, any> = {};
const ttsEnabled = Boolean(textTtsEnabled); const ttsEnabled = Boolean(textTtsEnabled);
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim(); const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
const knowledge = knowledgeBaseId const knowledge = knowledgeBaseId
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 } ? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
: { enabled: false }; : { enabled: false };
if (isExternalLlm) {
services.llm = {
provider: 'openai',
model: '',
apiKey: assistant.apiKey || '',
baseUrl: assistant.apiUrl || '',
};
if (!assistant.apiUrl) warnings.push(`External LLM API URL is empty for mode: ${assistant.configMode}`);
if (!assistant.apiKey) warnings.push(`External LLM API key is empty for mode: ${assistant.configMode}`);
} else if (assistant.llmModelId) {
const llm = llmModels.find((item) => item.id === assistant.llmModelId);
if (llm) {
services.llm = {
provider: 'openai',
model: llm.modelName || llm.name,
apiKey: llm.apiKey,
baseUrl: llm.baseUrl,
};
} else {
warnings.push(`LLM model not found in loaded list: ${assistant.llmModelId}`);
}
} else {
// Keep empty object to indicate engine should use default provider model.
services.llm = {};
}
if (assistant.asrModelId) {
const asr = asrModels.find((item) => item.id === assistant.asrModelId);
if (asr) {
const asrProvider = isOpenAICompatibleVendor(asr.vendor) ? 'openai_compatible' : 'buffered';
services.asr = {
provider: asrProvider,
model: asr.modelName || asr.name,
apiKey: asrProvider === 'openai_compatible' ? asr.apiKey : null,
};
} else {
warnings.push(`ASR model not found in loaded list: ${assistant.asrModelId}`);
}
}
if (assistant.voice) {
const voice = voices.find((item) => item.id === assistant.voice);
if (voice) {
const ttsProvider = isOpenAICompatibleVendor(voice.vendor) ? 'openai_compatible' : 'edge';
services.tts = {
enabled: ttsEnabled,
provider: ttsProvider,
model: voice.model,
apiKey: ttsProvider === 'openai_compatible' ? voice.apiKey : null,
voice: resolveRuntimeTtsVoice(assistant.voice, voice),
speed: assistant.speed || voice.speed || 1.0,
};
} else {
services.tts = {
enabled: ttsEnabled,
voice: assistant.voice,
speed: assistant.speed || 1.0,
};
warnings.push(`Voice resource not found in loaded list: ${assistant.voice}`);
}
} else if (!ttsEnabled) {
services.tts = {
enabled: false,
};
}
const localResolved = { const localResolved = {
assistantId: assistant.id,
warnings, warnings,
sessionStartMetadata: { sessionStartMetadata: {
overrides: {
output: { output: {
mode: ttsEnabled ? 'audio' : 'text', mode: ttsEnabled ? 'audio' : 'text',
}, },
@@ -3000,12 +2987,11 @@ export const DebugDrawer: React.FC<{
knowledgeBaseId, knowledgeBaseId,
knowledge, knowledge,
tools: selectedToolSchemas, tools: selectedToolSchemas,
services,
history: {
assistantId: assistant.id,
userId: 1,
source: 'debug',
}, },
history: {
userId: 1,
},
source: 'web_debug',
}, },
}; };
@@ -3020,21 +3006,13 @@ export const DebugDrawer: React.FC<{
} }
setDynamicVariablesError(''); setDynamicVariablesError('');
const localResolved = buildLocalResolvedRuntime(); const localResolved = buildLocalResolvedRuntime();
const mergedMetadata: Record<string, any> = { const mergedMetadata: Record<string, any> = sanitizeMetadataForWs({
...localResolved.sessionStartMetadata, ...localResolved.sessionStartMetadata,
...(sessionMetadataExtras || {}), ...(sessionMetadataExtras || {}),
}; });
if (Object.keys(dynamicVariablesResult.variables).length > 0) { if (Object.keys(dynamicVariablesResult.variables).length > 0) {
mergedMetadata.dynamicVariables = dynamicVariablesResult.variables; mergedMetadata.dynamicVariables = dynamicVariablesResult.variables;
} }
// Engine resolves trusted runtime config by top-level assistant/app ID.
// Keep these IDs at metadata root so backend /assistants/{id}/config is reachable.
if (!mergedMetadata.assistantId && assistant.id) {
mergedMetadata.assistantId = assistant.id;
}
if (!mergedMetadata.appId && assistant.id) {
mergedMetadata.appId = assistant.id;
}
if (!mergedMetadata.channel) { if (!mergedMetadata.channel) {
mergedMetadata.channel = 'web_debug'; mergedMetadata.channel = 'web_debug';
} }
@@ -3069,6 +3047,24 @@ export const DebugDrawer: React.FC<{
if (isOpen) setWsStatus('disconnected'); if (isOpen) setWsStatus('disconnected');
}; };
const buildSessionWsUrl = () => {
const base = wsUrl.trim();
if (!base) return '';
try {
const parsed = new URL(base);
parsed.searchParams.set('assistant_id', assistant.id);
return parsed.toString();
} catch {
try {
const parsed = new URL(base, window.location.href);
parsed.searchParams.set('assistant_id', assistant.id);
return parsed.toString();
} catch {
return base;
}
}
};
const ensureWsSession = async () => { const ensureWsSession = async () => {
if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) { if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) {
return; return;
@@ -3083,18 +3079,25 @@ export const DebugDrawer: React.FC<{
} }
const metadata = await fetchRuntimeMetadata(); const metadata = await fetchRuntimeMetadata();
const sessionWsUrl = buildSessionWsUrl();
setWsStatus('connecting'); setWsStatus('connecting');
setWsError(''); setWsError('');
await new Promise<void>((resolve, reject) => { await new Promise<void>((resolve, reject) => {
pendingResolveRef.current = resolve; pendingResolveRef.current = resolve;
pendingRejectRef.current = reject; pendingRejectRef.current = reject;
const ws = new WebSocket(wsUrl); const ws = new WebSocket(sessionWsUrl);
ws.binaryType = 'arraybuffer'; ws.binaryType = 'arraybuffer';
wsRef.current = ws; wsRef.current = ws;
ws.onopen = () => { ws.onopen = () => {
ws.send(JSON.stringify({ type: 'hello', version: 'v1' })); ws.send(
JSON.stringify({
type: 'session.start',
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
metadata,
})
);
}; };
ws.onmessage = (event) => { ws.onmessage = (event) => {
@@ -3118,16 +3121,6 @@ export const DebugDrawer: React.FC<{
if (onProtocolEvent) { if (onProtocolEvent) {
onProtocolEvent(payload); onProtocolEvent(payload);
} }
if (type === 'hello.ack') {
ws.send(
JSON.stringify({
type: 'session.start',
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
metadata,
})
);
return;
}
if (type === 'output.audio.start') { if (type === 'output.audio.start') {
// New utterance audio starts: cancel old queued/playing audio to avoid overlap. // New utterance audio starts: cancel old queued/playing audio to avoid overlap.
stopPlaybackImmediately(); stopPlaybackImmediately();