199 lines
6.9 KiB
Python
199 lines
6.9 KiB
Python
"""WebSocket endpoint test client.
|
|
|
|
Tests the /ws endpoint with sine wave or file audio streaming.
|
|
Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py
|
|
"""
|
|
|
|
import asyncio
|
|
import aiohttp
|
|
import json
|
|
import struct
|
|
import math
|
|
import argparse
|
|
import os
|
|
from datetime import datetime
|
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
|
|
|
# Configuration
|
|
SERVER_URL = "ws://localhost:8000/ws"
|
|
SAMPLE_RATE = 16000
|
|
FREQUENCY = 440 # 440Hz Sine Wave
|
|
CHUNK_DURATION_MS = 20
|
|
# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk
|
|
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0))
|
|
|
|
|
|
def generate_sine_wave(duration_ms=1000):
|
|
"""Generates sine wave audio (16kHz mono PCM 16-bit)."""
|
|
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
|
audio_data = bytearray()
|
|
|
|
for x in range(num_samples):
|
|
# Generate sine wave sample
|
|
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
|
# Pack as little-endian 16-bit integer
|
|
audio_data.extend(struct.pack('<h', value))
|
|
|
|
return audio_data
|
|
|
|
|
|
async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False):
|
|
"""Listen for incoming messages from the server."""
|
|
def event_ids_suffix(data):
|
|
payload = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
|
|
parts = []
|
|
for key in keys:
|
|
value = payload.get(key, data.get(key))
|
|
if value:
|
|
parts.append(f"{key}={value}")
|
|
return f" [{' '.join(parts)}]" if parts else ""
|
|
|
|
print("👂 Listening for server responses...")
|
|
async for msg in ws:
|
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
try:
|
|
data = json.loads(msg.data)
|
|
event_type = data.get('type', 'Unknown')
|
|
ids = event_ids_suffix(data)
|
|
print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...")
|
|
if track_debug:
|
|
print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}")
|
|
if event_type == "session.started":
|
|
ready_event.set()
|
|
except json.JSONDecodeError:
|
|
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
|
|
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
# Received audio chunk back (e.g., TTS or echo)
|
|
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes", end="\r")
|
|
|
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
print(f"\n[{timestamp}] ❌ Socket Closed")
|
|
break
|
|
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
print(f"\n[{timestamp}] ⚠️ Socket Error")
|
|
break
|
|
|
|
|
|
async def send_file_loop(ws, file_path):
|
|
"""Stream a raw PCM/WAV file to the server."""
|
|
if not os.path.exists(file_path):
|
|
print(f"❌ Error: File '{file_path}' not found.")
|
|
return
|
|
|
|
print(f"📂 Streaming file: {file_path} ...")
|
|
|
|
with open(file_path, "rb") as f:
|
|
# Skip WAV header if present (first 44 bytes)
|
|
if file_path.endswith('.wav'):
|
|
f.read(44)
|
|
|
|
while True:
|
|
chunk = f.read(CHUNK_SIZE_BYTES)
|
|
if not chunk:
|
|
break
|
|
|
|
# Send binary frame
|
|
await ws.send_bytes(chunk)
|
|
|
|
# Sleep to simulate real-time playback
|
|
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
|
|
|
print(f"\n✅ Finished streaming {file_path}")
|
|
|
|
|
|
async def send_sine_loop(ws):
|
|
"""Stream generated sine wave to the server."""
|
|
print("🎙️ Starting Audio Stream (Sine Wave)...")
|
|
|
|
# Generate 10 seconds of audio buffer
|
|
audio_buffer = generate_sine_wave(5000)
|
|
cursor = 0
|
|
|
|
while cursor < len(audio_buffer):
|
|
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
|
if not chunk:
|
|
break
|
|
|
|
await ws.send_bytes(chunk)
|
|
cursor += len(chunk)
|
|
|
|
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
|
|
|
print("\n✅ Finished streaming test audio.")
|
|
|
|
|
|
async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False):
|
|
"""Run the WebSocket test client."""
|
|
session = aiohttp.ClientSession()
|
|
try:
|
|
parts = urlsplit(url)
|
|
query = dict(parse_qsl(parts.query, keep_blank_values=True))
|
|
query["assistant_id"] = str(query.get("assistant_id") or "assistant_demo")
|
|
session_url = urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
|
print(f"🔌 Connecting to {session_url}...")
|
|
async with session.ws_connect(session_url) as ws:
|
|
print("✅ Connected!")
|
|
session_ready = asyncio.Event()
|
|
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
|
|
|
|
# Send v1 session.start initialization
|
|
await ws.send_json({
|
|
"type": "session.start",
|
|
"audio": {
|
|
"encoding": "pcm_s16le",
|
|
"sample_rate_hz": SAMPLE_RATE,
|
|
"channels": 1
|
|
},
|
|
"metadata": {
|
|
"channel": "test_websocket",
|
|
"source": "test_websocket",
|
|
},
|
|
})
|
|
print("📤 Sent v1 session.start")
|
|
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
|
|
|
# Select sender based on args
|
|
if use_sine:
|
|
await send_sine_loop(ws)
|
|
elif file_path:
|
|
await send_file_loop(ws, file_path)
|
|
else:
|
|
# Default to sine wave
|
|
await send_sine_loop(ws)
|
|
|
|
await ws.send_json({"type": "session.stop", "reason": "test_complete"})
|
|
await asyncio.sleep(1)
|
|
recv_task.cancel()
|
|
try:
|
|
await recv_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
except aiohttp.ClientConnectorError:
|
|
print(f"❌ Connection Failed. Is the server running at {url}?")
|
|
except asyncio.TimeoutError:
|
|
print("❌ Timeout waiting for session.started")
|
|
except Exception as e:
|
|
print(f"❌ Error: {e}")
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
|
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
|
|
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
|
|
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
|
|
parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging")
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug))
|
|
except KeyboardInterrupt:
|
|
print("\n👋 Client stopped.")
|