Enhance WebSocket session management by requiring assistant_id as a query parameter for connection. Update API reference documentation to reflect changes in message flow and metadata validation rules, including the introduction of whitelists for allowed metadata fields and restrictions on sensitive keys. Refactor client examples to align with the new session initiation process.
This commit is contained in:
@@ -23,6 +23,7 @@ import time
|
||||
import threading
|
||||
import queue
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@@ -59,9 +60,8 @@ class MicrophoneClient:
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
app_id: str = "assistant_demo",
|
||||
assistant_id: str = "assistant_demo",
|
||||
channel: str = "mic_client",
|
||||
config_version_id: str = "local-dev",
|
||||
input_device: int = None,
|
||||
output_device: int = None,
|
||||
track_debug: bool = False,
|
||||
@@ -80,9 +80,8 @@ class MicrophoneClient:
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.app_id = app_id
|
||||
self.assistant_id = assistant_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.input_device = input_device
|
||||
self.output_device = output_device
|
||||
self.track_debug = track_debug
|
||||
@@ -125,19 +124,21 @@ class MicrophoneClient:
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
def _session_url(self) -> str:
|
||||
parts = urlsplit(self.url)
|
||||
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||
query["assistant_id"] = self.assistant_id
|
||||
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
session_url = self._session_url()
|
||||
print(f"Connecting to {session_url}...")
|
||||
self.ws = await websockets.connect(session_url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.send_command({
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
})
|
||||
|
||||
await self.send_command({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
@@ -146,9 +147,8 @@ class MicrophoneClient:
|
||||
"channels": 1,
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
"source": "mic_client",
|
||||
},
|
||||
})
|
||||
|
||||
@@ -330,7 +330,7 @@ class MicrophoneClient:
|
||||
if self.track_debug:
|
||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||
|
||||
if event_type in {"hello.ack", "session.started"}:
|
||||
if event_type == "session.started":
|
||||
print(f"← Session ready!{ids}")
|
||||
elif event_type == "config.resolved":
|
||||
print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}")
|
||||
@@ -609,20 +609,15 @@ async def main():
|
||||
help="Show streaming LLM response chunks"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-id",
|
||||
"--assistant-id",
|
||||
default="assistant_demo",
|
||||
help="Stable app/assistant identifier for server-side config lookup"
|
||||
help="Assistant identifier used in websocket query parameter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
default="mic_client",
|
||||
help="Client channel name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-version-id",
|
||||
default="local-dev",
|
||||
help="Optional config version identifier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--track-debug",
|
||||
action="store_true",
|
||||
@@ -638,9 +633,8 @@ async def main():
|
||||
client = MicrophoneClient(
|
||||
url=args.url,
|
||||
sample_rate=args.sample_rate,
|
||||
app_id=args.app_id,
|
||||
assistant_id=args.assistant_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
input_device=args.input_device,
|
||||
output_device=args.output_device,
|
||||
track_debug=args.track_debug,
|
||||
|
||||
@@ -15,6 +15,7 @@ import sys
|
||||
import time
|
||||
import wave
|
||||
import io
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@@ -56,16 +57,14 @@ class SimpleVoiceClient:
|
||||
self,
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
app_id: str = "assistant_demo",
|
||||
assistant_id: str = "assistant_demo",
|
||||
channel: str = "simple_client",
|
||||
config_version_id: str = "local-dev",
|
||||
track_debug: bool = False,
|
||||
):
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.app_id = app_id
|
||||
self.assistant_id = assistant_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.track_debug = track_debug
|
||||
self.ws = None
|
||||
self.running = False
|
||||
@@ -88,6 +87,12 @@ class SimpleVoiceClient:
|
||||
# Interrupt handling - discard audio until next trackStart
|
||||
self._discard_audio = False
|
||||
|
||||
def _session_url(self) -> str:
|
||||
parts = urlsplit(self.url)
|
||||
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||
query["assistant_id"] = self.assistant_id
|
||||
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||
|
||||
@staticmethod
|
||||
def _event_ids_suffix(event: dict) -> str:
|
||||
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
||||
@@ -101,16 +106,12 @@ class SimpleVoiceClient:
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
session_url = self._session_url()
|
||||
print(f"Connecting to {session_url}...")
|
||||
self.ws = await websockets.connect(session_url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
}))
|
||||
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
@@ -119,12 +120,11 @@ class SimpleVoiceClient:
|
||||
"channels": 1,
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
"source": "simple_client",
|
||||
},
|
||||
}))
|
||||
print("-> hello/session.start")
|
||||
print("-> session.start")
|
||||
|
||||
async def send_chat(self, text: str):
|
||||
"""Send chat message."""
|
||||
@@ -311,9 +311,8 @@ async def main():
|
||||
parser.add_argument("--text", help="Send text and play response")
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||
parser.add_argument("--app-id", default="assistant_demo")
|
||||
parser.add_argument("--assistant-id", default="assistant_demo")
|
||||
parser.add_argument("--channel", default="simple_client")
|
||||
parser.add_argument("--config-version-id", default="local-dev")
|
||||
parser.add_argument("--track-debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -325,9 +324,8 @@ async def main():
|
||||
client = SimpleVoiceClient(
|
||||
args.url,
|
||||
args.sample_rate,
|
||||
app_id=args.app_id,
|
||||
assistant_id=args.assistant_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
track_debug=args.track_debug,
|
||||
)
|
||||
await client.run(args.text)
|
||||
|
||||
@@ -12,6 +12,7 @@ import math
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
@@ -130,14 +131,17 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
|
||||
"""Run the WebSocket test client."""
|
||||
session = aiohttp.ClientSession()
|
||||
try:
|
||||
print(f"🔌 Connecting to {url}...")
|
||||
async with session.ws_connect(url) as ws:
|
||||
parts = urlsplit(url)
|
||||
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||
query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo")
|
||||
session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||
print(f"🔌 Connecting to {session_url}...")
|
||||
async with session.ws_connect(session_url) as ws:
|
||||
print("✅ Connected!")
|
||||
session_ready = asyncio.Event()
|
||||
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
|
||||
|
||||
# Send v1 hello + session.start handshake
|
||||
await ws.send_json({"type": "hello", "version": "v1"})
|
||||
# Send v1 session.start initialization
|
||||
await ws.send_json({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
@@ -146,12 +150,11 @@ async def run_client(url, file_path=None, use_sine=False, track_debug: bool = Fa
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"appId": "assistant_demo",
|
||||
"channel": "test_websocket",
|
||||
"configVersionId": "local-dev",
|
||||
"source": "test_websocket",
|
||||
},
|
||||
})
|
||||
print("📤 Sent v1 hello/session.start")
|
||||
print("📤 Sent v1 session.start")
|
||||
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
||||
|
||||
# Select sender based on args
|
||||
|
||||
@@ -21,6 +21,7 @@ import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@@ -57,9 +58,8 @@ class WavFileClient:
|
||||
url: str,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
app_id: str = "assistant_demo",
|
||||
assistant_id: str = "assistant_demo",
|
||||
channel: str = "wav_client",
|
||||
config_version_id: str = "local-dev",
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
wait_time: float = 15.0,
|
||||
@@ -82,9 +82,8 @@ class WavFileClient:
|
||||
self.url = url
|
||||
self.input_file = Path(input_file)
|
||||
self.output_file = Path(output_file)
|
||||
self.app_id = app_id
|
||||
self.assistant_id = assistant_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
@@ -147,19 +146,21 @@ class WavFileClient:
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
def _session_url(self) -> str:
|
||||
parts = urlsplit(self.url)
|
||||
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||
query["assistant_id"] = self.assistant_id
|
||||
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
self.log_event("→", f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
session_url = self._session_url()
|
||||
self.log_event("→", f"Connecting to {session_url}...")
|
||||
self.ws = await websockets.connect(session_url)
|
||||
self.running = True
|
||||
self.log_event("←", "Connected!")
|
||||
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.send_command({
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
})
|
||||
await self.send_command({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
@@ -168,9 +169,8 @@ class WavFileClient:
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
"source": "wav_client",
|
||||
},
|
||||
})
|
||||
|
||||
@@ -329,9 +329,7 @@ class WavFileClient:
|
||||
if self.track_debug:
|
||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||
|
||||
if event_type == "hello.ack":
|
||||
self.log_event("←", f"Handshake acknowledged{ids}")
|
||||
elif event_type == "session.started":
|
||||
if event_type == "session.started":
|
||||
self.session_ready = True
|
||||
self.log_event("←", f"Session ready!{ids}")
|
||||
elif event_type == "config.resolved":
|
||||
@@ -521,20 +519,15 @@ async def main():
|
||||
help="Target sample rate for audio (default: 16000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-id",
|
||||
"--assistant-id",
|
||||
default="assistant_demo",
|
||||
help="Stable app/assistant identifier for server-side config lookup"
|
||||
help="Assistant identifier used in websocket query parameter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
default="wav_client",
|
||||
help="Client channel name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-version-id",
|
||||
default="local-dev",
|
||||
help="Optional config version identifier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-duration",
|
||||
type=int,
|
||||
@@ -570,9 +563,8 @@ async def main():
|
||||
url=args.url,
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
app_id=args.app_id,
|
||||
assistant_id=args.assistant_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
sample_rate=args.sample_rate,
|
||||
chunk_duration_ms=args.chunk_duration,
|
||||
wait_time=args.wait_time,
|
||||
|
||||
@@ -401,9 +401,14 @@
|
||||
|
||||
const targetSampleRate = 16000;
|
||||
const playbackStopRampSec = 0.008;
|
||||
const appId = "assistant_demo";
|
||||
const assistantId = "assistant_demo";
|
||||
const channel = "web_client";
|
||||
const configVersionId = "local-dev";
|
||||
|
||||
function buildSessionWsUrl(baseUrl) {
|
||||
const parsed = new URL(baseUrl);
|
||||
parsed.searchParams.set("assistant_id", assistantId);
|
||||
return parsed.toString();
|
||||
}
|
||||
|
||||
function logLine(type, text, data) {
|
||||
const time = new Date().toLocaleTimeString();
|
||||
@@ -556,14 +561,25 @@
|
||||
|
||||
async function connect() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) return;
|
||||
ws = new WebSocket(wsUrl.value.trim());
|
||||
const sessionWsUrl = buildSessionWsUrl(wsUrl.value.trim());
|
||||
ws = new WebSocket(sessionWsUrl);
|
||||
ws.binaryType = "arraybuffer";
|
||||
|
||||
ws.onopen = () => {
|
||||
setStatus(true, "Session open");
|
||||
logLine("sys", "WebSocket connected");
|
||||
ensureAudioContext();
|
||||
sendCommand({ type: "hello", version: "v1" });
|
||||
sendCommand({
|
||||
type: "session.start",
|
||||
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||
metadata: {
|
||||
channel,
|
||||
source: "web_client",
|
||||
overrides: {
|
||||
output: { mode: "audio" },
|
||||
},
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
@@ -622,17 +638,6 @@
|
||||
const type = event.type || "unknown";
|
||||
const ids = eventIdsSuffix(event);
|
||||
logLine("event", `${type}${ids}`, event);
|
||||
if (type === "hello.ack") {
|
||||
sendCommand({
|
||||
type: "session.start",
|
||||
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||
metadata: {
|
||||
appId,
|
||||
channel,
|
||||
configVersionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
if (type === "config.resolved") {
|
||||
logLine("sys", "config.resolved", event.config || {});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user