Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.
This commit is contained in:
@@ -11,9 +11,11 @@ WebSocket 端点提供双向实时语音对话能力,支持音频流输入输
|
|||||||
### 连接地址
|
### 连接地址
|
||||||
|
|
||||||
```
|
```
|
||||||
ws://<host>/ws
|
ws://<host>/ws?assistant_id=<assistant_id>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- `assistant_id` 为必填 query 参数,用于从数据库加载该助手的运行时配置。
|
||||||
|
|
||||||
### 传输规则
|
### 传输规则
|
||||||
|
|
||||||
- **文本帧**:JSON 格式控制消息
|
- **文本帧**:JSON 格式控制消息
|
||||||
@@ -25,8 +27,6 @@ ws://<host>/ws
|
|||||||
### 消息流程
|
### 消息流程
|
||||||
|
|
||||||
```
|
```
|
||||||
Client -> hello
|
|
||||||
Server <- hello.ack
|
|
||||||
Client -> session.start
|
Client -> session.start
|
||||||
Server <- session.started
|
Server <- session.started
|
||||||
Server <- config.resolved
|
Server <- config.resolved
|
||||||
@@ -44,27 +44,9 @@ Server <- session.stopped
|
|||||||
|
|
||||||
### 客户端 -> 服务端消息
|
### 客户端 -> 服务端消息
|
||||||
|
|
||||||
#### 1. Handshake: `hello`
|
#### 1. Session Start: `session.start`
|
||||||
|
|
||||||
客户端连接后发送的第一个消息,用于协议版本协商。
|
客户端连接后发送的第一个消息,用于启动对话会话。
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "hello",
|
|
||||||
"version": "v1"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
| 字段 | 类型 | 必填 | 说明 |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `type` | string | 是 | 固定为 `"hello"` |
|
|
||||||
| `version` | string | 是 | 协议版本,固定为 `"v1"` |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 2. Session Start: `session.start`
|
|
||||||
|
|
||||||
握手成功后发送的第二个消息,用于启动对话会话。
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -75,13 +57,17 @@ Server <- session.stopped
|
|||||||
"channels": 1
|
"channels": 1
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": "assistant_123",
|
|
||||||
"channel": "web",
|
"channel": "web",
|
||||||
"configVersionId": "cfg_20260217_01",
|
"source": "web_debug",
|
||||||
|
"history": {
|
||||||
|
"userId": 1
|
||||||
|
},
|
||||||
|
"overrides": {
|
||||||
"systemPrompt": "你是简洁助手",
|
"systemPrompt": "你是简洁助手",
|
||||||
"greeting": "你好,我能帮你什么?",
|
"greeting": "你好,我能帮你什么?",
|
||||||
"output": {
|
"output": {
|
||||||
"mode": "audio"
|
"mode": "audio"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"dynamicVariables": {
|
"dynamicVariables": {
|
||||||
"customer_name": "Alice",
|
"customer_name": "Alice",
|
||||||
@@ -101,17 +87,33 @@ Server <- session.stopped
|
|||||||
| `metadata` | object | 否 | 运行时配置 |
|
| `metadata` | object | 否 | 运行时配置 |
|
||||||
|
|
||||||
**metadata 支持的字段**:
|
**metadata 支持的字段**:
|
||||||
- `appId` / `app_id` - 应用 ID
|
|
||||||
- `channel` - 渠道标识
|
- `channel` - 渠道标识
|
||||||
- `configVersionId` / `config_version_id` - 配置版本
|
- `source` - 来源标识
|
||||||
- `systemPrompt` - 系统提示词
|
- `history.userId` - 历史记录用户 ID
|
||||||
- `greeting` - 开场白
|
- `overrides` - 可覆盖字段(仅限安全白名单)
|
||||||
- `output.mode` - 输出模式 (`audio` / `text`)
|
|
||||||
- `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符)
|
- `dynamicVariables` - 动态变量(支持 `{{variable}}` 占位符)
|
||||||
|
|
||||||
|
**`metadata.overrides` 白名单字段**:
|
||||||
|
- `systemPrompt`
|
||||||
|
- `greeting`
|
||||||
|
- `firstTurnMode`
|
||||||
|
- `generatedOpenerEnabled`
|
||||||
|
- `output`
|
||||||
|
- `bargeIn`
|
||||||
|
- `knowledgeBaseId`
|
||||||
|
- `knowledge`
|
||||||
|
- `tools`
|
||||||
|
- `openerAudio`
|
||||||
|
|
||||||
|
**限制**:
|
||||||
|
- `metadata.workflow` 会被忽略(不触发 workflow 事件)
|
||||||
|
- 禁止提交 `metadata.services`
|
||||||
|
- 禁止提交 `assistantId` / `appId` / `app_id` / `configVersionId` / `config_version_id`
|
||||||
|
- 禁止提交包含密钥语义的字段(如 `apiKey` / `token` / `secret` / `password` / `authorization`)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 3. Text Input: `input.text`
|
#### 2. Text Input: `input.text`
|
||||||
|
|
||||||
发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。
|
发送文本输入,跳过 ASR 识别,直接触发 LLM 回复。
|
||||||
|
|
||||||
@@ -129,7 +131,7 @@ Server <- session.stopped
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 4. Response Cancel: `response.cancel`
|
#### 3. Response Cancel: `response.cancel`
|
||||||
|
|
||||||
请求中断当前回答。
|
请求中断当前回答。
|
||||||
|
|
||||||
@@ -147,7 +149,7 @@ Server <- session.stopped
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 5. Tool Call Results: `tool_call.results`
|
#### 4. Tool Call Results: `tool_call.results`
|
||||||
|
|
||||||
回传客户端执行的工具结果。
|
回传客户端执行的工具结果。
|
||||||
|
|
||||||
@@ -176,7 +178,7 @@ Server <- session.stopped
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 6. Session Stop: `session.stop`
|
#### 5. Session Stop: `session.stop`
|
||||||
|
|
||||||
结束对话会话。
|
结束对话会话。
|
||||||
|
|
||||||
@@ -194,7 +196,7 @@ Server <- session.stopped
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 7. Binary Audio
|
#### 6. Binary Audio
|
||||||
|
|
||||||
在 `session.started` 之后可持续发送二进制 PCM 音频。
|
在 `session.started` 之后可持续发送二进制 PCM 音频。
|
||||||
|
|
||||||
@@ -239,7 +241,6 @@ Server <- session.stopped
|
|||||||
|
|
||||||
| 事件 | 说明 |
|
| 事件 | 说明 |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `hello.ack` | 握手成功响应 |
|
|
||||||
| `session.started` | 会话启动成功 |
|
| `session.started` | 会话启动成功 |
|
||||||
| `config.resolved` | 服务端最终配置快照 |
|
| `config.resolved` | 服务端最终配置快照 |
|
||||||
| `heartbeat` | 保活心跳(默认 50 秒间隔) |
|
| `heartbeat` | 保活心跳(默认 50 秒间隔) |
|
||||||
@@ -281,7 +282,10 @@ Server <- session.stopped
|
|||||||
| `protocol.invalid_json` | JSON 格式错误 |
|
| `protocol.invalid_json` | JSON 格式错误 |
|
||||||
| `protocol.invalid_message` | 消息格式错误 |
|
| `protocol.invalid_message` | 消息格式错误 |
|
||||||
| `protocol.order` | 消息顺序错误 |
|
| `protocol.order` | 消息顺序错误 |
|
||||||
| `protocol.version_unsupported` | 协议版本不支持 |
|
| `protocol.assistant_id_required` | 缺少 `assistant_id` query 参数 |
|
||||||
|
| `protocol.invalid_override` | metadata 覆盖字段不合法 |
|
||||||
|
| `assistant.not_found` | assistant 不存在 |
|
||||||
|
| `assistant.config_unavailable` | assistant 配置不可用 |
|
||||||
| `audio.invalid_pcm` | PCM 数据无效 |
|
| `audio.invalid_pcm` | PCM 数据无效 |
|
||||||
| `audio.frame_size_mismatch` | 音频帧大小不匹配 |
|
| `audio.frame_size_mismatch` | 音频帧大小不匹配 |
|
||||||
| `server.internal` | 服务端内部错误 |
|
| `server.internal` | 服务端内部错误 |
|
||||||
|
|||||||
@@ -157,16 +157,16 @@ class HttpBackendAdapter:
|
|||||||
async with session.get(url) as resp:
|
async with session.get(url) as resp:
|
||||||
if resp.status == 404:
|
if resp.status == 404:
|
||||||
logger.warning(f"Assistant config not found: {assistant_id}")
|
logger.warning(f"Assistant config not found: {assistant_id}")
|
||||||
return None
|
return {"__error_code": "assistant.not_found", "assistantId": assistant_id}
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
payload = await resp.json()
|
payload = await resp.json()
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
logger.warning("Assistant config payload is not a dict; ignoring")
|
logger.warning("Assistant config payload is not a dict; ignoring")
|
||||||
return None
|
return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id}
|
||||||
return payload
|
return payload
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
|
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
|
||||||
return None
|
return {"__error_code": "assistant.config_unavailable", "assistantId": assistant_id}
|
||||||
|
|
||||||
async def create_call_record(
|
async def create_call_record(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -163,13 +163,19 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
"""
|
"""
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
|
||||||
|
|
||||||
# Create transport and session
|
# Create transport and session
|
||||||
transport = SocketTransport(websocket)
|
transport = SocketTransport(websocket)
|
||||||
session = Session(session_id, transport, backend_gateway=backend_gateway)
|
session = Session(
|
||||||
|
session_id,
|
||||||
|
transport,
|
||||||
|
backend_gateway=backend_gateway,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
)
|
||||||
active_sessions[session_id] = session
|
active_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info(f"WebSocket connection established: {session_id}")
|
logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}")
|
||||||
|
|
||||||
last_received_at: List[float] = [time.monotonic()]
|
last_received_at: List[float] = [time.monotonic()]
|
||||||
last_heartbeat_at: List[float] = [0.0]
|
last_heartbeat_at: List[float] = [0.0]
|
||||||
@@ -239,16 +245,22 @@ async def webrtc_endpoint(websocket: WebSocket):
|
|||||||
return
|
return
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
|
||||||
|
|
||||||
# Create WebRTC peer connection
|
# Create WebRTC peer connection
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
# Create transport and session
|
# Create transport and session
|
||||||
transport = WebRtcTransport(websocket, pc)
|
transport = WebRtcTransport(websocket, pc)
|
||||||
session = Session(session_id, transport, backend_gateway=backend_gateway)
|
session = Session(
|
||||||
|
session_id,
|
||||||
|
transport,
|
||||||
|
backend_gateway=backend_gateway,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
)
|
||||||
active_sessions[session_id] = session
|
active_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info(f"WebRTC connection established: {session_id}")
|
logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}")
|
||||||
|
|
||||||
last_received_at: List[float] = [time.monotonic()]
|
last_received_at: List[float] = [time.monotonic()]
|
||||||
last_heartbeat_at: List[float] = [0.0]
|
last_heartbeat_at: List[float] = [0.0]
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from services.base import LLMMessage
|
|||||||
from models.ws_v1 import (
|
from models.ws_v1 import (
|
||||||
parse_client_message,
|
parse_client_message,
|
||||||
ev,
|
ev,
|
||||||
HelloMessage,
|
|
||||||
SessionStartMessage,
|
SessionStartMessage,
|
||||||
SessionStopMessage,
|
SessionStopMessage,
|
||||||
InputTextMessage,
|
InputTextMessage,
|
||||||
@@ -39,7 +38,6 @@ _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"
|
|||||||
class WsSessionState(str, Enum):
|
class WsSessionState(str, Enum):
|
||||||
"""Protocol state machine for WS sessions."""
|
"""Protocol state machine for WS sessions."""
|
||||||
|
|
||||||
WAIT_HELLO = "wait_hello"
|
|
||||||
WAIT_START = "wait_start"
|
WAIT_START = "wait_start"
|
||||||
ACTIVE = "active"
|
ACTIVE = "active"
|
||||||
STOPPED = "stopped"
|
STOPPED = "stopped"
|
||||||
@@ -57,7 +55,15 @@ class Session:
|
|||||||
TRACK_AUDIO_OUT = "audio_out"
|
TRACK_AUDIO_OUT = "audio_out"
|
||||||
TRACK_CONTROL = "control"
|
TRACK_CONTROL = "control"
|
||||||
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||||
_CLIENT_METADATA_OVERRIDES = {
|
_METADATA_ALLOWED_TOP_LEVEL_KEYS = {
|
||||||
|
"overrides",
|
||||||
|
"dynamicVariables",
|
||||||
|
"channel",
|
||||||
|
"source",
|
||||||
|
"history",
|
||||||
|
"workflow", # explicitly ignored for this MVP protocol version
|
||||||
|
}
|
||||||
|
_METADATA_ALLOWED_OVERRIDE_KEYS = {
|
||||||
"firstTurnMode",
|
"firstTurnMode",
|
||||||
"greeting",
|
"greeting",
|
||||||
"generatedOpenerEnabled",
|
"generatedOpenerEnabled",
|
||||||
@@ -67,18 +73,22 @@ class Session:
|
|||||||
"knowledge",
|
"knowledge",
|
||||||
"knowledgeBaseId",
|
"knowledgeBaseId",
|
||||||
"openerAudio",
|
"openerAudio",
|
||||||
"dynamicVariables",
|
"tools",
|
||||||
"history",
|
|
||||||
"userId",
|
|
||||||
"assistantId",
|
|
||||||
"source",
|
|
||||||
}
|
}
|
||||||
_CLIENT_METADATA_ID_KEYS = {
|
_METADATA_FORBIDDEN_TOP_LEVEL_KEYS = {
|
||||||
|
"assistantId",
|
||||||
"appId",
|
"appId",
|
||||||
"app_id",
|
"app_id",
|
||||||
"channel",
|
|
||||||
"configVersionId",
|
"configVersionId",
|
||||||
"config_version_id",
|
"config_version_id",
|
||||||
|
"services",
|
||||||
|
}
|
||||||
|
_METADATA_FORBIDDEN_KEY_TOKENS = {
|
||||||
|
"apikey",
|
||||||
|
"token",
|
||||||
|
"secret",
|
||||||
|
"password",
|
||||||
|
"authorization",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -87,6 +97,7 @@ class Session:
|
|||||||
transport: BaseTransport,
|
transport: BaseTransport,
|
||||||
use_duplex: bool = None,
|
use_duplex: bool = None,
|
||||||
backend_gateway: Optional[Any] = None,
|
backend_gateway: Optional[Any] = None,
|
||||||
|
assistant_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize session.
|
Initialize session.
|
||||||
@@ -99,6 +110,7 @@ class Session:
|
|||||||
self.id = session_id
|
self.id = session_id
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||||
|
self._assistant_id = str(assistant_id or "").strip() or None
|
||||||
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
||||||
self._history_bridge = SessionHistoryBridge(
|
self._history_bridge = SessionHistoryBridge(
|
||||||
history_writer=self._backend_gateway,
|
history_writer=self._backend_gateway,
|
||||||
@@ -121,9 +133,8 @@ class Session:
|
|||||||
# Session state
|
# Session state
|
||||||
self.created_at = None
|
self.created_at = None
|
||||||
self.state = "created" # Legacy call state for /call/lists
|
self.state = "created" # Legacy call state for /call/lists
|
||||||
self.ws_state = WsSessionState.WAIT_HELLO
|
self.ws_state = WsSessionState.WAIT_START
|
||||||
self._pipeline_started = False
|
self._pipeline_started = False
|
||||||
self.protocol_version: Optional[str] = None
|
|
||||||
|
|
||||||
# Track IDs
|
# Track IDs
|
||||||
self.current_track_id: str = self.TRACK_CONTROL
|
self.current_track_id: str = self.TRACK_CONTROL
|
||||||
@@ -137,7 +148,12 @@ class Session:
|
|||||||
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
||||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||||
|
|
||||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
logger.info(
|
||||||
|
"Session {} created (duplex={}, assistant_id={})",
|
||||||
|
self.id,
|
||||||
|
self.use_duplex,
|
||||||
|
self._assistant_id or "-",
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_text(self, text_data: str) -> None:
|
async def handle_text(self, text_data: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -222,23 +238,19 @@ class Session:
|
|||||||
msg_type = message.type
|
msg_type = message.type
|
||||||
logger.info(f"Session {self.id} received message: {msg_type}")
|
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||||
|
|
||||||
if isinstance(message, HelloMessage):
|
|
||||||
await self._handle_hello(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# All messages below require hello handshake first
|
|
||||||
if self.ws_state == WsSessionState.WAIT_HELLO:
|
|
||||||
await self._send_error(
|
|
||||||
"client",
|
|
||||||
"Expected hello message first",
|
|
||||||
"protocol.order",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(message, SessionStartMessage):
|
if isinstance(message, SessionStartMessage):
|
||||||
await self._handle_session_start(message)
|
await self._handle_session_start(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# All messages below require session.start first
|
||||||
|
if self.ws_state == WsSessionState.WAIT_START:
|
||||||
|
await self._send_error(
|
||||||
|
"client",
|
||||||
|
"Expected session.start message first",
|
||||||
|
"protocol.order",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# All messages below require active session
|
# All messages below require active session
|
||||||
if self.ws_state != WsSessionState.ACTIVE:
|
if self.ws_state != WsSessionState.ACTIVE:
|
||||||
await self._send_error(
|
await self._send_error(
|
||||||
@@ -262,67 +274,56 @@ class Session:
|
|||||||
else:
|
else:
|
||||||
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||||
|
|
||||||
async def _handle_hello(self, message: HelloMessage) -> None:
|
|
||||||
"""Handle initial hello/version negotiation."""
|
|
||||||
if self.ws_state != WsSessionState.WAIT_HELLO:
|
|
||||||
await self._send_error("client", "Duplicate hello", "protocol.order")
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.version != settings.ws_protocol_version:
|
|
||||||
await self._send_error(
|
|
||||||
"client",
|
|
||||||
f"Unsupported protocol version '{message.version}'",
|
|
||||||
"protocol.version_unsupported",
|
|
||||||
)
|
|
||||||
await self.transport.close()
|
|
||||||
self.ws_state = WsSessionState.STOPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
self.protocol_version = message.version
|
|
||||||
self.ws_state = WsSessionState.WAIT_START
|
|
||||||
await self._send_event(
|
|
||||||
ev(
|
|
||||||
"hello.ack",
|
|
||||||
version=self.protocol_version,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||||
"""Handle explicit session start after successful hello."""
|
"""Handle explicit session start."""
|
||||||
if self.ws_state != WsSessionState.WAIT_START:
|
if self.ws_state != WsSessionState.WAIT_START:
|
||||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||||
return
|
return
|
||||||
|
|
||||||
raw_metadata = message.metadata or {}
|
raw_metadata = message.metadata or {}
|
||||||
workflow_runtime = self._bootstrap_workflow(raw_metadata)
|
if not self._assistant_id:
|
||||||
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
|
await self._send_error(
|
||||||
client_runtime = self._sanitize_client_metadata(raw_metadata)
|
"client",
|
||||||
requested_assistant_id = (
|
"Missing required query parameter assistant_id",
|
||||||
workflow_runtime.get("assistantId")
|
"protocol.assistant_id_required",
|
||||||
or raw_metadata.get("assistantId")
|
stage="protocol",
|
||||||
or raw_metadata.get("appId")
|
retryable=False,
|
||||||
or raw_metadata.get("app_id")
|
|
||||||
)
|
)
|
||||||
if server_runtime:
|
await self.transport.close()
|
||||||
logger.info(
|
self.ws_state = WsSessionState.STOPPED
|
||||||
"Session {} loaded trusted runtime config from backend "
|
return
|
||||||
"(requested_assistant_id={}, resolved_assistant_id={}, configVersionId={}, has_services={})",
|
|
||||||
self.id,
|
sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata)
|
||||||
requested_assistant_id,
|
if metadata_error:
|
||||||
server_runtime.get("assistantId"),
|
await self._send_error(
|
||||||
server_runtime.get("configVersionId") or server_runtime.get("config_version_id"),
|
"client",
|
||||||
isinstance(server_runtime.get("services"), dict),
|
metadata_error["message"],
|
||||||
|
metadata_error["code"],
|
||||||
|
stage="protocol",
|
||||||
|
retryable=False,
|
||||||
)
|
)
|
||||||
else:
|
await self.transport.close()
|
||||||
logger.warning(
|
self.ws_state = WsSessionState.STOPPED
|
||||||
"Session {} missing trusted backend runtime config "
|
return
|
||||||
"(requested_assistant_id={}); falling back to engine defaults + safe client overrides",
|
|
||||||
self.id,
|
server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id)
|
||||||
requested_assistant_id,
|
if runtime_error:
|
||||||
|
await self._send_error(
|
||||||
|
"server",
|
||||||
|
runtime_error["message"],
|
||||||
|
runtime_error["code"],
|
||||||
|
stage="protocol",
|
||||||
|
retryable=False,
|
||||||
)
|
)
|
||||||
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
|
await self.transport.close()
|
||||||
metadata = self._merge_runtime_metadata(metadata, client_runtime)
|
self.ws_state = WsSessionState.STOPPED
|
||||||
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, raw_metadata)
|
return
|
||||||
|
|
||||||
|
metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {}))
|
||||||
|
for key in ("channel", "source", "history"):
|
||||||
|
if key in sanitized_metadata:
|
||||||
|
metadata[key] = sanitized_metadata[key]
|
||||||
|
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata)
|
||||||
if dynamic_var_error:
|
if dynamic_var_error:
|
||||||
await self._send_error(
|
await self._send_error(
|
||||||
"client",
|
"client",
|
||||||
@@ -331,6 +332,8 @@ class Session:
|
|||||||
stage="protocol",
|
stage="protocol",
|
||||||
retryable=False,
|
retryable=False,
|
||||||
)
|
)
|
||||||
|
await self.transport.close()
|
||||||
|
self.ws_state = WsSessionState.STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create history call record early so later turn callbacks can append transcripts.
|
# Create history call record early so later turn callbacks can append transcripts.
|
||||||
@@ -348,7 +351,7 @@ class Session:
|
|||||||
"(assistantId={}, configVersionId={}, output_mode={}, "
|
"(assistantId={}, configVersionId={}, output_mode={}, "
|
||||||
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
|
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
|
||||||
self.id,
|
self.id,
|
||||||
metadata.get("assistantId") or metadata.get("appId") or metadata.get("app_id"),
|
metadata.get("assistantId"),
|
||||||
metadata.get("configVersionId") or metadata.get("config_version_id"),
|
metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||||
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
|
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
|
||||||
llm_cfg.get("provider"),
|
llm_cfg.get("provider"),
|
||||||
@@ -387,24 +390,6 @@ class Session:
|
|||||||
config=self._build_config_resolved(metadata),
|
config=self._build_config_resolved(metadata),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.workflow_runner and self._workflow_initial_node:
|
|
||||||
await self._send_event(
|
|
||||||
ev(
|
|
||||||
"workflow.started",
|
|
||||||
workflowId=self.workflow_runner.workflow_id,
|
|
||||||
workflowName=self.workflow_runner.name,
|
|
||||||
nodeId=self._workflow_initial_node.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await self._send_event(
|
|
||||||
ev(
|
|
||||||
"workflow.node.entered",
|
|
||||||
workflowId=self.workflow_runner.workflow_id,
|
|
||||||
nodeId=self._workflow_initial_node.id,
|
|
||||||
nodeName=self._workflow_initial_node.name,
|
|
||||||
nodeType=self._workflow_initial_node.node_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Emit opener only after frontend has received session.started/config events.
|
# Emit opener only after frontend has received session.started/config events.
|
||||||
await self.pipeline.emit_initial_greeting()
|
await self.pipeline.emit_initial_greeting()
|
||||||
@@ -658,7 +643,7 @@ class Session:
|
|||||||
def _event_source(self, event_type: str) -> str:
|
def _event_source(self, event_type: str) -> str:
|
||||||
if event_type.startswith("workflow."):
|
if event_type.startswith("workflow."):
|
||||||
return "system"
|
return "system"
|
||||||
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
|
if event_type.startswith("session.") or event_type == "heartbeat":
|
||||||
return "system"
|
return "system"
|
||||||
if event_type == "error":
|
if event_type == "error":
|
||||||
return "system"
|
return "system"
|
||||||
@@ -931,26 +916,41 @@ class Session:
|
|||||||
|
|
||||||
async def _load_server_runtime_metadata(
|
async def _load_server_runtime_metadata(
|
||||||
self,
|
self,
|
||||||
client_metadata: Dict[str, Any],
|
assistant_id: str,
|
||||||
workflow_runtime: Dict[str, Any],
|
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Load trusted runtime metadata from backend assistant config."""
|
"""Load trusted runtime metadata from backend assistant config."""
|
||||||
assistant_id = (
|
if not assistant_id:
|
||||||
workflow_runtime.get("assistantId")
|
return {}, {
|
||||||
or client_metadata.get("assistantId")
|
"code": "protocol.assistant_id_required",
|
||||||
or client_metadata.get("appId")
|
"message": "Missing required query parameter assistant_id",
|
||||||
or client_metadata.get("app_id")
|
}
|
||||||
)
|
|
||||||
if assistant_id is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
||||||
if not callable(provider):
|
if not callable(provider):
|
||||||
return {}
|
return {}, {
|
||||||
|
"code": "assistant.config_unavailable",
|
||||||
|
"message": "Assistant config backend unavailable",
|
||||||
|
}
|
||||||
|
|
||||||
payload = await provider(str(assistant_id).strip())
|
payload = await provider(str(assistant_id).strip())
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
error_code = str(payload.get("__error_code") or "").strip()
|
||||||
|
if error_code == "assistant.not_found":
|
||||||
|
return {}, {
|
||||||
|
"code": "assistant.not_found",
|
||||||
|
"message": f"Assistant not found: {assistant_id}",
|
||||||
|
}
|
||||||
|
if error_code == "assistant.config_unavailable":
|
||||||
|
return {}, {
|
||||||
|
"code": "assistant.config_unavailable",
|
||||||
|
"message": f"Assistant config unavailable: {assistant_id}",
|
||||||
|
}
|
||||||
|
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
return {}
|
return {}, {
|
||||||
|
"code": "assistant.config_unavailable",
|
||||||
|
"message": f"Assistant config unavailable: {assistant_id}",
|
||||||
|
}
|
||||||
|
|
||||||
assistant_cfg: Dict[str, Any] = {}
|
assistant_cfg: Dict[str, Any] = {}
|
||||||
session_start_cfg = payload.get("sessionStartMetadata")
|
session_start_cfg = payload.get("sessionStartMetadata")
|
||||||
@@ -962,7 +962,10 @@ class Session:
|
|||||||
assistant_cfg = payload
|
assistant_cfg = payload
|
||||||
|
|
||||||
if not isinstance(assistant_cfg, dict):
|
if not isinstance(assistant_cfg, dict):
|
||||||
return {}
|
return {}, {
|
||||||
|
"code": "assistant.config_unavailable",
|
||||||
|
"message": f"Assistant config unavailable: {assistant_id}",
|
||||||
|
}
|
||||||
|
|
||||||
runtime: Dict[str, Any] = {}
|
runtime: Dict[str, Any] = {}
|
||||||
passthrough_keys = {
|
passthrough_keys = {
|
||||||
@@ -1009,36 +1012,110 @@ class Session:
|
|||||||
|
|
||||||
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
||||||
runtime["configVersionId"] = runtime.get("config_version_id")
|
runtime["configVersionId"] = runtime.get("config_version_id")
|
||||||
return runtime
|
return runtime, None
|
||||||
|
|
||||||
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]:
|
||||||
"""
|
if isinstance(payload, dict):
|
||||||
Sanitize untrusted metadata sources.
|
for key, value in payload.items():
|
||||||
|
key_str = str(key)
|
||||||
|
normalized = key_str.lower().replace("_", "").replace("-", "")
|
||||||
|
if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS):
|
||||||
|
return f"{path}.{key_str}"
|
||||||
|
nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}")
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return None
|
||||||
|
if isinstance(payload, list):
|
||||||
|
for idx, value in enumerate(payload):
|
||||||
|
nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]")
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return None
|
||||||
|
|
||||||
This keeps only a small override whitelist and stable config ID fields.
|
def _validate_and_sanitize_client_metadata(
|
||||||
"""
|
self,
|
||||||
|
metadata: Dict[str, Any],
|
||||||
|
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
|
||||||
if not isinstance(metadata, dict):
|
if not isinstance(metadata, dict):
|
||||||
return {}
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": "metadata must be an object",
|
||||||
|
}
|
||||||
|
|
||||||
sanitized: Dict[str, Any] = {}
|
forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata]
|
||||||
for key in self._CLIENT_METADATA_ID_KEYS:
|
if forbidden_top_level:
|
||||||
if key in metadata:
|
return {}, {
|
||||||
sanitized[key] = metadata[key]
|
"code": "protocol.invalid_override",
|
||||||
for key in self._CLIENT_METADATA_OVERRIDES:
|
"message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}",
|
||||||
if key in metadata:
|
}
|
||||||
sanitized[key] = metadata[key]
|
|
||||||
|
|
||||||
return sanitized
|
unknown_keys = [
|
||||||
|
key
|
||||||
|
for key in metadata.keys()
|
||||||
|
if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS
|
||||||
|
and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS
|
||||||
|
]
|
||||||
|
if unknown_keys:
|
||||||
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}",
|
||||||
|
}
|
||||||
|
|
||||||
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
if "workflow" in metadata:
|
||||||
"""Apply client metadata whitelist and remove forbidden secrets."""
|
logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id)
|
||||||
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
|
|
||||||
if isinstance(metadata.get("services"), dict):
|
overrides_raw = metadata.get("overrides")
|
||||||
logger.warning(
|
overrides: Dict[str, Any] = {}
|
||||||
"Session {} provided metadata.services from client; client-side service config is ignored",
|
if overrides_raw is not None:
|
||||||
self.id,
|
if not isinstance(overrides_raw, dict):
|
||||||
)
|
return {}, {
|
||||||
return sanitized
|
"code": "protocol.invalid_override",
|
||||||
|
"message": "metadata.overrides must be an object",
|
||||||
|
}
|
||||||
|
unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS]
|
||||||
|
if unsupported_override_keys:
|
||||||
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}",
|
||||||
|
}
|
||||||
|
overrides = dict(overrides_raw)
|
||||||
|
|
||||||
|
dynamic_variables = metadata.get("dynamicVariables")
|
||||||
|
history_raw = metadata.get("history")
|
||||||
|
history: Dict[str, Any] = {}
|
||||||
|
if history_raw is not None:
|
||||||
|
if not isinstance(history_raw, dict):
|
||||||
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": "metadata.history must be an object",
|
||||||
|
}
|
||||||
|
unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"]
|
||||||
|
if unsupported_history_keys:
|
||||||
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}",
|
||||||
|
}
|
||||||
|
if "userId" in history_raw:
|
||||||
|
history["userId"] = history_raw.get("userId")
|
||||||
|
|
||||||
|
sanitized: Dict[str, Any] = {"overrides": overrides}
|
||||||
|
if dynamic_variables is not None:
|
||||||
|
sanitized["dynamicVariables"] = dynamic_variables
|
||||||
|
if "channel" in metadata:
|
||||||
|
sanitized["channel"] = metadata.get("channel")
|
||||||
|
if "source" in metadata:
|
||||||
|
sanitized["source"] = metadata.get("source")
|
||||||
|
if history:
|
||||||
|
sanitized["history"] = history
|
||||||
|
|
||||||
|
forbidden_path = self._find_forbidden_secret_key(sanitized)
|
||||||
|
if forbidden_path:
|
||||||
|
return {}, {
|
||||||
|
"code": "protocol.invalid_override",
|
||||||
|
"message": f"Forbidden secret-like key detected at {forbidden_path}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized, None
|
||||||
|
|
||||||
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build public resolved config payload (secrets removed)."""
|
"""Build public resolved config payload (secrets removed)."""
|
||||||
@@ -1047,7 +1124,7 @@ class Session:
|
|||||||
runtime = self.pipeline.resolved_runtime_config()
|
runtime = self.pipeline.resolved_runtime_config()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
|
"appId": metadata.get("assistantId"),
|
||||||
"channel": metadata.get("channel"),
|
"channel": metadata.get("channel"),
|
||||||
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||||
"prompt": {"sha256": prompt_hash},
|
"prompt": {"sha256": prompt_hash},
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ This document defines the public WebSocket protocol for the `/ws` endpoint.
|
|||||||
Validation policy:
|
Validation policy:
|
||||||
- WS v1 JSON control messages are validated strictly.
|
- WS v1 JSON control messages are validated strictly.
|
||||||
- Unknown top-level fields are rejected for all defined client message types.
|
- Unknown top-level fields are rejected for all defined client message types.
|
||||||
- `hello.version` is fixed to `"v1"`.
|
- `assistant_id` query parameter is required on `/ws`.
|
||||||
|
|
||||||
## Transport
|
## Transport
|
||||||
|
|
||||||
@@ -17,29 +17,16 @@ Validation policy:
|
|||||||
|
|
||||||
Required message order:
|
Required message order:
|
||||||
|
|
||||||
1. Client sends `hello`.
|
1. Client connects to `/ws?assistant_id=<id>`.
|
||||||
2. Server replies `hello.ack`.
|
2. Client sends `session.start`.
|
||||||
3. Client sends `session.start`.
|
3. Server replies `session.started`.
|
||||||
4. Server replies `session.started`.
|
4. Client may stream binary audio and/or send `input.text`.
|
||||||
5. Client may stream binary audio and/or send `input.text`.
|
5. Client sends `session.stop` (or closes socket).
|
||||||
6. Client sends `session.stop` (or closes socket).
|
|
||||||
|
|
||||||
If order is violated, server emits `error` with `code = "protocol.order"`.
|
If order is violated, server emits `error` with `code = "protocol.order"`.
|
||||||
|
|
||||||
## Client -> Server Messages
|
## Client -> Server Messages
|
||||||
|
|
||||||
### `hello`
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "hello",
|
|
||||||
"version": "v1"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
- `version` must be `v1`.
|
|
||||||
|
|
||||||
### `session.start`
|
### `session.start`
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -51,15 +38,18 @@ Rules:
|
|||||||
"channels": 1
|
"channels": 1
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": "assistant_123",
|
|
||||||
"channel": "web",
|
"channel": "web",
|
||||||
"configVersionId": "cfg_20260217_01",
|
"source": "web-debug",
|
||||||
"client": "web-debug",
|
"history": {
|
||||||
|
"userId": 1
|
||||||
|
},
|
||||||
|
"overrides": {
|
||||||
"output": {
|
"output": {
|
||||||
"mode": "audio"
|
"mode": "audio"
|
||||||
},
|
},
|
||||||
"systemPrompt": "You are concise.",
|
"systemPrompt": "You are concise.",
|
||||||
"greeting": "Hi, how can I help?",
|
"greeting": "Hi, how can I help?"
|
||||||
|
},
|
||||||
"dynamicVariables": {
|
"dynamicVariables": {
|
||||||
"customer_name": "Alice",
|
"customer_name": "Alice",
|
||||||
"plan_tier": "Pro"
|
"plan_tier": "Pro"
|
||||||
@@ -69,9 +59,13 @@ Rules:
|
|||||||
```
|
```
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
- Client-side `metadata.services` is ignored.
|
- Assistant config is resolved strictly by URL query `assistant_id`.
|
||||||
- Service config (including secrets) is resolved server-side (env/backend).
|
- `metadata` top-level keys allowed: `overrides`, `dynamicVariables`, `channel`, `source`, `history`, `workflow` (`workflow` is ignored).
|
||||||
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints).
|
- `metadata.overrides` whitelist: `systemPrompt`, `greeting`, `firstTurnMode`, `generatedOpenerEnabled`, `output`, `bargeIn`, `knowledgeBaseId`, `knowledge`, `tools`, `openerAudio`.
|
||||||
|
- `metadata.services` is rejected with `protocol.invalid_override`.
|
||||||
|
- `metadata.workflow` is ignored in this MVP protocol version.
|
||||||
|
- Top-level IDs are forbidden in payload (`assistantId`, `appId`, `app_id`, `configVersionId`, `config_version_id`).
|
||||||
|
- Secret-like keys are forbidden in metadata (`apiKey`, `token`, `secret`, `password`, `authorization`).
|
||||||
- `metadata.dynamicVariables` is optional and must be an object of string key/value pairs.
|
- `metadata.dynamicVariables` is optional and must be an object of string key/value pairs.
|
||||||
- Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`
|
- Key pattern: `^[a-zA-Z_][a-zA-Z0-9_]{0,63}$`
|
||||||
- Max entries: 30
|
- Max entries: 30
|
||||||
@@ -85,7 +79,7 @@ Rules:
|
|||||||
- Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`.
|
- Invalid `dynamicVariables` payload rejects `session.start` with `protocol.dynamic_variables_invalid`.
|
||||||
|
|
||||||
Text-only mode:
|
Text-only mode:
|
||||||
- Set `metadata.output.mode = "text"`.
|
- Set `metadata.overrides.output.mode = "text"`.
|
||||||
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
|
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
|
||||||
|
|
||||||
### `input.text`
|
### `input.text`
|
||||||
@@ -158,8 +152,6 @@ Envelope notes:
|
|||||||
|
|
||||||
Common events:
|
Common events:
|
||||||
|
|
||||||
- `hello.ack`
|
|
||||||
- Fields: `sessionId`, `version`
|
|
||||||
- `session.started`
|
- `session.started`
|
||||||
- Fields: `sessionId`, `trackId`, `tracks`, `audio`
|
- Fields: `sessionId`, `trackId`, `tracks`, `audio`
|
||||||
- `config.resolved`
|
- `config.resolved`
|
||||||
@@ -204,7 +196,7 @@ Common events:
|
|||||||
Track IDs (MVP fixed values):
|
Track IDs (MVP fixed values):
|
||||||
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
|
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
|
||||||
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
|
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
|
||||||
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`)
|
- `control`: session/control events (`session.*`, `error`, `config.resolved`)
|
||||||
|
|
||||||
Correlation IDs (`event.data`):
|
Correlation IDs (`event.data`):
|
||||||
- `turn_id`: one user-assistant interaction turn.
|
- `turn_id`: one user-assistant interaction turn.
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import time
|
|||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -59,9 +60,8 @@ class MicrophoneClient:
|
|||||||
url: str,
|
url: str,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
chunk_duration_ms: int = 20,
|
chunk_duration_ms: int = 20,
|
||||||
app_id: str = "assistant_demo",
|
assistant_id: str = "assistant_demo",
|
||||||
channel: str = "mic_client",
|
channel: str = "mic_client",
|
||||||
config_version_id: str = "local-dev",
|
|
||||||
input_device: int = None,
|
input_device: int = None,
|
||||||
output_device: int = None,
|
output_device: int = None,
|
||||||
track_debug: bool = False,
|
track_debug: bool = False,
|
||||||
@@ -80,9 +80,8 @@ class MicrophoneClient:
|
|||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.chunk_duration_ms = chunk_duration_ms
|
self.chunk_duration_ms = chunk_duration_ms
|
||||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||||
self.app_id = app_id
|
self.assistant_id = assistant_id
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.config_version_id = config_version_id
|
|
||||||
self.input_device = input_device
|
self.input_device = input_device
|
||||||
self.output_device = output_device
|
self.output_device = output_device
|
||||||
self.track_debug = track_debug
|
self.track_debug = track_debug
|
||||||
@@ -126,18 +125,20 @@ class MicrophoneClient:
|
|||||||
parts.append(f"{key}={value}")
|
parts.append(f"{key}={value}")
|
||||||
return f" [{' '.join(parts)}]" if parts else ""
|
return f" [{' '.join(parts)}]" if parts else ""
|
||||||
|
|
||||||
|
def _session_url(self) -> str:
|
||||||
|
parts = urlsplit(self.url)
|
||||||
|
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||||
|
query["assistant_id"] = self.assistant_id
|
||||||
|
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Connect to WebSocket server."""
|
"""Connect to WebSocket server."""
|
||||||
print(f"Connecting to {self.url}...")
|
session_url = self._session_url()
|
||||||
self.ws = await websockets.connect(self.url)
|
print(f"Connecting to {session_url}...")
|
||||||
|
self.ws = await websockets.connect(session_url)
|
||||||
self.running = True
|
self.running = True
|
||||||
print("Connected!")
|
print("Connected!")
|
||||||
|
|
||||||
# WS v1 handshake: hello -> session.start
|
|
||||||
await self.send_command({
|
|
||||||
"type": "hello",
|
|
||||||
"version": "v1",
|
|
||||||
})
|
|
||||||
await self.send_command({
|
await self.send_command({
|
||||||
"type": "session.start",
|
"type": "session.start",
|
||||||
"audio": {
|
"audio": {
|
||||||
@@ -146,9 +147,8 @@ class MicrophoneClient:
|
|||||||
"channels": 1,
|
"channels": 1,
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": self.app_id,
|
|
||||||
"channel": self.channel,
|
"channel": self.channel,
|
||||||
"configVersionId": self.config_version_id,
|
"source": "mic_client",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ class MicrophoneClient:
|
|||||||
if self.track_debug:
|
if self.track_debug:
|
||||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||||
|
|
||||||
if event_type in {"hello.ack", "session.started"}:
|
if event_type == "session.started":
|
||||||
print(f"← Session ready!{ids}")
|
print(f"← Session ready!{ids}")
|
||||||
elif event_type == "config.resolved":
|
elif event_type == "config.resolved":
|
||||||
print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}")
|
print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}")
|
||||||
@@ -609,20 +609,15 @@ async def main():
|
|||||||
help="Show streaming LLM response chunks"
|
help="Show streaming LLM response chunks"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--app-id",
|
"--assistant-id",
|
||||||
default="assistant_demo",
|
default="assistant_demo",
|
||||||
help="Stable app/assistant identifier for server-side config lookup"
|
help="Assistant identifier used in websocket query parameter"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--channel",
|
"--channel",
|
||||||
default="mic_client",
|
default="mic_client",
|
||||||
help="Client channel name"
|
help="Client channel name"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--config-version-id",
|
|
||||||
default="local-dev",
|
|
||||||
help="Optional config version identifier"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--track-debug",
|
"--track-debug",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -638,9 +633,8 @@ async def main():
|
|||||||
client = MicrophoneClient(
|
client = MicrophoneClient(
|
||||||
url=args.url,
|
url=args.url,
|
||||||
sample_rate=args.sample_rate,
|
sample_rate=args.sample_rate,
|
||||||
app_id=args.app_id,
|
assistant_id=args.assistant_id,
|
||||||
channel=args.channel,
|
channel=args.channel,
|
||||||
config_version_id=args.config_version_id,
|
|
||||||
input_device=args.input_device,
|
input_device=args.input_device,
|
||||||
output_device=args.output_device,
|
output_device=args.output_device,
|
||||||
track_debug=args.track_debug,
|
track_debug=args.track_debug,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
import io
|
import io
|
||||||
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -56,16 +57,14 @@ class SimpleVoiceClient:
|
|||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
app_id: str = "assistant_demo",
|
assistant_id: str = "assistant_demo",
|
||||||
channel: str = "simple_client",
|
channel: str = "simple_client",
|
||||||
config_version_id: str = "local-dev",
|
|
||||||
track_debug: bool = False,
|
track_debug: bool = False,
|
||||||
):
|
):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.app_id = app_id
|
self.assistant_id = assistant_id
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.config_version_id = config_version_id
|
|
||||||
self.track_debug = track_debug
|
self.track_debug = track_debug
|
||||||
self.ws = None
|
self.ws = None
|
||||||
self.running = False
|
self.running = False
|
||||||
@@ -88,6 +87,12 @@ class SimpleVoiceClient:
|
|||||||
# Interrupt handling - discard audio until next trackStart
|
# Interrupt handling - discard audio until next trackStart
|
||||||
self._discard_audio = False
|
self._discard_audio = False
|
||||||
|
|
||||||
|
def _session_url(self) -> str:
|
||||||
|
parts = urlsplit(self.url)
|
||||||
|
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||||
|
query["assistant_id"] = self.assistant_id
|
||||||
|
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _event_ids_suffix(event: dict) -> str:
|
def _event_ids_suffix(event: dict) -> str:
|
||||||
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
||||||
@@ -101,16 +106,12 @@ class SimpleVoiceClient:
|
|||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""Connect to server."""
|
"""Connect to server."""
|
||||||
print(f"Connecting to {self.url}...")
|
session_url = self._session_url()
|
||||||
self.ws = await websockets.connect(self.url)
|
print(f"Connecting to {session_url}...")
|
||||||
|
self.ws = await websockets.connect(session_url)
|
||||||
self.running = True
|
self.running = True
|
||||||
print("Connected!")
|
print("Connected!")
|
||||||
|
|
||||||
# WS v1 handshake: hello -> session.start
|
|
||||||
await self.ws.send(json.dumps({
|
|
||||||
"type": "hello",
|
|
||||||
"version": "v1",
|
|
||||||
}))
|
|
||||||
await self.ws.send(json.dumps({
|
await self.ws.send(json.dumps({
|
||||||
"type": "session.start",
|
"type": "session.start",
|
||||||
"audio": {
|
"audio": {
|
||||||
@@ -119,12 +120,11 @@ class SimpleVoiceClient:
|
|||||||
"channels": 1,
|
"channels": 1,
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": self.app_id,
|
|
||||||
"channel": self.channel,
|
"channel": self.channel,
|
||||||
"configVersionId": self.config_version_id,
|
"source": "simple_client",
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
print("-> hello/session.start")
|
print("-> session.start")
|
||||||
|
|
||||||
async def send_chat(self, text: str):
|
async def send_chat(self, text: str):
|
||||||
"""Send chat message."""
|
"""Send chat message."""
|
||||||
@@ -311,9 +311,8 @@ async def main():
|
|||||||
parser.add_argument("--text", help="Send text and play response")
|
parser.add_argument("--text", help="Send text and play response")
|
||||||
parser.add_argument("--list-devices", action="store_true")
|
parser.add_argument("--list-devices", action="store_true")
|
||||||
parser.add_argument("--sample-rate", type=int, default=16000)
|
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||||
parser.add_argument("--app-id", default="assistant_demo")
|
parser.add_argument("--assistant-id", default="assistant_demo")
|
||||||
parser.add_argument("--channel", default="simple_client")
|
parser.add_argument("--channel", default="simple_client")
|
||||||
parser.add_argument("--config-version-id", default="local-dev")
|
|
||||||
parser.add_argument("--track-debug", action="store_true")
|
parser.add_argument("--track-debug", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -325,9 +324,8 @@ async def main():
|
|||||||
client = SimpleVoiceClient(
|
client = SimpleVoiceClient(
|
||||||
args.url,
|
args.url,
|
||||||
args.sample_rate,
|
args.sample_rate,
|
||||||
app_id=args.app_id,
|
assistant_id=args.assistant_id,
|
||||||
channel=args.channel,
|
channel=args.channel,
|
||||||
config_version_id=args.config_version_id,
|
|
||||||
track_debug=args.track_debug,
|
track_debug=args.track_debug,
|
||||||
)
|
)
|
||||||
await client.run(args.text)
|
await client.run(args.text)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import math
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
SERVER_URL = "ws://localhost:8000/ws"
|
SERVER_URL = "ws://localhost:8000/ws"
|
||||||
@@ -130,14 +131,17 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
|
|||||||
"""Run the WebSocket test client."""
|
"""Run the WebSocket test client."""
|
||||||
session = aiohttp.ClientSession()
|
session = aiohttp.ClientSession()
|
||||||
try:
|
try:
|
||||||
print(f"🔌 Connecting to {url}...")
|
parts = urlsplit(url)
|
||||||
async with session.ws_connect(url) as ws:
|
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||||
|
query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo")
|
||||||
|
session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||||
|
print(f"🔌 Connecting to {session_url}...")
|
||||||
|
async with session.ws_connect(session_url) as ws:
|
||||||
print("✅ Connected!")
|
print("✅ Connected!")
|
||||||
session_ready = asyncio.Event()
|
session_ready = asyncio.Event()
|
||||||
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
|
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
|
||||||
|
|
||||||
# Send v1 hello + session.start handshake
|
# Send v1 session.start initialization
|
||||||
await ws.send_json({"type": "hello", "version": "v1"})
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
"type": "session.start",
|
"type": "session.start",
|
||||||
"audio": {
|
"audio": {
|
||||||
@@ -146,12 +150,11 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
|
|||||||
"channels": 1
|
"channels": 1
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": "assistant_demo",
|
|
||||||
"channel": "test_websocket",
|
"channel": "test_websocket",
|
||||||
"configVersionId": "local-dev",
|
"source": "test_websocket",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
print("📤 Sent v1 hello/session.start")
|
print("📤 Sent v1 session.start")
|
||||||
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
||||||
|
|
||||||
# Select sender based on args
|
# Select sender based on args
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -57,9 +58,8 @@ class WavFileClient:
|
|||||||
url: str,
|
url: str,
|
||||||
input_file: str,
|
input_file: str,
|
||||||
output_file: str,
|
output_file: str,
|
||||||
app_id: str = "assistant_demo",
|
assistant_id: str = "assistant_demo",
|
||||||
channel: str = "wav_client",
|
channel: str = "wav_client",
|
||||||
config_version_id: str = "local-dev",
|
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
chunk_duration_ms: int = 20,
|
chunk_duration_ms: int = 20,
|
||||||
wait_time: float = 15.0,
|
wait_time: float = 15.0,
|
||||||
@@ -82,9 +82,8 @@ class WavFileClient:
|
|||||||
self.url = url
|
self.url = url
|
||||||
self.input_file = Path(input_file)
|
self.input_file = Path(input_file)
|
||||||
self.output_file = Path(output_file)
|
self.output_file = Path(output_file)
|
||||||
self.app_id = app_id
|
self.assistant_id = assistant_id
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.config_version_id = config_version_id
|
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.chunk_duration_ms = chunk_duration_ms
|
self.chunk_duration_ms = chunk_duration_ms
|
||||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||||
@@ -148,18 +147,20 @@ class WavFileClient:
|
|||||||
parts.append(f"{key}={value}")
|
parts.append(f"{key}={value}")
|
||||||
return f" [{' '.join(parts)}]" if parts else ""
|
return f" [{' '.join(parts)}]" if parts else ""
|
||||||
|
|
||||||
|
def _session_url(self) -> str:
|
||||||
|
parts = urlsplit(self.url)
|
||||||
|
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||||
|
query["assistant_id"] = self.assistant_id
|
||||||
|
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Connect to WebSocket server."""
|
"""Connect to WebSocket server."""
|
||||||
self.log_event("→", f"Connecting to {self.url}...")
|
session_url = self._session_url()
|
||||||
self.ws = await websockets.connect(self.url)
|
self.log_event("→", f"Connecting to {session_url}...")
|
||||||
|
self.ws = await websockets.connect(session_url)
|
||||||
self.running = True
|
self.running = True
|
||||||
self.log_event("←", "Connected!")
|
self.log_event("←", "Connected!")
|
||||||
|
|
||||||
# WS v1 handshake: hello -> session.start
|
|
||||||
await self.send_command({
|
|
||||||
"type": "hello",
|
|
||||||
"version": "v1",
|
|
||||||
})
|
|
||||||
await self.send_command({
|
await self.send_command({
|
||||||
"type": "session.start",
|
"type": "session.start",
|
||||||
"audio": {
|
"audio": {
|
||||||
@@ -168,9 +169,8 @@ class WavFileClient:
|
|||||||
"channels": 1
|
"channels": 1
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"appId": self.app_id,
|
|
||||||
"channel": self.channel,
|
"channel": self.channel,
|
||||||
"configVersionId": self.config_version_id,
|
"source": "wav_client",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -329,9 +329,7 @@ class WavFileClient:
|
|||||||
if self.track_debug:
|
if self.track_debug:
|
||||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||||
|
|
||||||
if event_type == "hello.ack":
|
if event_type == "session.started":
|
||||||
self.log_event("←", f"Handshake acknowledged{ids}")
|
|
||||||
elif event_type == "session.started":
|
|
||||||
self.session_ready = True
|
self.session_ready = True
|
||||||
self.log_event("←", f"Session ready!{ids}")
|
self.log_event("←", f"Session ready!{ids}")
|
||||||
elif event_type == "config.resolved":
|
elif event_type == "config.resolved":
|
||||||
@@ -521,20 +519,15 @@ async def main():
|
|||||||
help="Target sample rate for audio (default: 16000)"
|
help="Target sample rate for audio (default: 16000)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--app-id",
|
"--assistant-id",
|
||||||
default="assistant_demo",
|
default="assistant_demo",
|
||||||
help="Stable app/assistant identifier for server-side config lookup"
|
help="Assistant identifier used in websocket query parameter"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--channel",
|
"--channel",
|
||||||
default="wav_client",
|
default="wav_client",
|
||||||
help="Client channel name"
|
help="Client channel name"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--config-version-id",
|
|
||||||
default="local-dev",
|
|
||||||
help="Optional config version identifier"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chunk-duration",
|
"--chunk-duration",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -570,9 +563,8 @@ async def main():
|
|||||||
url=args.url,
|
url=args.url,
|
||||||
input_file=args.input,
|
input_file=args.input,
|
||||||
output_file=args.output,
|
output_file=args.output,
|
||||||
app_id=args.app_id,
|
assistant_id=args.assistant_id,
|
||||||
channel=args.channel,
|
channel=args.channel,
|
||||||
config_version_id=args.config_version_id,
|
|
||||||
sample_rate=args.sample_rate,
|
sample_rate=args.sample_rate,
|
||||||
chunk_duration_ms=args.chunk_duration,
|
chunk_duration_ms=args.chunk_duration,
|
||||||
wait_time=args.wait_time,
|
wait_time=args.wait_time,
|
||||||
|
|||||||
@@ -401,9 +401,14 @@
|
|||||||
|
|
||||||
const targetSampleRate = 16000;
|
const targetSampleRate = 16000;
|
||||||
const playbackStopRampSec = 0.008;
|
const playbackStopRampSec = 0.008;
|
||||||
const appId = "assistant_demo";
|
const assistantId = "assistant_demo";
|
||||||
const channel = "web_client";
|
const channel = "web_client";
|
||||||
const configVersionId = "local-dev";
|
|
||||||
|
function buildSessionWsUrl(baseUrl) {
|
||||||
|
const parsed = new URL(baseUrl);
|
||||||
|
parsed.searchParams.set("assistant_id", assistantId);
|
||||||
|
return parsed.toString();
|
||||||
|
}
|
||||||
|
|
||||||
function logLine(type, text, data) {
|
function logLine(type, text, data) {
|
||||||
const time = new Date().toLocaleTimeString();
|
const time = new Date().toLocaleTimeString();
|
||||||
@@ -556,14 +561,25 @@
|
|||||||
|
|
||||||
async function connect() {
|
async function connect() {
|
||||||
if (ws && ws.readyState === WebSocket.OPEN) return;
|
if (ws && ws.readyState === WebSocket.OPEN) return;
|
||||||
ws = new WebSocket(wsUrl.value.trim());
|
const sessionWsUrl = buildSessionWsUrl(wsUrl.value.trim());
|
||||||
|
ws = new WebSocket(sessionWsUrl);
|
||||||
ws.binaryType = "arraybuffer";
|
ws.binaryType = "arraybuffer";
|
||||||
|
|
||||||
ws.onopen = () => {
|
ws.onopen = () => {
|
||||||
setStatus(true, "Session open");
|
setStatus(true, "Session open");
|
||||||
logLine("sys", "WebSocket connected");
|
logLine("sys", "WebSocket connected");
|
||||||
ensureAudioContext();
|
ensureAudioContext();
|
||||||
sendCommand({ type: "hello", version: "v1" });
|
sendCommand({
|
||||||
|
type: "session.start",
|
||||||
|
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||||
|
metadata: {
|
||||||
|
channel,
|
||||||
|
source: "web_client",
|
||||||
|
overrides: {
|
||||||
|
output: { mode: "audio" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onclose = () => {
|
ws.onclose = () => {
|
||||||
@@ -622,17 +638,6 @@
|
|||||||
const type = event.type || "unknown";
|
const type = event.type || "unknown";
|
||||||
const ids = eventIdsSuffix(event);
|
const ids = eventIdsSuffix(event);
|
||||||
logLine("event", `${type}${ids}`, event);
|
logLine("event", `${type}${ids}`, event);
|
||||||
if (type === "hello.ack") {
|
|
||||||
sendCommand({
|
|
||||||
type: "session.start",
|
|
||||||
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
|
||||||
metadata: {
|
|
||||||
appId,
|
|
||||||
channel,
|
|
||||||
configVersionId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (type === "config.resolved") {
|
if (type === "config.resolved") {
|
||||||
logLine("sys", "config.resolved", event.config || {});
|
logLine("sys", "config.resolved", event.config || {});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,12 +18,6 @@ class _StrictModel(BaseModel):
|
|||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
# Client -> Server messages
|
|
||||||
class HelloMessage(_StrictModel):
|
|
||||||
type: Literal["hello"]
|
|
||||||
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
|
|
||||||
|
|
||||||
|
|
||||||
class SessionStartAudio(_StrictModel):
|
class SessionStartAudio(_StrictModel):
|
||||||
encoding: Literal["pcm_s16le"] = "pcm_s16le"
|
encoding: Literal["pcm_s16le"] = "pcm_s16le"
|
||||||
sample_rate_hz: Literal[16000] = 16000
|
sample_rate_hz: Literal[16000] = 16000
|
||||||
@@ -69,7 +63,6 @@ class ToolCallResultsMessage(_StrictModel):
|
|||||||
|
|
||||||
|
|
||||||
CLIENT_MESSAGE_TYPES = {
|
CLIENT_MESSAGE_TYPES = {
|
||||||
"hello": HelloMessage,
|
|
||||||
"session.start": SessionStartMessage,
|
"session.start": SessionStartMessage,
|
||||||
"session.stop": SessionStopMessage,
|
"session.stop": SessionStopMessage,
|
||||||
"input.text": InputTextMessage,
|
"input.text": InputTextMessage,
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl
|
|||||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
|
async def _noop_speak(_text: str, *args, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||||
|
|||||||
238
engine/tests/test_ws_protocol_session_start.py
Normal file
238
engine/tests/test_ws_protocol_session_start.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.session import Session, WsSessionState
|
||||||
|
from models.ws_v1 import SessionStartMessage, parse_client_message
|
||||||
|
|
||||||
|
|
||||||
|
def _session() -> Session:
|
||||||
|
session = Session.__new__(Session)
|
||||||
|
session.id = "sess_test"
|
||||||
|
session._assistant_id = "assistant_demo"
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_client_message_rejects_hello_message():
|
||||||
|
with pytest.raises(ValueError, match="Unknown client message type: hello"):
|
||||||
|
parse_client_message({"type": "hello", "version": "v1"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_text_reports_invalid_message_for_hello():
|
||||||
|
session = Session.__new__(Session)
|
||||||
|
session.id = "sess_invalid_hello"
|
||||||
|
session.ws_state = WsSessionState.WAIT_START
|
||||||
|
|
||||||
|
class _Transport:
|
||||||
|
async def close(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
session.transport = _Transport()
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
async def _send_error(sender, message, code, **kwargs):
|
||||||
|
captured.append((sender, code, message, kwargs))
|
||||||
|
|
||||||
|
session._send_error = _send_error
|
||||||
|
|
||||||
|
await session.handle_text('{"type":"hello","version":"v1"}')
|
||||||
|
assert captured
|
||||||
|
sender, code, message, _ = captured[0]
|
||||||
|
assert sender == "client"
|
||||||
|
assert code == "protocol.invalid_message"
|
||||||
|
assert "Unknown client message type: hello" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_metadata_rejects_services_payload():
|
||||||
|
session = _session()
|
||||||
|
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
|
||||||
|
assert sanitized == {}
|
||||||
|
assert error is not None
|
||||||
|
assert error["code"] == "protocol.invalid_override"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_metadata_rejects_secret_like_override_keys():
|
||||||
|
session = _session()
|
||||||
|
sanitized, error = session._validate_and_sanitize_client_metadata(
|
||||||
|
{
|
||||||
|
"overrides": {
|
||||||
|
"output": {"mode": "audio"},
|
||||||
|
"apiKey": "xxx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert sanitized == {}
|
||||||
|
assert error is not None
|
||||||
|
assert error["code"] == "protocol.invalid_override"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_metadata_ignores_workflow_payload():
|
||||||
|
session = _session()
|
||||||
|
sanitized, error = session._validate_and_sanitize_client_metadata(
|
||||||
|
{
|
||||||
|
"workflow": {"nodes": [], "edges": []},
|
||||||
|
"channel": "web_debug",
|
||||||
|
"overrides": {"output": {"mode": "text"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert error is None
|
||||||
|
assert "workflow" not in sanitized
|
||||||
|
assert sanitized["channel"] == "web_debug"
|
||||||
|
assert sanitized["overrides"]["output"]["mode"] == "text"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_server_runtime_metadata_returns_not_found_error():
|
||||||
|
session = _session()
|
||||||
|
|
||||||
|
class _Gateway:
|
||||||
|
async def fetch_assistant_config(self, assistant_id: str):
|
||||||
|
_ = assistant_id
|
||||||
|
return {"__error_code": "assistant.not_found"}
|
||||||
|
|
||||||
|
session._backend_gateway = _Gateway()
|
||||||
|
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
||||||
|
assert runtime == {}
|
||||||
|
assert error is not None
|
||||||
|
assert error["code"] == "assistant.not_found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
|
||||||
|
session = _session()
|
||||||
|
|
||||||
|
class _Gateway:
|
||||||
|
async def fetch_assistant_config(self, assistant_id: str):
|
||||||
|
_ = assistant_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
session._backend_gateway = _Gateway()
|
||||||
|
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
||||||
|
assert runtime == {}
|
||||||
|
assert error is not None
|
||||||
|
assert error["code"] == "assistant.config_unavailable"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
|
||||||
|
session = Session.__new__(Session)
|
||||||
|
session.id = "sess_missing_assistant"
|
||||||
|
session.ws_state = WsSessionState.WAIT_START
|
||||||
|
session._assistant_id = None
|
||||||
|
|
||||||
|
class _Transport:
|
||||||
|
def __init__(self):
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
transport = _Transport()
|
||||||
|
session.transport = transport
|
||||||
|
captured_codes = []
|
||||||
|
|
||||||
|
async def _send_error(sender, message, code, **kwargs):
|
||||||
|
_ = (sender, message, kwargs)
|
||||||
|
captured_codes.append(code)
|
||||||
|
|
||||||
|
session._send_error = _send_error
|
||||||
|
|
||||||
|
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
|
||||||
|
assert captured_codes == ["protocol.assistant_id_required"]
|
||||||
|
assert transport.closed is True
|
||||||
|
assert session.ws_state == WsSessionState.STOPPED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow():
|
||||||
|
session = Session.__new__(Session)
|
||||||
|
session.id = "sess_start_ok"
|
||||||
|
session.ws_state = WsSessionState.WAIT_START
|
||||||
|
session.state = "created"
|
||||||
|
session._assistant_id = "assistant_demo"
|
||||||
|
session.current_track_id = Session.TRACK_CONTROL
|
||||||
|
session._pipeline_started = False
|
||||||
|
|
||||||
|
class _Transport:
|
||||||
|
async def close(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _Pipeline:
|
||||||
|
def __init__(self):
|
||||||
|
self.started = False
|
||||||
|
self.applied = {}
|
||||||
|
self.conversation = type("Conversation", (), {"system_prompt": ""})()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
async def emit_initial_greeting(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply_runtime_overrides(self, metadata):
|
||||||
|
self.applied = dict(metadata)
|
||||||
|
|
||||||
|
def resolved_runtime_config(self):
|
||||||
|
return {
|
||||||
|
"output": {"mode": "text"},
|
||||||
|
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
|
||||||
|
"tools": {"allowlist": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
session.transport = _Transport()
|
||||||
|
session.pipeline = _Pipeline()
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def _start_history_bridge(_metadata):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _load_server_runtime_metadata(_assistant_id):
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"assistantId": "assistant_demo",
|
||||||
|
"configVersionId": "cfg_1",
|
||||||
|
"systemPrompt": "Base prompt",
|
||||||
|
"greeting": "Base greeting",
|
||||||
|
"output": {"mode": "audio"},
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _send_event(event):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
async def _send_error(sender, message, code, **kwargs):
|
||||||
|
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
||||||
|
|
||||||
|
session._start_history_bridge = _start_history_bridge
|
||||||
|
session._load_server_runtime_metadata = _load_server_runtime_metadata
|
||||||
|
session._send_event = _send_event
|
||||||
|
session._send_error = _send_error
|
||||||
|
|
||||||
|
await session._handle_session_start(
|
||||||
|
SessionStartMessage(
|
||||||
|
type="session.start",
|
||||||
|
metadata={
|
||||||
|
"workflow": {"nodes": []},
|
||||||
|
"channel": "web_debug",
|
||||||
|
"source": "debug_ui",
|
||||||
|
"history": {"userId": 7},
|
||||||
|
"overrides": {
|
||||||
|
"greeting": "Override greeting",
|
||||||
|
"output": {"mode": "text"},
|
||||||
|
"tools": [{"name": "calculator"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session.ws_state == WsSessionState.ACTIVE
|
||||||
|
assert session.pipeline.started is True
|
||||||
|
assert session.pipeline.applied["assistantId"] == "assistant_demo"
|
||||||
|
assert session.pipeline.applied["greeting"] == "Override greeting"
|
||||||
|
assert session.pipeline.applied["output"]["mode"] == "text"
|
||||||
|
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
|
||||||
|
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
|
||||||
|
|
||||||
|
config_event = next(item for item in events if item.get("type") == "config.resolved")
|
||||||
|
assert config_event["config"]["appId"] == "assistant_demo"
|
||||||
|
assert config_event["config"]["channel"] == "web_debug"
|
||||||
@@ -2842,6 +2842,61 @@ export const DebugDrawer: React.FC<{
|
|||||||
return error;
|
return error;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const METADATA_OVERRIDE_WHITELIST = new Set([
|
||||||
|
'firstTurnMode',
|
||||||
|
'greeting',
|
||||||
|
'generatedOpenerEnabled',
|
||||||
|
'systemPrompt',
|
||||||
|
'output',
|
||||||
|
'bargeIn',
|
||||||
|
'knowledge',
|
||||||
|
'knowledgeBaseId',
|
||||||
|
'openerAudio',
|
||||||
|
'tools',
|
||||||
|
]);
|
||||||
|
const METADATA_FORBIDDEN_SECRET_TOKENS = ['apikey', 'token', 'secret', 'password', 'authorization'];
|
||||||
|
const isPlainObject = (value: unknown): value is Record<string, any> => Boolean(value) && typeof value === 'object' && !Array.isArray(value);
|
||||||
|
const isForbiddenSecretKey = (key: string): boolean => {
|
||||||
|
const normalized = key.toLowerCase().replace(/[_-]/g, '');
|
||||||
|
return METADATA_FORBIDDEN_SECRET_TOKENS.some((token) => normalized.includes(token));
|
||||||
|
};
|
||||||
|
const stripForbiddenSecretKeysDeep = (value: any): any => {
|
||||||
|
if (Array.isArray(value)) return value.map(stripForbiddenSecretKeysDeep);
|
||||||
|
if (!isPlainObject(value)) return value;
|
||||||
|
return Object.entries(value).reduce<Record<string, any>>((acc, [key, nested]) => {
|
||||||
|
if (isForbiddenSecretKey(key)) return acc;
|
||||||
|
acc[key] = stripForbiddenSecretKeysDeep(nested);
|
||||||
|
return acc;
|
||||||
|
}, {});
|
||||||
|
};
|
||||||
|
const sanitizeMetadataForWs = (raw: unknown): Record<string, any> => {
|
||||||
|
if (!isPlainObject(raw)) return { overrides: {} };
|
||||||
|
const sanitized: Record<string, any> = { overrides: {} };
|
||||||
|
|
||||||
|
if (typeof raw.channel === 'string' && raw.channel.trim()) {
|
||||||
|
sanitized.channel = raw.channel.trim();
|
||||||
|
}
|
||||||
|
if (typeof raw.source === 'string' && raw.source.trim()) {
|
||||||
|
sanitized.source = raw.source.trim();
|
||||||
|
}
|
||||||
|
if (isPlainObject(raw.history) && raw.history.userId !== undefined) {
|
||||||
|
sanitized.history = { userId: raw.history.userId };
|
||||||
|
}
|
||||||
|
if (isPlainObject(raw.dynamicVariables)) {
|
||||||
|
sanitized.dynamicVariables = raw.dynamicVariables;
|
||||||
|
}
|
||||||
|
if (isPlainObject(raw.overrides)) {
|
||||||
|
const overrides = Object.entries(raw.overrides).reduce<Record<string, any>>((acc, [key, value]) => {
|
||||||
|
if (!METADATA_OVERRIDE_WHITELIST.has(key)) return acc;
|
||||||
|
if (isForbiddenSecretKey(key)) return acc;
|
||||||
|
acc[key] = stripForbiddenSecretKeysDeep(value);
|
||||||
|
return acc;
|
||||||
|
}, {});
|
||||||
|
sanitized.overrides = overrides;
|
||||||
|
}
|
||||||
|
return sanitized;
|
||||||
|
};
|
||||||
|
|
||||||
const buildDynamicVariablesPayload = (): { variables: Record<string, string>; error?: string } => {
|
const buildDynamicVariablesPayload = (): { variables: Record<string, string>; error?: string } => {
|
||||||
const variables: Record<string, string> = {};
|
const variables: Record<string, string> = {};
|
||||||
const nonEmptyRows = dynamicVariables
|
const nonEmptyRows = dynamicVariables
|
||||||
@@ -2908,84 +2963,16 @@ export const DebugDrawer: React.FC<{
|
|||||||
|
|
||||||
const buildLocalResolvedRuntime = () => {
|
const buildLocalResolvedRuntime = () => {
|
||||||
const warnings: string[] = [];
|
const warnings: string[] = [];
|
||||||
const services: Record<string, any> = {};
|
|
||||||
const ttsEnabled = Boolean(textTtsEnabled);
|
const ttsEnabled = Boolean(textTtsEnabled);
|
||||||
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
|
|
||||||
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
|
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
|
||||||
const knowledge = knowledgeBaseId
|
const knowledge = knowledgeBaseId
|
||||||
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
|
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
|
||||||
: { enabled: false };
|
: { enabled: false };
|
||||||
|
|
||||||
if (isExternalLlm) {
|
|
||||||
services.llm = {
|
|
||||||
provider: 'openai',
|
|
||||||
model: '',
|
|
||||||
apiKey: assistant.apiKey || '',
|
|
||||||
baseUrl: assistant.apiUrl || '',
|
|
||||||
};
|
|
||||||
if (!assistant.apiUrl) warnings.push(`External LLM API URL is empty for mode: ${assistant.configMode}`);
|
|
||||||
if (!assistant.apiKey) warnings.push(`External LLM API key is empty for mode: ${assistant.configMode}`);
|
|
||||||
} else if (assistant.llmModelId) {
|
|
||||||
const llm = llmModels.find((item) => item.id === assistant.llmModelId);
|
|
||||||
if (llm) {
|
|
||||||
services.llm = {
|
|
||||||
provider: 'openai',
|
|
||||||
model: llm.modelName || llm.name,
|
|
||||||
apiKey: llm.apiKey,
|
|
||||||
baseUrl: llm.baseUrl,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
warnings.push(`LLM model not found in loaded list: ${assistant.llmModelId}`);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Keep empty object to indicate engine should use default provider model.
|
|
||||||
services.llm = {};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (assistant.asrModelId) {
|
|
||||||
const asr = asrModels.find((item) => item.id === assistant.asrModelId);
|
|
||||||
if (asr) {
|
|
||||||
const asrProvider = isOpenAICompatibleVendor(asr.vendor) ? 'openai_compatible' : 'buffered';
|
|
||||||
services.asr = {
|
|
||||||
provider: asrProvider,
|
|
||||||
model: asr.modelName || asr.name,
|
|
||||||
apiKey: asrProvider === 'openai_compatible' ? asr.apiKey : null,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
warnings.push(`ASR model not found in loaded list: ${assistant.asrModelId}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (assistant.voice) {
|
|
||||||
const voice = voices.find((item) => item.id === assistant.voice);
|
|
||||||
if (voice) {
|
|
||||||
const ttsProvider = isOpenAICompatibleVendor(voice.vendor) ? 'openai_compatible' : 'edge';
|
|
||||||
services.tts = {
|
|
||||||
enabled: ttsEnabled,
|
|
||||||
provider: ttsProvider,
|
|
||||||
model: voice.model,
|
|
||||||
apiKey: ttsProvider === 'openai_compatible' ? voice.apiKey : null,
|
|
||||||
voice: resolveRuntimeTtsVoice(assistant.voice, voice),
|
|
||||||
speed: assistant.speed || voice.speed || 1.0,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
services.tts = {
|
|
||||||
enabled: ttsEnabled,
|
|
||||||
voice: assistant.voice,
|
|
||||||
speed: assistant.speed || 1.0,
|
|
||||||
};
|
|
||||||
warnings.push(`Voice resource not found in loaded list: ${assistant.voice}`);
|
|
||||||
}
|
|
||||||
} else if (!ttsEnabled) {
|
|
||||||
services.tts = {
|
|
||||||
enabled: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const localResolved = {
|
const localResolved = {
|
||||||
assistantId: assistant.id,
|
|
||||||
warnings,
|
warnings,
|
||||||
sessionStartMetadata: {
|
sessionStartMetadata: {
|
||||||
|
overrides: {
|
||||||
output: {
|
output: {
|
||||||
mode: ttsEnabled ? 'audio' : 'text',
|
mode: ttsEnabled ? 'audio' : 'text',
|
||||||
},
|
},
|
||||||
@@ -3000,12 +2987,11 @@ export const DebugDrawer: React.FC<{
|
|||||||
knowledgeBaseId,
|
knowledgeBaseId,
|
||||||
knowledge,
|
knowledge,
|
||||||
tools: selectedToolSchemas,
|
tools: selectedToolSchemas,
|
||||||
services,
|
|
||||||
history: {
|
|
||||||
assistantId: assistant.id,
|
|
||||||
userId: 1,
|
|
||||||
source: 'debug',
|
|
||||||
},
|
},
|
||||||
|
history: {
|
||||||
|
userId: 1,
|
||||||
|
},
|
||||||
|
source: 'web_debug',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -3020,21 +3006,13 @@ export const DebugDrawer: React.FC<{
|
|||||||
}
|
}
|
||||||
setDynamicVariablesError('');
|
setDynamicVariablesError('');
|
||||||
const localResolved = buildLocalResolvedRuntime();
|
const localResolved = buildLocalResolvedRuntime();
|
||||||
const mergedMetadata: Record<string, any> = {
|
const mergedMetadata: Record<string, any> = sanitizeMetadataForWs({
|
||||||
...localResolved.sessionStartMetadata,
|
...localResolved.sessionStartMetadata,
|
||||||
...(sessionMetadataExtras || {}),
|
...(sessionMetadataExtras || {}),
|
||||||
};
|
});
|
||||||
if (Object.keys(dynamicVariablesResult.variables).length > 0) {
|
if (Object.keys(dynamicVariablesResult.variables).length > 0) {
|
||||||
mergedMetadata.dynamicVariables = dynamicVariablesResult.variables;
|
mergedMetadata.dynamicVariables = dynamicVariablesResult.variables;
|
||||||
}
|
}
|
||||||
// Engine resolves trusted runtime config by top-level assistant/app ID.
|
|
||||||
// Keep these IDs at metadata root so backend /assistants/{id}/config is reachable.
|
|
||||||
if (!mergedMetadata.assistantId && assistant.id) {
|
|
||||||
mergedMetadata.assistantId = assistant.id;
|
|
||||||
}
|
|
||||||
if (!mergedMetadata.appId && assistant.id) {
|
|
||||||
mergedMetadata.appId = assistant.id;
|
|
||||||
}
|
|
||||||
if (!mergedMetadata.channel) {
|
if (!mergedMetadata.channel) {
|
||||||
mergedMetadata.channel = 'web_debug';
|
mergedMetadata.channel = 'web_debug';
|
||||||
}
|
}
|
||||||
@@ -3069,6 +3047,24 @@ export const DebugDrawer: React.FC<{
|
|||||||
if (isOpen) setWsStatus('disconnected');
|
if (isOpen) setWsStatus('disconnected');
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildSessionWsUrl = () => {
|
||||||
|
const base = wsUrl.trim();
|
||||||
|
if (!base) return '';
|
||||||
|
try {
|
||||||
|
const parsed = new URL(base);
|
||||||
|
parsed.searchParams.set('assistant_id', assistant.id);
|
||||||
|
return parsed.toString();
|
||||||
|
} catch {
|
||||||
|
try {
|
||||||
|
const parsed = new URL(base, window.location.href);
|
||||||
|
parsed.searchParams.set('assistant_id', assistant.id);
|
||||||
|
return parsed.toString();
|
||||||
|
} catch {
|
||||||
|
return base;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const ensureWsSession = async () => {
|
const ensureWsSession = async () => {
|
||||||
if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
||||||
return;
|
return;
|
||||||
@@ -3083,18 +3079,25 @@ export const DebugDrawer: React.FC<{
|
|||||||
}
|
}
|
||||||
|
|
||||||
const metadata = await fetchRuntimeMetadata();
|
const metadata = await fetchRuntimeMetadata();
|
||||||
|
const sessionWsUrl = buildSessionWsUrl();
|
||||||
setWsStatus('connecting');
|
setWsStatus('connecting');
|
||||||
setWsError('');
|
setWsError('');
|
||||||
|
|
||||||
await new Promise<void>((resolve, reject) => {
|
await new Promise<void>((resolve, reject) => {
|
||||||
pendingResolveRef.current = resolve;
|
pendingResolveRef.current = resolve;
|
||||||
pendingRejectRef.current = reject;
|
pendingRejectRef.current = reject;
|
||||||
const ws = new WebSocket(wsUrl);
|
const ws = new WebSocket(sessionWsUrl);
|
||||||
ws.binaryType = 'arraybuffer';
|
ws.binaryType = 'arraybuffer';
|
||||||
wsRef.current = ws;
|
wsRef.current = ws;
|
||||||
|
|
||||||
ws.onopen = () => {
|
ws.onopen = () => {
|
||||||
ws.send(JSON.stringify({ type: 'hello', version: 'v1' }));
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: 'session.start',
|
||||||
|
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
@@ -3118,16 +3121,6 @@ export const DebugDrawer: React.FC<{
|
|||||||
if (onProtocolEvent) {
|
if (onProtocolEvent) {
|
||||||
onProtocolEvent(payload);
|
onProtocolEvent(payload);
|
||||||
}
|
}
|
||||||
if (type === 'hello.ack') {
|
|
||||||
ws.send(
|
|
||||||
JSON.stringify({
|
|
||||||
type: 'session.start',
|
|
||||||
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
|
|
||||||
metadata,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (type === 'output.audio.start') {
|
if (type === 'output.audio.start') {
|
||||||
// New utterance audio starts: cancel old queued/playing audio to avoid overlap.
|
// New utterance audio starts: cancel old queued/playing audio to avoid overlap.
|
||||||
stopPlaybackImmediately();
|
stopPlaybackImmediately();
|
||||||
|
|||||||
Reference in New Issue
Block a user