Update engine

This commit is contained in:
Xin Wang
2026-02-23 17:16:18 +08:00
parent 01c0de0a4d
commit c6c84b5af9
9 changed files with 991 additions and 186 deletions

View File

@@ -57,10 +57,15 @@ class WavFileClient:
url: str,
input_file: str,
output_file: str,
app_id: str = "assistant_demo",
channel: str = "wav_client",
config_version_id: str = "local-dev",
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
wait_time: float = 15.0,
verbose: bool = False
verbose: bool = False,
track_debug: bool = False,
tail_silence_ms: int = 800,
):
"""
Initialize WAV file client.
@@ -77,11 +82,17 @@ class WavFileClient:
self.url = url
self.input_file = Path(input_file)
self.output_file = Path(output_file)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.wait_time = wait_time
self.verbose = verbose
self.track_debug = track_debug
self.tail_silence_ms = max(0, int(tail_silence_ms))
self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms
# WebSocket connection
self.ws = None
@@ -125,6 +136,17 @@ class WavFileClient:
# Replace problematic characters for console output
safe_message = message.encode('ascii', errors='replace').decode('ascii')
print(f"{direction} {safe_message}")
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None:
"""Connect to WebSocket server."""
@@ -144,7 +166,12 @@ class WavFileClient:
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1
}
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
})
async def send_command(self, cmd: dict) -> None:
@@ -216,6 +243,10 @@ class WavFileClient:
end_sample = min(sent_samples + chunk_size, total_samples)
chunk = audio_data[sent_samples:end_sample]
chunk_bytes = chunk.tobytes()
if len(chunk_bytes) % self.frame_bytes != 0:
# v1 audio framing requires 640-byte (20ms) PCM units.
pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes)
chunk_bytes += b"\x00" * pad
# Send to server
if self.ws:
@@ -232,6 +263,16 @@ class WavFileClient:
# Delay to simulate real-time streaming
# Server expects audio at real-time pace for VAD/ASR to work properly
await asyncio.sleep(self.chunk_duration_ms / 1000)
# Add a short silence tail to help VAD/EOU close the final utterance.
if self.tail_silence_ms > 0 and self.ws:
tail_frames = max(1, self.tail_silence_ms // 20)
silence = b"\x00" * self.frame_bytes
for _ in range(tail_frames):
await self.ws.send(silence)
self.bytes_sent += len(silence)
await asyncio.sleep(0.02)
self.log_event("", f"Sent trailing silence: {self.tail_silence_ms}ms")
self.send_completed = True
elapsed = time.time() - self.send_start_time
@@ -284,16 +325,22 @@ class WavFileClient:
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("type", "unknown")
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "hello.ack":
self.log_event("", "Handshake acknowledged")
self.log_event("", f"Handshake acknowledged{ids}")
elif event_type == "session.started":
self.session_ready = True
self.log_event("", "Session ready!")
self.log_event("", f"Session ready!{ids}")
elif event_type == "config.resolved":
config = event.get("config", {})
self.log_event("", f"Config resolved (output={config.get('output', {})}){ids}")
elif event_type == "input.speech_started":
self.log_event("", "Speech detected")
self.log_event("", f"Speech detected{ids}")
elif event_type == "input.speech_stopped":
self.log_event("", "Silence detected")
self.log_event("", f"Silence detected{ids}")
elif event_type == "transcript.delta":
text = event.get("text", "")
display_text = text[:60] + "..." if len(text) > 60 else text
@@ -301,35 +348,35 @@ class WavFileClient:
elif event_type == "transcript.final":
text = event.get("text", "")
print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}")
self.log_event("", f"→ You: {text}{ids}")
elif event_type == "metrics.ttfb":
latency_ms = event.get("latencyMs", 0)
self.log_event("", f"[TTFB] Server latency: {latency_ms}ms")
elif event_type == "assistant.response.delta":
text = event.get("text", "")
if self.verbose and text:
self.log_event("", f"LLM: {text}")
self.log_event("", f"LLM: {text}{ids}")
elif event_type == "assistant.response.final":
text = event.get("text", "")
if text:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}")
elif event_type == "output.audio.start":
self.track_started = True
self.response_start_time = time.time()
self.waiting_for_first_audio = True
self.log_event("", "Bot started speaking")
self.log_event("", f"Bot started speaking{ids}")
elif event_type == "output.audio.end":
self.track_ended = True
self.log_event("", "Bot finished speaking")
self.log_event("", f"Bot finished speaking{ids}")
elif event_type == "response.interrupted":
self.log_event("", "Bot interrupted!")
self.log_event("", f"Bot interrupted!{ids}")
elif event_type == "error":
self.log_event("!", f"Error: {event.get('message')}")
self.log_event("!", f"Error: {event.get('message')}{ids}")
elif event_type == "session.stopped":
self.log_event("", f"Session stopped: {event.get('reason')}")
self.log_event("", f"Session stopped: {event.get('reason')}{ids}")
self.running = False
else:
self.log_event("", f"Event: {event_type}")
self.log_event("", f"Event: {event_type}{ids}")
def save_output_wav(self) -> None:
"""Save received audio to output WAV file."""
@@ -473,6 +520,21 @@ async def main():
default=16000,
help="Target sample rate for audio (default: 16000)"
)
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="wav_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--chunk-duration",
type=int,
@@ -490,6 +552,17 @@ async def main():
action="store_true",
help="Enable verbose output"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
parser.add_argument(
"--tail-silence-ms",
type=int,
default=800,
help="Trailing silence to send after WAV playback for EOU detection (default: 800)"
)
args = parser.parse_args()
@@ -497,10 +570,15 @@ async def main():
url=args.url,
input_file=args.input,
output_file=args.output,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time,
verbose=args.verbose
verbose=args.verbose,
track_debug=args.track_debug,
tail_silence_ms=args.tail_silence_ms,
)
await client.run()