From c6c84b5af963cd2d5acbad43521dd50db71b0c8a Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Mon, 23 Feb 2026 17:16:18 +0800 Subject: [PATCH] Update engine --- app/main.py | 5 +- core/duplex_pipeline.py | 380 +++++++++++++++++++++++++++++++++---- core/session.py | 298 ++++++++++++++++++++++++++--- docs/ws_v1_schema.md | 105 ++++++---- examples/mic_client.py | 131 +++++++++---- examples/simple_client.py | 92 +++++++-- examples/test_websocket.py | 31 ++- examples/wav_client.py | 110 +++++++++-- examples/web_client.html | 25 ++- 9 files changed, 991 insertions(+), 186 deletions(-) diff --git a/app/main.py b/app/main.py index 259204c..e2c6c75 100644 --- a/app/main.py +++ b/app/main.py @@ -24,7 +24,6 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport from core.session import Session from processors.tracks import Resampled16kTrack from core.events import get_event_bus, reset_event_bus -from models.ws_v1 import ev # Check interval for heartbeat/timeout (seconds) _HEARTBEAT_CHECK_INTERVAL_SEC = 5 @@ -54,9 +53,7 @@ async def heartbeat_and_timeout_task( break if now - last_heartbeat_at[0] >= heartbeat_interval_sec: try: - await transport.send_event({ - **ev("heartbeat"), - }) + await session.send_heartbeat() last_heartbeat_at[0] = now except Exception as e: logger.debug(f"Session {session_id}: heartbeat send failed: {e}") diff --git a/core/duplex_pipeline.py b/core/duplex_pipeline.py index 508ba2b..2265db4 100644 --- a/core/duplex_pipeline.py +++ b/core/duplex_pipeline.py @@ -14,7 +14,8 @@ event-driven design. import asyncio import json import time -from typing import Any, Dict, List, Optional, Tuple +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np from loguru import logger @@ -59,6 +60,12 @@ class DuplexPipeline: _MIN_SPLIT_SPOKEN_CHARS = 6 _TOOL_WAIT_TIMEOUT_SECONDS = 15.0 _SERVER_TOOL_TIMEOUT_SECONDS = 15.0 + TRACK_AUDIO_IN = "audio_in" + TRACK_AUDIO_OUT = "audio_out" + TRACK_CONTROL = "control" + _PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms + _ASR_DELTA_THROTTLE_MS = 300 + _LLM_DELTA_THROTTLE_MS = 80 _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { "current_time": { "name": "current_time", @@ -96,6 +103,9 @@ class DuplexPipeline: self.transport = transport self.session_id = session_id self.event_bus = get_event_bus() + self.track_audio_in = self.TRACK_AUDIO_IN + self.track_audio_out = self.TRACK_AUDIO_OUT + self.track_control = self.TRACK_CONTROL # Initialize VAD self.vad_model = SileroVAD( @@ -120,6 +130,8 @@ class DuplexPipeline: # Track last sent transcript to avoid duplicates self._last_sent_transcript = "" + self._pending_transcript_delta: str = "" + self._last_transcript_delta_emit_ms: float = 0.0 # Conversation manager self.conversation = ConversationManager( @@ -153,6 +165,7 @@ class DuplexPipeline: self._outbound_seq = 0 self._outbound_task: Optional[asyncio.Task] = None self._drop_outbound_audio = False + self._audio_out_frame_buffer: bytes = b"" # Interruption handling self._interrupt_event = asyncio.Event() @@ -186,9 +199,28 @@ class DuplexPipeline: self._pending_tool_waiters: Dict[str, asyncio.Future] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() + self._pending_client_tool_call_ids: set[str] = set() + self._next_seq: Optional[Callable[[], int]] = None + self._local_seq: int = 0 + + # Cross-service correlation IDs + self._turn_count: int = 0 + self._response_count: int = 0 + self._tts_count: int = 0 + self._utterance_count: int = 0 + self._current_turn_id: Optional[str] = None + self._current_utterance_id: Optional[str] = None + self._current_response_id: Optional[str] = None + self._current_tts_id: Optional[str] = None + self._pending_llm_delta: str = "" + self._last_llm_delta_emit_ms: float = 0.0 logger.info(f"DuplexPipeline initialized for session {session_id}") + def set_event_sequence_provider(self, provider: Callable[[], int]) -> None: + """Use session-scoped monotonic sequence provider for envelope events.""" + self._next_seq = provider + def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: """ Apply runtime overrides from WS session.start metadata. @@ -276,6 +308,131 @@ class DuplexPipeline: if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) + def resolved_runtime_config(self) -> Dict[str, Any]: + """Return current effective runtime configuration without secrets.""" + llm_provider = str(self._runtime_llm.get("provider") or "openai").lower() + tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower() + asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower() + output_mode = str(self._runtime_output.get("mode") or "").strip().lower() + if not output_mode: + output_mode = "audio" if self._tts_output_enabled() else "text" + + return { + "output": {"mode": output_mode}, + "services": { + "llm": { + "provider": llm_provider, + "model": str(self._runtime_llm.get("model") or settings.llm_model), + "baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url, + }, + "asr": { + "provider": asr_provider, + "model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model), + "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), + "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), + }, + "tts": { + "enabled": self._tts_output_enabled(), + "provider": tts_provider, + "model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model), + "voice": str(self._runtime_tts.get("voice") or settings.tts_voice), + "speed": float(self._runtime_tts.get("speed") or settings.tts_speed), + }, + }, + "tools": { + "allowlist": sorted(self._runtime_tool_executor.keys()), + }, + "tracks": { + "audio_in": self.track_audio_in, + "audio_out": self.track_audio_out, + "control": self.track_control, + }, + } + + def _next_event_seq(self) -> int: + if self._next_seq: + return self._next_seq() + self._local_seq += 1 + return self._local_seq + + def _event_source(self, event_type: str) -> str: + if event_type.startswith("transcript.") or event_type.startswith("input.speech_"): + return "asr" + if event_type.startswith("assistant.response."): + return "llm" + if event_type.startswith("assistant.tool_"): + return "tool" + if event_type.startswith("output.audio.") or event_type == "metrics.ttfb": + return "tts" + return "system" + + def _new_id(self, prefix: str, counter: int) -> str: + return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}" + + def _start_turn(self) -> str: + self._turn_count += 1 + self._current_turn_id = self._new_id("turn", self._turn_count) + self._current_utterance_id = None + self._current_response_id = None + self._current_tts_id = None + return self._current_turn_id + + def _start_response(self) -> str: + self._response_count += 1 + self._current_response_id = self._new_id("resp", self._response_count) + self._current_tts_id = None + return self._current_response_id + + def _start_tts(self) -> str: + self._tts_count += 1 + self._current_tts_id = self._new_id("tts", self._tts_count) + return self._current_tts_id + + def _finalize_utterance(self) -> str: + if self._current_utterance_id: + return self._current_utterance_id + self._utterance_count += 1 + self._current_utterance_id = self._new_id("utt", self._utterance_count) + if not self._current_turn_id: + self._start_turn() + return self._current_utterance_id + + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: + event_type = str(event.get("type") or "") + source = str(event.get("source") or self._event_source(event_type)) + track_id = event.get("trackId") + if not track_id: + if source == "asr": + track_id = self.track_audio_in + elif source in {"llm", "tts", "tool"}: + track_id = self.track_audio_out + else: + track_id = self.track_control + + data = event.get("data") + if not isinstance(data, dict): + data = {} + if self._current_turn_id: + data.setdefault("turn_id", self._current_turn_id) + if self._current_utterance_id: + data.setdefault("utterance_id", self._current_utterance_id) + if self._current_response_id: + data.setdefault("response_id", self._current_response_id) + if self._current_tts_id: + data.setdefault("tts_id", self._current_tts_id) + + for k, v in event.items(): + if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}: + continue + data.setdefault(k, v) + + event["sessionId"] = self.session_id + event["seq"] = self._next_event_seq() + event["source"] = source + event["trackId"] = track_id + event["data"] = data + return event + @staticmethod def _coerce_bool(value: Any) -> Optional[bool]: if isinstance(value, bool): @@ -472,11 +629,13 @@ class DuplexPipeline: greeting_to_speak = generated_greeting self.conversation.greeting = generated_greeting if greeting_to_speak: + self._start_turn() + self._start_response() await self._send_event( ev( "assistant.response.final", text=greeting_to_speak, - trackId=self.session_id, + trackId=self.track_audio_out, ), priority=20, ) @@ -494,10 +653,58 @@ class DuplexPipeline: await self._outbound_q.put((priority, self._outbound_seq, kind, payload)) async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None: - await self._enqueue_outbound("event", event, priority) + await self._enqueue_outbound("event", self._envelope_event(event), priority) async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None: - await self._enqueue_outbound("audio", pcm_bytes, priority) + if not pcm_bytes: + return + self._audio_out_frame_buffer += pcm_bytes + while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES: + frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES] + self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :] + await self._enqueue_outbound("audio", frame, priority) + + async def _flush_audio_out_frames(self, priority: int = 50) -> None: + """Flush remaining outbound audio as one padded 20ms PCM frame.""" + if not self._audio_out_frame_buffer: + return + tail = self._audio_out_frame_buffer + self._audio_out_frame_buffer = b"" + if len(tail) < self._PCM_FRAME_BYTES: + tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail))) + await self._enqueue_outbound("audio", tail, priority) + + async def _emit_transcript_delta(self, text: str) -> None: + await self._send_event( + { + **ev( + "transcript.delta", + trackId=self.track_audio_in, + text=text, + ) + }, + priority=30, + ) + + async def _emit_llm_delta(self, text: str) -> None: + await self._send_event( + { + **ev( + "assistant.response.delta", + trackId=self.track_audio_out, + text=text, + ) + }, + priority=20, + ) + + async def _flush_pending_llm_delta(self) -> None: + if not self._pending_llm_delta: + return + chunk = self._pending_llm_delta + self._pending_llm_delta = "" + self._last_llm_delta_emit_ms = time.monotonic() * 1000.0 + await self._emit_llm_delta(chunk) async def _outbound_loop(self) -> None: """Single sender loop that enforces priority for interrupt events.""" @@ -546,13 +753,13 @@ class DuplexPipeline: # Emit VAD event await self.event_bus.publish(event_type, { - "trackId": self.session_id, + "trackId": self.track_audio_in, "probability": probability }) await self._send_event( ev( "input.speech_started" if event_type == "speaking" else "input.speech_stopped", - trackId=self.session_id, + trackId=self.track_audio_in, probability=probability, ), priority=30, @@ -661,6 +868,9 @@ class DuplexPipeline: # Cancel any current speaking await self._stop_current_speech() + self._start_turn() + self._finalize_utterance() + # Start new turn await self.conversation.end_user_turn(text) self._current_turn_task = asyncio.create_task(self._handle_turn(text)) @@ -683,24 +893,45 @@ class DuplexPipeline: if text == self._last_sent_transcript and not is_final: return + now_ms = time.monotonic() * 1000.0 self._last_sent_transcript = text - # Send transcript event to client - await self._send_event({ - **ev( - "transcript.final" if is_final else "transcript.delta", - trackId=self.session_id, - text=text, + if is_final: + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = 0.0 + await self._send_event( + { + **ev( + "transcript.final", + trackId=self.track_audio_in, + text=text, + ) + }, + priority=30, ) - }, priority=30) + logger.debug(f"Sent transcript (final): {text[:50]}...") + return + + self._pending_transcript_delta = text + should_emit = ( + self._last_transcript_delta_emit_ms <= 0.0 + or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS + ) + if should_emit and self._pending_transcript_delta: + delta = self._pending_transcript_delta + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = now_ms + await self._emit_transcript_delta(delta) if not is_final: logger.info(f"[ASR] ASR interim: {text[:100]}") - logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") + logger.debug(f"Sent transcript (interim): {text[:50]}...") async def _on_speech_start(self) -> None: """Handle user starting to speak.""" if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED): + self._start_turn() + self._finalize_utterance() await self.conversation.start_user_turn() self._audio_buffer = b"" self._last_sent_transcript = "" @@ -779,6 +1010,7 @@ class DuplexPipeline: return logger.info(f"[EOU] Detected - user said: {user_text[:100]}...") + self._finalize_utterance() # For ASR backends that already emitted final via callback, # avoid duplicating transcript.final on EOU. @@ -786,7 +1018,7 @@ class DuplexPipeline: await self._send_event({ **ev( "transcript.final", - trackId=self.session_id, + trackId=self.track_audio_in, text=user_text, ) }, priority=25) @@ -794,6 +1026,8 @@ class DuplexPipeline: # Clear buffers self._audio_buffer = b"" self._last_sent_transcript = "" + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = 0.0 self._asr_capture_active = False self._pending_speech_audio = b"" @@ -894,6 +1128,44 @@ class DuplexPipeline: # Default to server execution unless explicitly marked as client. return "server" + def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: + fn = tool_call.get("function") + if not isinstance(fn, dict): + return {} + raw = fn.get("arguments") + if isinstance(raw, dict): + return raw + if isinstance(raw, str) and raw.strip(): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else {"raw": raw} + except Exception: + return {"raw": raw} + return {} + + def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + status = result.get("status") if isinstance(result.get("status"), dict) else {} + status_code = int(status.get("code") or 0) if status else 0 + status_message = str(status.get("message") or "") if status else "" + tool_call_id = str(result.get("tool_call_id") or result.get("id") or "") + tool_name = str(result.get("name") or "unknown_tool") + ok = bool(200 <= status_code < 300) + retryable = status_code >= 500 or status_code in {429, 408} + error: Optional[Dict[str, Any]] = None + if not ok: + error = { + "code": status_code or 500, + "message": status_message or "tool_execution_failed", + "retryable": retryable, + } + return { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "ok": ok, + "error": error, + "status": {"code": status_code, "message": status_message}, + } + async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None: tool_name = str(result.get("name") or "unknown_tool") call_id = str(result.get("tool_call_id") or result.get("id") or "") @@ -904,12 +1176,17 @@ class DuplexPipeline: f"[Tool] emit result source={source} name={tool_name} call_id={call_id} " f"status={status_code} {status_message}".strip() ) + normalized = self._normalize_tool_result(result) await self._send_event( { **ev( "assistant.tool_result", - trackId=self.session_id, + trackId=self.track_audio_out, source=source, + tool_call_id=normalized["tool_call_id"], + tool_name=normalized["tool_name"], + ok=normalized["ok"], + error=normalized["error"], result=result, ) }, @@ -927,6 +1204,9 @@ class DuplexPipeline: call_id = str(item.get("tool_call_id") or item.get("id") or "").strip() if not call_id: continue + if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids: + logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}") + continue if call_id in self._completed_tool_call_ids: logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}") continue @@ -972,6 +1252,7 @@ class DuplexPipeline: } finally: self._pending_tool_waiters.pop(call_id, None) + self._pending_client_tool_call_ids.discard(call_id) def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: if isinstance(item, LLMStreamEvent): @@ -998,6 +1279,11 @@ class DuplexPipeline: user_text: User's transcribed text """ try: + if not self._current_turn_id: + self._start_turn() + if not self._current_utterance_id: + self._finalize_utterance() + self._start_response() # Start latency tracking self._turn_start_time = time.time() self._first_audio_sent = False @@ -1012,6 +1298,8 @@ class DuplexPipeline: self._drop_outbound_audio = False first_audio_sent = False + self._pending_llm_delta = "" + self._last_llm_delta_emit_ms = 0.0 for _ in range(max_rounds): if self._interrupt_event.is_set(): break @@ -1028,6 +1316,7 @@ class DuplexPipeline: event = self._normalize_stream_event(raw_event) if event.type == "tool_call": + await self._flush_pending_llm_delta() tool_call = event.tool_call if isinstance(event.tool_call, dict) else None if not tool_call: continue @@ -1045,11 +1334,19 @@ class DuplexPipeline: f"executor={executor} args={args_preview}" ) tool_calls.append(enriched_tool_call) + tool_arguments = self._tool_arguments(enriched_tool_call) + if executor == "client" and call_id: + self._pending_client_tool_call_ids.add(call_id) await self._send_event( { **ev( "assistant.tool_call", - trackId=self.session_id, + trackId=self.track_audio_out, + tool_call_id=call_id, + tool_name=tool_name, + arguments=tool_arguments, + executor=executor, + timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000), tool_call=enriched_tool_call, ) }, @@ -1071,19 +1368,13 @@ class DuplexPipeline: round_response += text_chunk sentence_buffer += text_chunk await self.conversation.update_assistant_text(text_chunk) - - await self._send_event( - { - **ev( - "assistant.response.delta", - trackId=self.session_id, - text=text_chunk, - ) - }, - # Keep delta/final on the same event priority so FIFO seq - # preserves stream order (avoid late-delta after final). - priority=20, - ) + self._pending_llm_delta += text_chunk + now_ms = time.monotonic() * 1000.0 + if ( + self._last_llm_delta_emit_ms <= 0.0 + or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS + ): + await self._flush_pending_llm_delta() while True: split_result = extract_tts_sentence( @@ -1112,11 +1403,12 @@ class DuplexPipeline: if self._tts_output_enabled() and not self._interrupt_event.is_set(): if not first_audio_sent: + self._start_tts() await self._send_event( { **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10, @@ -1130,6 +1422,7 @@ class DuplexPipeline: ) remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + await self._flush_pending_llm_delta() if ( self._tts_output_enabled() and remaining_text @@ -1137,11 +1430,12 @@ class DuplexPipeline: and not self._interrupt_event.is_set() ): if not first_audio_sent: + self._start_tts() await self._send_event( { **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10, @@ -1204,11 +1498,12 @@ class DuplexPipeline: ] if full_response and not self._interrupt_event.is_set(): + await self._flush_pending_llm_delta() await self._send_event( { **ev( "assistant.response.final", - trackId=self.session_id, + trackId=self.track_audio_out, text=full_response, ) }, @@ -1217,10 +1512,11 @@ class DuplexPipeline: # Send track end if first_audio_sent: + await self._flush_audio_out_frames(priority=50) await self._send_event({ **ev( "output.audio.end", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1241,6 +1537,8 @@ class DuplexPipeline: self._barge_in_speech_start_time = None self._barge_in_speech_frames = 0 self._barge_in_silence_frames = 0 + self._current_response_id = None + self._current_tts_id = None async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None: """ @@ -1277,7 +1575,7 @@ class DuplexPipeline: await self._send_event({ **ev( "metrics.ttfb", - trackId=self.session_id, + trackId=self.track_audio_out, latencyMs=round(ttfb_ms), ) }, priority=25) @@ -1354,10 +1652,11 @@ class DuplexPipeline: first_audio_sent = False # Send track start event + self._start_tts() await self._send_event({ **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1379,7 +1678,7 @@ class DuplexPipeline: await self._send_event({ **ev( "metrics.ttfb", - trackId=self.session_id, + trackId=self.track_audio_out, latencyMs=round(ttfb_ms), ) }, priority=25) @@ -1391,10 +1690,11 @@ class DuplexPipeline: await asyncio.sleep(0.01) # Send track end event + await self._flush_audio_out_frames(priority=50) await self._send_event({ **ev( "output.audio.end", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1422,13 +1722,14 @@ class DuplexPipeline: self._interrupt_event.set() self._is_bot_speaking = False self._drop_outbound_audio = True + self._audio_out_frame_buffer = b"" # Send interrupt event to client IMMEDIATELY # This must happen BEFORE canceling services, so client knows to discard in-flight audio await self._send_event({ **ev( "response.interrupted", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=0) @@ -1455,6 +1756,7 @@ class DuplexPipeline: async def _stop_current_speech(self) -> None: """Stop any current speech task.""" self._drop_outbound_audio = True + self._audio_out_frame_buffer = b"" if self._current_turn_task and not self._current_turn_task.done(): self._interrupt_event.set() self._current_turn_task.cancel() diff --git a/core/session.py b/core/session.py index 3f8f18d..597b694 100644 --- a/core/session.py +++ b/core/session.py @@ -1,15 +1,16 @@ """Session management for active calls.""" import asyncio -import uuid +import hashlib import json -import time import re +import time from enum import Enum from typing import Optional, Dict, Any, List from loguru import logger from app.backend_client import ( + fetch_assistant_config, create_history_call_record, add_history_transcript, finalize_history_call_record, @@ -49,6 +50,32 @@ class Session: Uses full duplex voice conversation pipeline. """ + TRACK_AUDIO_IN = "audio_in" + TRACK_AUDIO_OUT = "audio_out" + TRACK_CONTROL = "control" + AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms + _CLIENT_METADATA_OVERRIDES = { + "firstTurnMode", + "greeting", + "generatedOpenerEnabled", + "systemPrompt", + "output", + "bargeIn", + "knowledge", + "knowledgeBaseId", + "history", + "userId", + "assistantId", + "source", + } + _CLIENT_METADATA_ID_KEYS = { + "appId", + "app_id", + "channel", + "configVersionId", + "config_version_id", + } + def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): """ Initialize session. @@ -78,7 +105,10 @@ class Session: self.authenticated: bool = False # Track IDs - self.current_track_id: Optional[str] = str(uuid.uuid4()) + self.current_track_id: str = self.TRACK_CONTROL + self._event_seq: int = 0 + self._audio_ingress_buffer: bytes = b"" + self._audio_frame_error_reported: bool = False self._history_call_id: Optional[str] = None self._history_turn_index: int = 0 self._history_call_started_mono: Optional[float] = None @@ -89,6 +119,7 @@ class Session: self._workflow_last_user_text: str = "" self._workflow_initial_node: Optional[WorkflowNodeDef] = None + self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.conversation.on_turn_complete(self._on_turn_complete) logger.info(f"Session {self.id} created (duplex={self.use_duplex})") @@ -129,13 +160,52 @@ class Session: "client", "Audio received before session.start", "protocol.order", + stage="protocol", + retryable=False, ) return try: - await self.pipeline.process_audio(audio_bytes) + if not audio_bytes: + return + if len(audio_bytes) % 2 != 0: + await self._send_error( + "client", + "Invalid PCM payload: odd number of bytes", + "audio.invalid_pcm", + stage="audio", + retryable=False, + ) + return + + frame_bytes = self.AUDIO_FRAME_BYTES + self._audio_ingress_buffer += audio_bytes + + # Protocol v1 audio framing: 20ms PCM frame (640 bytes). + # Allow aggregated frames in one WS message (multiple of 640). + if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported: + self._audio_frame_error_reported = True + await self._send_error( + "client", + f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)", + "audio.frame_size_mismatch", + stage="audio", + retryable=True, + ) + + while len(self._audio_ingress_buffer) >= frame_bytes: + frame = self._audio_ingress_buffer[:frame_bytes] + self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:] + await self.pipeline.process_audio(frame) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) + await self._send_error( + "server", + f"Audio processing failed: {e}", + "audio.processing_failed", + stage="audio", + retryable=True, + ) async def _handle_v1_message(self, message: Any) -> None: """Route validated WS v1 message to handlers.""" @@ -217,10 +287,9 @@ class Session: self.authenticated = True self.protocol_version = message.version self.ws_state = WsSessionState.WAIT_START - await self.transport.send_event( + await self._send_event( ev( "hello.ack", - sessionId=self.id, version=self.protocol_version, ) ) @@ -231,8 +300,12 @@ class Session: await self._send_error("client", "Duplicate session.start", "protocol.order") return - metadata = message.metadata or {} - metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) + raw_metadata = message.metadata or {} + workflow_runtime = self._bootstrap_workflow(raw_metadata) + server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime) + client_runtime = self._sanitize_client_metadata(raw_metadata) + metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime)) + metadata = self._merge_runtime_metadata(metadata, client_runtime) # Create history call record early so later turn callbacks can append transcripts. await self._start_history_bridge(metadata) @@ -248,28 +321,37 @@ class Session: self.state = "accepted" self.ws_state = WsSessionState.ACTIVE - await self.transport.send_event( + await self._send_event( ev( "session.started", - sessionId=self.id, trackId=self.current_track_id, + tracks={ + "audio_in": self.TRACK_AUDIO_IN, + "audio_out": self.TRACK_AUDIO_OUT, + "control": self.TRACK_CONTROL, + }, audio=message.audio or {}, ) ) + await self._send_event( + ev( + "config.resolved", + trackId=self.TRACK_CONTROL, + config=self._build_config_resolved(metadata), + ) + ) if self.workflow_runner and self._workflow_initial_node: - await self.transport.send_event( + await self._send_event( ev( "workflow.started", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, workflowName=self.workflow_runner.name, nodeId=self._workflow_initial_node.id, ) ) - await self.transport.send_event( + await self._send_event( ev( "workflow.node.entered", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=self._workflow_initial_node.id, nodeName=self._workflow_initial_node.name, @@ -285,17 +367,23 @@ class Session: stop_reason = reason or "client_requested" self.state = "hungup" self.ws_state = WsSessionState.STOPPED - await self.transport.send_event( + await self._send_event( ev( "session.stopped", - sessionId=self.id, reason=stop_reason, ) ) await self._finalize_history(status="connected") await self.transport.close() - async def _send_error(self, sender: str, error_message: str, code: str) -> None: + async def _send_error( + self, + sender: str, + error_message: str, + code: str, + stage: Optional[str] = None, + retryable: Optional[bool] = None, + ) -> None: """ Send error event to client. @@ -304,13 +392,25 @@ class Session: error_message: Error message code: Machine-readable error code """ - await self.transport.send_event( + resolved_stage = stage or self._infer_error_stage(code) + resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"}) + await self._send_event( ev( "error", sender=sender, code=code, message=error_message, + stage=resolved_stage, + retryable=resolved_retryable, trackId=self.current_track_id, + data={ + "error": { + "stage": resolved_stage, + "code": code, + "message": error_message, + "retryable": resolved_retryable, + } + }, ) ) @@ -483,10 +583,9 @@ class Session: node = transition.node edge = transition.edge - await self.transport.send_event( + await self._send_event( ev( "workflow.edge.taken", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, edgeId=edge.id, fromNodeId=edge.from_node_id, @@ -494,10 +593,9 @@ class Session: reason=reason, ) ) - await self.transport.send_event( + await self._send_event( ev( "workflow.node.entered", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, nodeName=node.name, @@ -510,10 +608,9 @@ class Session: self.pipeline.apply_runtime_overrides(node_runtime) if node.node_type == "tool": - await self.transport.send_event( + await self._send_event( ev( "workflow.tool.requested", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, tool=node.tool or {}, @@ -522,10 +619,9 @@ class Session: return if node.node_type == "human_transfer": - await self.transport.send_event( + await self._send_event( ev( "workflow.human_transfer", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) @@ -534,16 +630,68 @@ class Session: return if node.node_type == "end": - await self.transport.send_event( + await self._send_event( ev( "workflow.ended", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_end") + def _next_event_seq(self) -> int: + self._event_seq += 1 + return self._event_seq + + def _event_source(self, event_type: str) -> str: + if event_type.startswith("workflow."): + return "system" + if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat": + return "system" + if event_type == "error": + return "system" + return "system" + + def _infer_error_stage(self, code: str) -> str: + normalized = str(code or "").strip().lower() + if normalized.startswith("audio."): + return "audio" + if normalized.startswith("tool."): + return "tool" + if normalized.startswith("asr."): + return "asr" + if normalized.startswith("llm."): + return "llm" + if normalized.startswith("tts."): + return "tts" + return "protocol" + + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: + event_type = str(event.get("type") or "") + source = str(event.get("source") or self._event_source(event_type)) + track_id = event.get("trackId") or self.TRACK_CONTROL + + data = event.get("data") + if not isinstance(data, dict): + data = {} + for k, v in event.items(): + if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}: + continue + data.setdefault(k, v) + + event["sessionId"] = self.id + event["seq"] = self._next_event_seq() + event["source"] = source + event["trackId"] = track_id + event["data"] = data + return event + + async def _send_event(self, event: Dict[str, Any]) -> None: + await self.transport.send_event(self._envelope_event(event)) + + async def send_heartbeat(self) -> None: + await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL)) + async def _workflow_llm_route( self, node: WorkflowNodeDef, @@ -629,6 +777,100 @@ class Session: merged[key] = value return merged + async def _load_server_runtime_metadata( + self, + client_metadata: Dict[str, Any], + workflow_runtime: Dict[str, Any], + ) -> Dict[str, Any]: + """Load trusted runtime metadata from backend assistant config.""" + assistant_id = ( + workflow_runtime.get("assistantId") + or client_metadata.get("assistantId") + or client_metadata.get("appId") + or client_metadata.get("app_id") + ) + if assistant_id is None: + return {} + if not settings.backend_url: + return {} + + payload = await fetch_assistant_config(str(assistant_id).strip()) + if not isinstance(payload, dict): + return {} + + assistant_cfg = payload.get("assistant") if isinstance(payload.get("assistant"), dict) else payload + if not isinstance(assistant_cfg, dict): + return {} + + runtime: Dict[str, Any] = {} + if assistant_cfg.get("systemPrompt") is not None: + runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "") + elif assistant_cfg.get("prompt") is not None: + runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "") + + if assistant_cfg.get("greeting") is not None: + runtime["greeting"] = assistant_cfg.get("greeting") + elif assistant_cfg.get("opener") is not None: + runtime["greeting"] = assistant_cfg.get("opener") + + if isinstance(assistant_cfg.get("services"), dict): + runtime["services"] = assistant_cfg.get("services") + if isinstance(assistant_cfg.get("tools"), list): + runtime["tools"] = assistant_cfg.get("tools") + + runtime["assistantId"] = str(assistant_id) + return runtime + + def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize untrusted metadata sources. + + This keeps only a small override whitelist and stable config ID fields. + """ + if not isinstance(metadata, dict): + return {} + + sanitized: Dict[str, Any] = {} + for key in self._CLIENT_METADATA_ID_KEYS: + if key in metadata: + sanitized[key] = metadata[key] + for key in self._CLIENT_METADATA_OVERRIDES: + if key in metadata: + sanitized[key] = metadata[key] + + return sanitized + + def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Apply client metadata whitelist and remove forbidden secrets.""" + sanitized = self._sanitize_untrusted_runtime_metadata(metadata) + if isinstance(metadata.get("services"), dict): + logger.warning( + "Session {} provided metadata.services from client; client-side service config is ignored", + self.id, + ) + return sanitized + + def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Build public resolved config payload (secrets removed).""" + system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "") + prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None + runtime = self.pipeline.resolved_runtime_config() + + return { + "appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"), + "channel": metadata.get("channel"), + "configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"), + "prompt": {"sha256": prompt_hash}, + "output": runtime.get("output", {}), + "services": runtime.get("services", {}), + "tools": runtime.get("tools", {}), + "tracks": { + "audio_in": self.TRACK_AUDIO_IN, + "audio_out": self.TRACK_AUDIO_OUT, + "control": self.TRACK_CONTROL, + }, + } + def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: """Best-effort extraction of a JSON object from freeform text.""" try: diff --git a/docs/ws_v1_schema.md b/docs/ws_v1_schema.md index 9db0900..c2a7ab4 100644 --- a/docs/ws_v1_schema.md +++ b/docs/ws_v1_schema.md @@ -52,43 +52,26 @@ Rules: "channels": 1 }, "metadata": { + "appId": "assistant_123", + "channel": "web", + "configVersionId": "cfg_20260217_01", "client": "web-debug", "output": { "mode": "audio" }, "systemPrompt": "You are concise.", - "greeting": "Hi, how can I help?", - "services": { - "llm": { - "provider": "openai", - "model": "gpt-4o-mini", - "apiKey": "sk-...", - "baseUrl": "https://api.openai.com/v1" - }, - "asr": { - "provider": "openai_compatible", - "model": "FunAudioLLM/SenseVoiceSmall", - "apiKey": "sf-...", - "interimIntervalMs": 500, - "minAudioMs": 300 - }, - "tts": { - "enabled": true, - "provider": "openai_compatible", - "model": "FunAudioLLM/CosyVoice2-0.5B", - "apiKey": "sf-...", - "voice": "anna", - "speed": 1.0 - } - } + "greeting": "Hi, how can I help?" } } ``` -`metadata.services` is optional. If omitted, server defaults to environment configuration. +Rules: +- Client-side `metadata.services` is ignored. +- Service config (including secrets) is resolved server-side (env/backend). +- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints). Text-only mode: -- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`. +- Set `metadata.output.mode = "text"`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. ### `input.text` @@ -121,6 +104,7 @@ Text-only mode: ### `tool_call.results` Client tool execution results returned to server. +Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side). ```json { @@ -138,21 +122,35 @@ Client tool execution results returned to server. ## Server -> Client Events -All server events include: +All server events include an envelope: ```json { "type": "event.name", - "timestamp": 1730000000000 + "timestamp": 1730000000000, + "sessionId": "sess_xxx", + "seq": 42, + "source": "asr", + "trackId": "audio_in", + "data": {} } ``` +Envelope notes: +- `seq` is monotonically increasing within one session (for replay/resume). +- `source` is one of: `asr | llm | tts | tool | system`. +- `data` is structured payload; legacy top-level fields are kept for compatibility. + Common events: - `hello.ack` - Fields: `sessionId`, `version` - `session.started` - - Fields: `sessionId`, `trackId`, `audio` + - Fields: `sessionId`, `trackId`, `tracks`, `audio` +- `config.resolved` + - Fields: `sessionId`, `trackId`, `config` + - Sent immediately after `session.started`. + - Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets. - `session.stopped` - Fields: `sessionId`, `reason` - `heartbeat` @@ -169,9 +167,10 @@ Common events: - `assistant.response.final` - Fields: `trackId`, `text` - `assistant.tool_call` - - Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`) + - Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms` - `assistant.tool_result` - - Fields: `trackId`, `source`, `result` + - Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error` + - `error`: `{ code, message, retryable }` when `ok=false` - `output.audio.start` - Fields: `trackId` - `output.audio.end` @@ -183,15 +182,49 @@ Common events: - `error` - Fields: `sender`, `code`, `message`, `trackId` +Track IDs (MVP fixed values): +- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) +- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`) +- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`) + +Correlation IDs (`event.data`): +- `turn_id`: one user-assistant interaction turn. +- `utterance_id`: one ASR final utterance. +- `response_id`: one assistant response generation. +- `tool_call_id`: one tool invocation. +- `tts_id`: one TTS playback segment. + ## Binary Audio Frames After `session.started`, client may send binary PCM chunks continuously. -Recommended format: -- 16-bit signed little-endian PCM. -- 1 channel. -- 16000 Hz. -- 20ms frames (640 bytes) preferred. +MVP fixed format: +- 16-bit signed little-endian PCM (`pcm_s16le`) +- mono (1 channel) +- 16000 Hz +- 20ms frame = 640 bytes + +Framing rules: +- Binary audio frame unit is 640 bytes. +- A WS binary message may carry one or multiple complete 640-byte frames. +- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors. + +TTS boundary events: +- `output.audio.start` and `output.audio.end` mark assistant playback boundaries. + +## Event Throttling + +To keep client rendering and server load stable, v1 applies/recommends: +- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms). +- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms). +- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms. + +## Error Structure + +`error` keeps legacy top-level fields (`code`, `message`) and adds structured info: +- `stage`: `protocol | asr | llm | tts | tool | audio` +- `retryable`: boolean +- `data.error`: `{ stage, code, message, retryable }` ## Compatibility diff --git a/examples/mic_client.py b/examples/mic_client.py index 509aeaa..00d403f 100644 --- a/examples/mic_client.py +++ b/examples/mic_client.py @@ -59,8 +59,12 @@ class MicrophoneClient: url: str, sample_rate: int = 16000, chunk_duration_ms: int = 20, + app_id: str = "assistant_demo", + channel: str = "mic_client", + config_version_id: str = "local-dev", input_device: int = None, - output_device: int = None + output_device: int = None, + track_debug: bool = False, ): """ Initialize microphone client. @@ -76,8 +80,12 @@ class MicrophoneClient: self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id self.input_device = input_device self.output_device = output_device + self.track_debug = track_debug # WebSocket connection self.ws = None @@ -106,6 +114,17 @@ class MicrophoneClient: # Verbose mode for streaming LLM responses self.verbose = False + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self) -> None: """Connect to WebSocket server.""" @@ -114,20 +133,30 @@ class MicrophoneClient: self.running = True print("Connected!") - # Send invite command + # WS v1 handshake: hello -> session.start await self.send_command({ - "command": "invite", - "option": { - "codec": "pcm", - "sampleRate": self.sample_rate - } + "type": "hello", + "version": "v1", + }) + await self.send_command({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1, + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, }) async def send_command(self, cmd: dict) -> None: """Send JSON command to server.""" if self.ws: await self.ws.send(json.dumps(cmd)) - print(f"→ Command: {cmd.get('command', 'unknown')}") + print(f"→ Command: {cmd.get('type', 'unknown')}") async def send_chat(self, text: str) -> None: """Send chat message (text input).""" @@ -136,7 +165,7 @@ class MicrophoneClient: self.first_audio_received = False await self.send_command({ - "command": "chat", + "type": "input.text", "text": text }) print(f"→ Chat: {text}") @@ -144,13 +173,14 @@ class MicrophoneClient: async def send_interrupt(self) -> None: """Send interrupt command.""" await self.send_command({ - "command": "interrupt" + "type": "response.cancel", + "graceful": False, }) async def send_hangup(self, reason: str = "User quit") -> None: """Send hangup command.""" await self.send_command({ - "command": "hangup", + "type": "session.stop", "reason": reason }) @@ -295,43 +325,48 @@ class MicrophoneClient: async def _handle_event(self, event: dict) -> None: """Handle incoming event.""" - event_type = event.get("event", "unknown") + event_type = event.get("type", event.get("event", "unknown")) + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type == "answer": - print("← Session ready!") - elif event_type == "speaking": - print("← User speech detected") - elif event_type == "silence": - print("← User silence detected") - elif event_type == "transcript": + if event_type in {"hello.ack", "session.started"}: + print(f"← Session ready!{ids}") + elif event_type == "config.resolved": + print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}") + elif event_type == "input.speech_started": + print(f"← User speech detected{ids}") + elif event_type == "input.speech_stopped": + print(f"← User silence detected{ids}") + elif event_type in {"transcript", "transcript.delta", "transcript.final"}: # Display user speech transcription text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = event_type == "transcript.final" or bool(event.get("isFinal")) if is_final: # Clear the interim line and print final print(" " * 80, end="\r") # Clear previous interim text - print(f"→ You: {text}") + print(f"→ You: {text}{ids}") else: # Interim result - show with indicator (overwrite same line) display_text = text[:60] + "..." if len(text) > 60 else text print(f" [listening] {display_text}".ljust(80), end="\r") - elif event_type == "ttfb": + elif event_type in {"ttfb", "metrics.ttfb"}: # Server-side TTFB event latency_ms = event.get("latencyMs", 0) print(f"← [TTFB] Server reported latency: {latency_ms}ms") - elif event_type == "llmResponse": + elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}: # LLM text response text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = event_type == "assistant.response.final" or bool(event.get("isFinal")) if is_final: # Print final LLM response print(f"← AI: {text}") elif self.verbose: # Show streaming chunks only in verbose mode display_text = text[:60] + "..." if len(text) > 60 else text - print(f" [streaming] {display_text}") - elif event_type == "trackStart": - print("← Bot started speaking") + print(f" [streaming] {display_text}{ids}") + elif event_type in {"trackStart", "output.audio.start"}: + print(f"← Bot started speaking{ids}") # IMPORTANT: Accept audio again after trackStart self._discard_audio = False self._audio_sequence += 1 @@ -342,13 +377,13 @@ class MicrophoneClient: # Clear any old audio in buffer with self.audio_output_lock: self.audio_output_buffer = b"" - elif event_type == "trackEnd": - print("← Bot finished speaking") + elif event_type in {"trackEnd", "output.audio.end"}: + print(f"← Bot finished speaking{ids}") # Reset TTFB tracking after response completes self.request_start_time = None self.first_audio_received = False - elif event_type == "interrupt": - print("← Bot interrupted!") + elif event_type in {"interrupt", "response.interrupted"}: + print(f"← Bot interrupted!{ids}") # IMPORTANT: Discard all audio until next trackStart self._discard_audio = True # Clear audio buffer immediately @@ -357,12 +392,12 @@ class MicrophoneClient: self.audio_output_buffer = b"" print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)") elif event_type == "error": - print(f"← Error: {event.get('error')}") - elif event_type == "hangup": - print(f"← Hangup: {event.get('reason')}") + print(f"← Error: {event.get('error')}{ids}") + elif event_type in {"hangup", "session.stopped"}: + print(f"← Hangup: {event.get('reason')}{ids}") self.running = False else: - print(f"← Event: {event_type}") + print(f"← Event: {event_type}{ids}") async def interactive_mode(self) -> None: """Run interactive mode for text chat.""" @@ -573,6 +608,26 @@ async def main(): action="store_true", help="Show streaming LLM response chunks" ) + parser.add_argument( + "--app-id", + default="assistant_demo", + help="Stable app/assistant identifier for server-side config lookup" + ) + parser.add_argument( + "--channel", + default="mic_client", + help="Client channel name" + ) + parser.add_argument( + "--config-version-id", + default="local-dev", + help="Optional config version identifier" + ) + parser.add_argument( + "--track-debug", + action="store_true", + help="Print event trackId for protocol debugging" + ) args = parser.parse_args() @@ -583,8 +638,12 @@ async def main(): client = MicrophoneClient( url=args.url, sample_rate=args.sample_rate, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, input_device=args.input_device, - output_device=args.output_device + output_device=args.output_device, + track_debug=args.track_debug, ) client.verbose = args.verbose diff --git a/examples/simple_client.py b/examples/simple_client.py index 4280f93..b1648bf 100644 --- a/examples/simple_client.py +++ b/examples/simple_client.py @@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE: class SimpleVoiceClient: """Simple voice client with reliable audio playback.""" - def __init__(self, url: str, sample_rate: int = 16000): + def __init__( + self, + url: str, + sample_rate: int = 16000, + app_id: str = "assistant_demo", + channel: str = "simple_client", + config_version_id: str = "local-dev", + track_debug: bool = False, + ): self.url = url self.sample_rate = sample_rate + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id + self.track_debug = track_debug self.ws = None self.running = False @@ -75,6 +87,17 @@ class SimpleVoiceClient: # Interrupt handling - discard audio until next trackStart self._discard_audio = False + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self): """Connect to server.""" @@ -83,12 +106,25 @@ class SimpleVoiceClient: self.running = True print("Connected!") - # Send invite + # WS v1 handshake: hello -> session.start await self.ws.send(json.dumps({ - "command": "invite", - "option": {"codec": "pcm", "sampleRate": self.sample_rate} + "type": "hello", + "version": "v1", })) - print("-> invite") + await self.ws.send(json.dumps({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1, + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, + })) + print("-> hello/session.start") async def send_chat(self, text: str): """Send chat message.""" @@ -96,8 +132,8 @@ class SimpleVoiceClient: self.request_start_time = time.time() self.first_audio_received = False - await self.ws.send(json.dumps({"command": "chat", "text": text})) - print(f"-> chat: {text}") + await self.ws.send(json.dumps({"type": "input.text", "text": text})) + print(f"-> input.text: {text}") def play_audio(self, audio_data: bytes): """Play audio data immediately.""" @@ -152,34 +188,39 @@ class SimpleVoiceClient: else: # JSON event event = json.loads(msg) - etype = event.get("event", "?") + etype = event.get("type", event.get("event", "?")) + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}") - if etype == "transcript": + if etype in {"transcript", "transcript.delta", "transcript.final"}: # User speech transcription text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = etype == "transcript.final" or bool(event.get("isFinal")) if is_final: - print(f"<- You said: {text}") + print(f"<- You said: {text}{ids}") else: print(f"<- [listening] {text}", end="\r") - elif etype == "ttfb": + elif etype in {"ttfb", "metrics.ttfb"}: # Server-side TTFB event latency_ms = event.get("latencyMs", 0) print(f"<- [TTFB] Server reported latency: {latency_ms}ms") - elif etype == "trackStart": + elif etype in {"trackStart", "output.audio.start"}: # New track starting - accept audio again self._discard_audio = False - print(f"<- {etype}") - elif etype == "interrupt": + print(f"<- {etype}{ids}") + elif etype in {"interrupt", "response.interrupted"}: # Interrupt - discard audio until next trackStart self._discard_audio = True - print(f"<- {etype} (discarding audio until new track)") - elif etype == "hangup": - print(f"<- {etype}") + print(f"<- {etype}{ids} (discarding audio until new track)") + elif etype in {"hangup", "session.stopped"}: + print(f"<- {etype}{ids}") self.running = False break + elif etype == "config.resolved": + print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}") else: - print(f"<- {etype}") + print(f"<- {etype}{ids}") except asyncio.TimeoutError: continue @@ -270,6 +311,10 @@ async def main(): parser.add_argument("--text", help="Send text and play response") parser.add_argument("--list-devices", action="store_true") parser.add_argument("--sample-rate", type=int, default=16000) + parser.add_argument("--app-id", default="assistant_demo") + parser.add_argument("--channel", default="simple_client") + parser.add_argument("--config-version-id", default="local-dev") + parser.add_argument("--track-debug", action="store_true") args = parser.parse_args() @@ -277,7 +322,14 @@ async def main(): list_audio_devices() return - client = SimpleVoiceClient(args.url, args.sample_rate) + client = SimpleVoiceClient( + args.url, + args.sample_rate, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, + track_debug=args.track_debug, + ) await client.run(args.text) diff --git a/examples/test_websocket.py b/examples/test_websocket.py index 0d2675d..6717834 100644 --- a/examples/test_websocket.py +++ b/examples/test_websocket.py @@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000): return audio_data -async def receive_loop(ws, ready_event: asyncio.Event): +async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False): """Listen for incoming messages from the server.""" + def event_ids_suffix(data): + payload = data.get("data") if isinstance(data.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = payload.get(key, data.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" + print("👂 Listening for server responses...") async for msg in ws: timestamp = datetime.now().strftime("%H:%M:%S") @@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event): try: data = json.loads(msg.data) event_type = data.get('type', 'Unknown') - print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...") + ids = event_ids_suffix(data) + print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...") + if track_debug: + print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}") if event_type == "session.started": ready_event.set() except json.JSONDecodeError: @@ -113,7 +126,7 @@ async def send_sine_loop(ws): print("\n✅ Finished streaming test audio.") -async def run_client(url, file_path=None, use_sine=False): +async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False): """Run the WebSocket test client.""" session = aiohttp.ClientSession() try: @@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False): async with session.ws_connect(url) as ws: print("✅ Connected!") session_ready = asyncio.Event() - recv_task = asyncio.create_task(receive_loop(ws, session_ready)) + recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug)) # Send v1 hello + session.start handshake await ws.send_json({"type": "hello", "version": "v1"}) @@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False): "encoding": "pcm_s16le", "sample_rate_hz": SAMPLE_RATE, "channels": 1 - } + }, + "metadata": { + "appId": "assistant_demo", + "channel": "test_websocket", + "configVersionId": "local-dev", + }, }) print("📤 Sent v1 hello/session.start") await asyncio.wait_for(session_ready.wait(), timeout=8) @@ -168,9 +186,10 @@ if __name__ == "__main__": parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL") parser.add_argument("--file", help="Path to PCM/WAV file to stream") parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)") + parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging") args = parser.parse_args() try: - asyncio.run(run_client(args.url, args.file, args.sine)) + asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug)) except KeyboardInterrupt: print("\n👋 Client stopped.") diff --git a/examples/wav_client.py b/examples/wav_client.py index 729e4d2..5684256 100644 --- a/examples/wav_client.py +++ b/examples/wav_client.py @@ -57,10 +57,15 @@ class WavFileClient: url: str, input_file: str, output_file: str, + app_id: str = "assistant_demo", + channel: str = "wav_client", + config_version_id: str = "local-dev", sample_rate: int = 16000, chunk_duration_ms: int = 20, wait_time: float = 15.0, - verbose: bool = False + verbose: bool = False, + track_debug: bool = False, + tail_silence_ms: int = 800, ): """ Initialize WAV file client. @@ -77,11 +82,17 @@ class WavFileClient: self.url = url self.input_file = Path(input_file) self.output_file = Path(output_file) + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.wait_time = wait_time self.verbose = verbose + self.track_debug = track_debug + self.tail_silence_ms = max(0, int(tail_silence_ms)) + self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms # WebSocket connection self.ws = None @@ -125,6 +136,17 @@ class WavFileClient: # Replace problematic characters for console output safe_message = message.encode('ascii', errors='replace').decode('ascii') print(f"{direction} {safe_message}") + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self) -> None: """Connect to WebSocket server.""" @@ -144,7 +166,12 @@ class WavFileClient: "encoding": "pcm_s16le", "sample_rate_hz": self.sample_rate, "channels": 1 - } + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, }) async def send_command(self, cmd: dict) -> None: @@ -216,6 +243,10 @@ class WavFileClient: end_sample = min(sent_samples + chunk_size, total_samples) chunk = audio_data[sent_samples:end_sample] chunk_bytes = chunk.tobytes() + if len(chunk_bytes) % self.frame_bytes != 0: + # v1 audio framing requires 640-byte (20ms) PCM units. + pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes) + chunk_bytes += b"\x00" * pad # Send to server if self.ws: @@ -232,6 +263,16 @@ class WavFileClient: # Delay to simulate real-time streaming # Server expects audio at real-time pace for VAD/ASR to work properly await asyncio.sleep(self.chunk_duration_ms / 1000) + + # Add a short silence tail to help VAD/EOU close the final utterance. + if self.tail_silence_ms > 0 and self.ws: + tail_frames = max(1, self.tail_silence_ms // 20) + silence = b"\x00" * self.frame_bytes + for _ in range(tail_frames): + await self.ws.send(silence) + self.bytes_sent += len(silence) + await asyncio.sleep(0.02) + self.log_event("→", f"Sent trailing silence: {self.tail_silence_ms}ms") self.send_completed = True elapsed = time.time() - self.send_start_time @@ -284,16 +325,22 @@ class WavFileClient: async def _handle_event(self, event: dict) -> None: """Handle incoming event.""" event_type = event.get("type", "unknown") + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") if event_type == "hello.ack": - self.log_event("←", "Handshake acknowledged") + self.log_event("←", f"Handshake acknowledged{ids}") elif event_type == "session.started": self.session_ready = True - self.log_event("←", "Session ready!") + self.log_event("←", f"Session ready!{ids}") + elif event_type == "config.resolved": + config = event.get("config", {}) + self.log_event("←", f"Config resolved (output={config.get('output', {})}){ids}") elif event_type == "input.speech_started": - self.log_event("←", "Speech detected") + self.log_event("←", f"Speech detected{ids}") elif event_type == "input.speech_stopped": - self.log_event("←", "Silence detected") + self.log_event("←", f"Silence detected{ids}") elif event_type == "transcript.delta": text = event.get("text", "") display_text = text[:60] + "..." if len(text) > 60 else text @@ -301,35 +348,35 @@ class WavFileClient: elif event_type == "transcript.final": text = event.get("text", "") print(" " * 80, end="\r") - self.log_event("←", f"→ You: {text}") + self.log_event("←", f"→ You: {text}{ids}") elif event_type == "metrics.ttfb": latency_ms = event.get("latencyMs", 0) self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms") elif event_type == "assistant.response.delta": text = event.get("text", "") if self.verbose and text: - self.log_event("←", f"LLM: {text}") + self.log_event("←", f"LLM: {text}{ids}") elif event_type == "assistant.response.final": text = event.get("text", "") if text: - self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}") + self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}") elif event_type == "output.audio.start": self.track_started = True self.response_start_time = time.time() self.waiting_for_first_audio = True - self.log_event("←", "Bot started speaking") + self.log_event("←", f"Bot started speaking{ids}") elif event_type == "output.audio.end": self.track_ended = True - self.log_event("←", "Bot finished speaking") + self.log_event("←", f"Bot finished speaking{ids}") elif event_type == "response.interrupted": - self.log_event("←", "Bot interrupted!") + self.log_event("←", f"Bot interrupted!{ids}") elif event_type == "error": - self.log_event("!", f"Error: {event.get('message')}") + self.log_event("!", f"Error: {event.get('message')}{ids}") elif event_type == "session.stopped": - self.log_event("←", f"Session stopped: {event.get('reason')}") + self.log_event("←", f"Session stopped: {event.get('reason')}{ids}") self.running = False else: - self.log_event("←", f"Event: {event_type}") + self.log_event("←", f"Event: {event_type}{ids}") def save_output_wav(self) -> None: """Save received audio to output WAV file.""" @@ -473,6 +520,21 @@ async def main(): default=16000, help="Target sample rate for audio (default: 16000)" ) + parser.add_argument( + "--app-id", + default="assistant_demo", + help="Stable app/assistant identifier for server-side config lookup" + ) + parser.add_argument( + "--channel", + default="wav_client", + help="Client channel name" + ) + parser.add_argument( + "--config-version-id", + default="local-dev", + help="Optional config version identifier" + ) parser.add_argument( "--chunk-duration", type=int, @@ -490,6 +552,17 @@ async def main(): action="store_true", help="Enable verbose output" ) + parser.add_argument( + "--track-debug", + action="store_true", + help="Print event trackId for protocol debugging" + ) + parser.add_argument( + "--tail-silence-ms", + type=int, + default=800, + help="Trailing silence to send after WAV playback for EOU detection (default: 800)" + ) args = parser.parse_args() @@ -497,10 +570,15 @@ async def main(): url=args.url, input_file=args.input, output_file=args.output, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, sample_rate=args.sample_rate, chunk_duration_ms=args.chunk_duration, wait_time=args.wait_time, - verbose=args.verbose + verbose=args.verbose, + track_debug=args.track_debug, + tail_silence_ms=args.tail_silence_ms, ) await client.run() diff --git a/examples/web_client.html b/examples/web_client.html index aaeb636..3431c02 100644 --- a/examples/web_client.html +++ b/examples/web_client.html @@ -401,6 +401,9 @@ const targetSampleRate = 16000; const playbackStopRampSec = 0.008; + const appId = "assistant_demo"; + const channel = "web_client"; + const configVersionId = "local-dev"; function logLine(type, text, data) { const time = new Date().toLocaleTimeString(); @@ -604,15 +607,35 @@ logLine("sys", `→ ${cmd.type}`, cmd); } + function eventIdsSuffix(event) { + const data = event && typeof event.data === "object" && event.data ? event.data : {}; + const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"]; + const parts = []; + for (const key of keys) { + const value = data[key] || event[key]; + if (value) parts.push(`${key}=${value}`); + } + return parts.length ? ` [${parts.join(" ")}]` : ""; + } + function handleEvent(event) { const type = event.type || "unknown"; - logLine("event", type, event); + const ids = eventIdsSuffix(event); + logLine("event", `${type}${ids}`, event); if (type === "hello.ack") { sendCommand({ type: "session.start", audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, + metadata: { + appId, + channel, + configVersionId, + }, }); } + if (type === "config.resolved") { + logLine("sys", "config.resolved", event.config || {}); + } if (type === "transcript.final") { if (event.text) { setInterim("You", "");