Update engine

This commit is contained in:
Xin Wang
2026-02-23 17:16:18 +08:00
parent 01c0de0a4d
commit c6c84b5af9
9 changed files with 991 additions and 186 deletions

View File

@@ -24,7 +24,6 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session from core.session import Session
from processors.tracks import Resampled16kTrack from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus from core.events import get_event_bus, reset_event_bus
from models.ws_v1 import ev
# Check interval for heartbeat/timeout (seconds) # Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5 _HEARTBEAT_CHECK_INTERVAL_SEC = 5
@@ -54,9 +53,7 @@ async def heartbeat_and_timeout_task(
break break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec: if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try: try:
await transport.send_event({ await session.send_heartbeat()
**ev("heartbeat"),
})
last_heartbeat_at[0] = now last_heartbeat_at[0] = now
except Exception as e: except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}") logger.debug(f"Session {session_id}: heartbeat send failed: {e}")

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio import asyncio
import json import json
import time import time
from typing import Any, Dict, List, Optional, Tuple import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
from loguru import logger from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6 _MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0 _TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0 _SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": { "current_time": {
"name": "current_time", "name": "current_time",
@@ -96,6 +103,9 @@ class DuplexPipeline:
self.transport = transport self.transport = transport
self.session_id = session_id self.session_id = session_id
self.event_bus = get_event_bus() self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD # Initialize VAD
self.vad_model = SileroVAD( self.vad_model = SileroVAD(
@@ -120,6 +130,8 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates # Track last sent transcript to avoid duplicates
self._last_sent_transcript = "" self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager # Conversation manager
self.conversation = ConversationManager( self.conversation = ConversationManager(
@@ -153,6 +165,7 @@ class DuplexPipeline:
self._outbound_seq = 0 self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling # Interruption handling
self._interrupt_event = asyncio.Event() self._interrupt_event = asyncio.Event()
@@ -186,9 +199,28 @@ class DuplexPipeline:
self._pending_tool_waiters: Dict[str, asyncio.Future] = {} self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set() self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
logger.info(f"DuplexPipeline initialized for session {session_id}") logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
""" """
Apply runtime overrides from WS session.start metadata. Apply runtime overrides from WS session.start metadata.
@@ -276,6 +308,131 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower()
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": sorted(self._runtime_tool_executor.keys()),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod @staticmethod
def _coerce_bool(value: Any) -> Optional[bool]: def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool): if isinstance(value, bool):
@@ -472,11 +629,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting self.conversation.greeting = generated_greeting
if greeting_to_speak: if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event( await self._send_event(
ev( ev(
"assistant.response.final", "assistant.response.final",
text=greeting_to_speak, text=greeting_to_speak,
trackId=self.session_id, trackId=self.track_audio_out,
), ),
priority=20, priority=20,
) )
@@ -494,10 +653,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload)) await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None: async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority) await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None: async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority) if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None: async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events.""" """Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +753,13 @@ class DuplexPipeline:
# Emit VAD event # Emit VAD event
await self.event_bus.publish(event_type, { await self.event_bus.publish(event_type, {
"trackId": self.session_id, "trackId": self.track_audio_in,
"probability": probability "probability": probability
}) })
await self._send_event( await self._send_event(
ev( ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped", "input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id, trackId=self.track_audio_in,
probability=probability, probability=probability,
), ),
priority=30, priority=30,
@@ -661,6 +868,9 @@ class DuplexPipeline:
# Cancel any current speaking # Cancel any current speaking
await self._stop_current_speech() await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn # Start new turn
await self.conversation.end_user_turn(text) await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text)) self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +893,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final: if text == self._last_sent_transcript and not is_final:
return return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text self._last_sent_transcript = text
# Send transcript event to client if is_final:
await self._send_event({ self._pending_transcript_delta = ""
**ev( self._last_transcript_delta_emit_ms = 0.0
"transcript.final" if is_final else "transcript.delta", await self._send_event(
trackId=self.session_id, {
text=text, **ev(
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
) )
}, priority=30) logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final: if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}") logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None: async def _on_speech_start(self) -> None:
"""Handle user starting to speak.""" """Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED): if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn() await self.conversation.start_user_turn()
self._audio_buffer = b"" self._audio_buffer = b""
self._last_sent_transcript = "" self._last_sent_transcript = ""
@@ -779,6 +1010,7 @@ class DuplexPipeline:
return return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...") logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback, # For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU. # avoid duplicating transcript.final on EOU.
@@ -786,7 +1018,7 @@ class DuplexPipeline:
await self._send_event({ await self._send_event({
**ev( **ev(
"transcript.final", "transcript.final",
trackId=self.session_id, trackId=self.track_audio_in,
text=user_text, text=user_text,
) )
}, priority=25) }, priority=25)
@@ -794,6 +1026,8 @@ class DuplexPipeline:
# Clear buffers # Clear buffers
self._audio_buffer = b"" self._audio_buffer = b""
self._last_sent_transcript = "" self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False self._asr_capture_active = False
self._pending_speech_audio = b"" self._pending_speech_audio = b""
@@ -894,6 +1128,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client. # Default to server execution unless explicitly marked as client.
return "server" return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None: async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool") tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "") call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1176,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} " f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip() f"status={status_code} {status_message}".strip()
) )
normalized = self._normalize_tool_result(result)
await self._send_event( await self._send_event(
{ {
**ev( **ev(
"assistant.tool_result", "assistant.tool_result",
trackId=self.session_id, trackId=self.track_audio_out,
source=source, source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result, result=result,
) )
}, },
@@ -927,6 +1204,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip() call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id: if not call_id:
continue continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids: if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}") logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue continue
@@ -972,6 +1252,7 @@ class DuplexPipeline:
} }
finally: finally:
self._pending_tool_waiters.pop(call_id, None) self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent): if isinstance(item, LLMStreamEvent):
@@ -998,6 +1279,11 @@ class DuplexPipeline:
user_text: User's transcribed text user_text: User's transcribed text
""" """
try: try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking # Start latency tracking
self._turn_start_time = time.time() self._turn_start_time = time.time()
self._first_audio_sent = False self._first_audio_sent = False
@@ -1012,6 +1298,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False self._drop_outbound_audio = False
first_audio_sent = False first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds): for _ in range(max_rounds):
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
break break
@@ -1028,6 +1316,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event) event = self._normalize_stream_event(raw_event)
if event.type == "tool_call": if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call: if not tool_call:
continue continue
@@ -1045,11 +1334,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}" f"executor={executor} args={args_preview}"
) )
tool_calls.append(enriched_tool_call) tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event( await self._send_event(
{ {
**ev( **ev(
"assistant.tool_call", "assistant.tool_call",
trackId=self.session_id, trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call, tool_call=enriched_tool_call,
) )
}, },
@@ -1071,19 +1368,13 @@ class DuplexPipeline:
round_response += text_chunk round_response += text_chunk
sentence_buffer += text_chunk sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk) await self.conversation.update_assistant_text(text_chunk)
self._pending_llm_delta += text_chunk
await self._send_event( now_ms = time.monotonic() * 1000.0
{ if (
**ev( self._last_llm_delta_emit_ms <= 0.0
"assistant.response.delta", or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
trackId=self.session_id, ):
text=text_chunk, await self._flush_pending_llm_delta()
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
while True: while True:
split_result = extract_tts_sentence( split_result = extract_tts_sentence(
@@ -1112,11 +1403,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set(): if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent: if not first_audio_sent:
self._start_tts()
await self._send_event( await self._send_event(
{ {
**ev( **ev(
"output.audio.start", "output.audio.start",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, },
priority=10, priority=10,
@@ -1130,6 +1422,7 @@ class DuplexPipeline:
) )
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if ( if (
self._tts_output_enabled() self._tts_output_enabled()
and remaining_text and remaining_text
@@ -1137,11 +1430,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set() and not self._interrupt_event.is_set()
): ):
if not first_audio_sent: if not first_audio_sent:
self._start_tts()
await self._send_event( await self._send_event(
{ {
**ev( **ev(
"output.audio.start", "output.audio.start",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, },
priority=10, priority=10,
@@ -1204,11 +1498,12 @@ class DuplexPipeline:
] ]
if full_response and not self._interrupt_event.is_set(): if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event( await self._send_event(
{ {
**ev( **ev(
"assistant.response.final", "assistant.response.final",
trackId=self.session_id, trackId=self.track_audio_out,
text=full_response, text=full_response,
) )
}, },
@@ -1217,10 +1512,11 @@ class DuplexPipeline:
# Send track end # Send track end
if first_audio_sent: if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({ await self._send_event({
**ev( **ev(
"output.audio.end", "output.audio.end",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, priority=10) }, priority=10)
@@ -1241,6 +1537,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0 self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0 self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None: async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
""" """
@@ -1277,7 +1575,7 @@ class DuplexPipeline:
await self._send_event({ await self._send_event({
**ev( **ev(
"metrics.ttfb", "metrics.ttfb",
trackId=self.session_id, trackId=self.track_audio_out,
latencyMs=round(ttfb_ms), latencyMs=round(ttfb_ms),
) )
}, priority=25) }, priority=25)
@@ -1354,10 +1652,11 @@ class DuplexPipeline:
first_audio_sent = False first_audio_sent = False
# Send track start event # Send track start event
self._start_tts()
await self._send_event({ await self._send_event({
**ev( **ev(
"output.audio.start", "output.audio.start",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, priority=10) }, priority=10)
@@ -1379,7 +1678,7 @@ class DuplexPipeline:
await self._send_event({ await self._send_event({
**ev( **ev(
"metrics.ttfb", "metrics.ttfb",
trackId=self.session_id, trackId=self.track_audio_out,
latencyMs=round(ttfb_ms), latencyMs=round(ttfb_ms),
) )
}, priority=25) }, priority=25)
@@ -1391,10 +1690,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# Send track end event # Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({ await self._send_event({
**ev( **ev(
"output.audio.end", "output.audio.end",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, priority=10) }, priority=10)
@@ -1422,13 +1722,14 @@ class DuplexPipeline:
self._interrupt_event.set() self._interrupt_event.set()
self._is_bot_speaking = False self._is_bot_speaking = False
self._drop_outbound_audio = True self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY # Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio # This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({ await self._send_event({
**ev( **ev(
"response.interrupted", "response.interrupted",
trackId=self.session_id, trackId=self.track_audio_out,
) )
}, priority=0) }, priority=0)
@@ -1455,6 +1756,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None: async def _stop_current_speech(self) -> None:
"""Stop any current speech task.""" """Stop any current speech task."""
self._drop_outbound_audio = True self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done(): if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set() self._interrupt_event.set()
self._current_turn_task.cancel() self._current_turn_task.cancel()

View File

@@ -1,15 +1,16 @@
"""Session management for active calls.""" """Session management for active calls."""
import asyncio import asyncio
import uuid import hashlib
import json import json
import time
import re import re
import time
from enum import Enum from enum import Enum
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from loguru import logger from loguru import logger
from app.backend_client import ( from app.backend_client import (
fetch_assistant_config,
create_history_call_record, create_history_call_record,
add_history_transcript, add_history_transcript,
finalize_history_call_record, finalize_history_call_record,
@@ -49,6 +50,32 @@ class Session:
Uses full duplex voice conversation pipeline. Uses full duplex voice conversation pipeline.
""" """
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
""" """
Initialize session. Initialize session.
@@ -78,7 +105,10 @@ class Session:
self.authenticated: bool = False self.authenticated: bool = False
# Track IDs # Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4()) self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0 self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None self._history_call_started_mono: Optional[float] = None
@@ -89,6 +119,7 @@ class Session:
self._workflow_last_user_text: str = "" self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete) self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})") logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +160,52 @@ class Session:
"client", "client",
"Audio received before session.start", "Audio received before session.start",
"protocol.order", "protocol.order",
stage="protocol",
retryable=False,
) )
return return
try: try:
await self.pipeline.process_audio(audio_bytes) if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
await self._send_error(
"client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=True,
)
while len(self._audio_ingress_buffer) >= frame_bytes:
frame = self._audio_ingress_buffer[:frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
await self.pipeline.process_audio(frame)
except Exception as e: except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None: async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers.""" """Route validated WS v1 message to handlers."""
@@ -217,10 +287,9 @@ class Session:
self.authenticated = True self.authenticated = True
self.protocol_version = message.version self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event( await self._send_event(
ev( ev(
"hello.ack", "hello.ack",
sessionId=self.id,
version=self.protocol_version, version=self.protocol_version,
) )
) )
@@ -231,8 +300,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order") await self._send_error("client", "Duplicate session.start", "protocol.order")
return return
metadata = message.metadata or {} raw_metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts. # Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata) await self._start_history_bridge(metadata)
@@ -248,28 +321,37 @@ class Session:
self.state = "accepted" self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event( await self._send_event(
ev( ev(
"session.started", "session.started",
sessionId=self.id,
trackId=self.current_track_id, trackId=self.current_track_id,
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio or {}, audio=message.audio or {},
) )
) )
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node: if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.started", "workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name, workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id, nodeId=self._workflow_initial_node.id,
) )
) )
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.node.entered", "workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id, nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name, nodeName=self._workflow_initial_node.name,
@@ -285,17 +367,23 @@ class Session:
stop_reason = reason or "client_requested" stop_reason = reason or "client_requested"
self.state = "hungup" self.state = "hungup"
self.ws_state = WsSessionState.STOPPED self.ws_state = WsSessionState.STOPPED
await self.transport.send_event( await self._send_event(
ev( ev(
"session.stopped", "session.stopped",
sessionId=self.id,
reason=stop_reason, reason=stop_reason,
) )
) )
await self._finalize_history(status="connected") await self._finalize_history(status="connected")
await self.transport.close() await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None: async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
) -> None:
""" """
Send error event to client. Send error event to client.
@@ -304,13 +392,25 @@ class Session:
error_message: Error message error_message: Error message
code: Machine-readable error code code: Machine-readable error code
""" """
await self.transport.send_event( resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
await self._send_event(
ev( ev(
"error", "error",
sender=sender, sender=sender,
code=code, code=code,
message=error_message, message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=self.current_track_id, trackId=self.current_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
) )
) )
@@ -483,10 +583,9 @@ class Session:
node = transition.node node = transition.node
edge = transition.edge edge = transition.edge
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.edge.taken", "workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id, edgeId=edge.id,
fromNodeId=edge.from_node_id, fromNodeId=edge.from_node_id,
@@ -494,10 +593,9 @@ class Session:
reason=reason, reason=reason,
) )
) )
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.node.entered", "workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
nodeId=node.id, nodeId=node.id,
nodeName=node.name, nodeName=node.name,
@@ -510,10 +608,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime) self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool": if node.node_type == "tool":
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.tool.requested", "workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
nodeId=node.id, nodeId=node.id,
tool=node.tool or {}, tool=node.tool or {},
@@ -522,10 +619,9 @@ class Session:
return return
if node.node_type == "human_transfer": if node.node_type == "human_transfer":
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.human_transfer", "workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
nodeId=node.id, nodeId=node.id,
) )
@@ -534,16 +630,68 @@ class Session:
return return
if node.node_type == "end": if node.node_type == "end":
await self.transport.send_event( await self._send_event(
ev( ev(
"workflow.ended", "workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id, workflowId=self.workflow_runner.workflow_id,
nodeId=node.id, nodeId=node.id,
) )
) )
await self._handle_session_stop("workflow_end") await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route( async def _workflow_llm_route(
self, self,
node: WorkflowNodeDef, node: WorkflowNodeDef,
@@ -629,6 +777,100 @@ class Session:
merged[key] = value merged[key] = value
return merged return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
if not settings.backend_url:
return {}
payload = await fetch_assistant_config(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg = payload.get("assistant") if isinstance(payload.get("assistant"), dict) else payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
if isinstance(assistant_cfg.get("services"), dict):
runtime["services"] = assistant_cfg.get("services")
if isinstance(assistant_cfg.get("tools"), list):
runtime["tools"] = assistant_cfg.get("tools")
runtime["assistantId"] = str(assistant_id)
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text.""" """Best-effort extraction of a JSON object from freeform text."""
try: try:

View File

@@ -52,43 +52,26 @@ Rules:
"channels": 1 "channels": 1
}, },
"metadata": { "metadata": {
"appId": "assistant_123",
"channel": "web",
"configVersionId": "cfg_20260217_01",
"client": "web-debug", "client": "web-debug",
"output": { "output": {
"mode": "audio" "mode": "audio"
}, },
"systemPrompt": "You are concise.", "systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?", "greeting": "Hi, how can I help?"
"services": {
"llm": {
"provider": "openai",
"model": "gpt-4o-mini",
"apiKey": "sk-...",
"baseUrl": "https://api.openai.com/v1"
},
"asr": {
"provider": "openai_compatible",
"model": "FunAudioLLM/SenseVoiceSmall",
"apiKey": "sf-...",
"interimIntervalMs": 500,
"minAudioMs": 300
},
"tts": {
"enabled": true,
"provider": "openai_compatible",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"apiKey": "sf-...",
"voice": "anna",
"speed": 1.0
}
}
} }
} }
``` ```
`metadata.services` is optional. If omitted, server defaults to environment configuration. Rules:
- Client-side `metadata.services` is ignored.
- Service config (including secrets) is resolved server-side (env/backend).
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints).
Text-only mode: Text-only mode:
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`. - Set `metadata.output.mode = "text"`.
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
### `input.text` ### `input.text`
@@ -121,6 +104,7 @@ Text-only mode:
### `tool_call.results` ### `tool_call.results`
Client tool execution results returned to server. Client tool execution results returned to server.
Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side).
```json ```json
{ {
@@ -138,21 +122,35 @@ Client tool execution results returned to server.
## Server -> Client Events ## Server -> Client Events
All server events include: All server events include an envelope:
```json ```json
{ {
"type": "event.name", "type": "event.name",
"timestamp": 1730000000000 "timestamp": 1730000000000,
"sessionId": "sess_xxx",
"seq": 42,
"source": "asr",
"trackId": "audio_in",
"data": {}
} }
``` ```
Envelope notes:
- `seq` is monotonically increasing within one session (for replay/resume).
- `source` is one of: `asr | llm | tts | tool | system`.
- `data` is structured payload; legacy top-level fields are kept for compatibility.
Common events: Common events:
- `hello.ack` - `hello.ack`
- Fields: `sessionId`, `version` - Fields: `sessionId`, `version`
- `session.started` - `session.started`
- Fields: `sessionId`, `trackId`, `audio` - Fields: `sessionId`, `trackId`, `tracks`, `audio`
- `config.resolved`
- Fields: `sessionId`, `trackId`, `config`
- Sent immediately after `session.started`.
- Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets.
- `session.stopped` - `session.stopped`
- Fields: `sessionId`, `reason` - Fields: `sessionId`, `reason`
- `heartbeat` - `heartbeat`
@@ -169,9 +167,10 @@ Common events:
- `assistant.response.final` - `assistant.response.final`
- Fields: `trackId`, `text` - Fields: `trackId`, `text`
- `assistant.tool_call` - `assistant.tool_call`
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`) - Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms`
- `assistant.tool_result` - `assistant.tool_result`
- Fields: `trackId`, `source`, `result` - Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error`
- `error`: `{ code, message, retryable }` when `ok=false`
- `output.audio.start` - `output.audio.start`
- Fields: `trackId` - Fields: `trackId`
- `output.audio.end` - `output.audio.end`
@@ -183,15 +182,49 @@ Common events:
- `error` - `error`
- Fields: `sender`, `code`, `message`, `trackId` - Fields: `sender`, `code`, `message`, `trackId`
Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`)
Correlation IDs (`event.data`):
- `turn_id`: one user-assistant interaction turn.
- `utterance_id`: one ASR final utterance.
- `response_id`: one assistant response generation.
- `tool_call_id`: one tool invocation.
- `tts_id`: one TTS playback segment.
## Binary Audio Frames ## Binary Audio Frames
After `session.started`, client may send binary PCM chunks continuously. After `session.started`, client may send binary PCM chunks continuously.
Recommended format: MVP fixed format:
- 16-bit signed little-endian PCM. - 16-bit signed little-endian PCM (`pcm_s16le`)
- 1 channel. - mono (1 channel)
- 16000 Hz. - 16000 Hz
- 20ms frames (640 bytes) preferred. - 20ms frame = 640 bytes
Framing rules:
- Binary audio frame unit is 640 bytes.
- A WS binary message may carry one or multiple complete 640-byte frames.
- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors.
TTS boundary events:
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
## Event Throttling
To keep client rendering and server load stable, v1 applies/recommends:
- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms).
- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms).
- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms.
## Error Structure
`error` keeps legacy top-level fields (`code`, `message`) and adds structured info:
- `stage`: `protocol | asr | llm | tts | tool | audio`
- `retryable`: boolean
- `data.error`: `{ stage, code, message, retryable }`
## Compatibility ## Compatibility

View File

@@ -59,8 +59,12 @@ class MicrophoneClient:
url: str, url: str,
sample_rate: int = 16000, sample_rate: int = 16000,
chunk_duration_ms: int = 20, chunk_duration_ms: int = 20,
app_id: str = "assistant_demo",
channel: str = "mic_client",
config_version_id: str = "local-dev",
input_device: int = None, input_device: int = None,
output_device: int = None output_device: int = None,
track_debug: bool = False,
): ):
""" """
Initialize microphone client. Initialize microphone client.
@@ -76,8 +80,12 @@ class MicrophoneClient:
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.input_device = input_device self.input_device = input_device
self.output_device = output_device self.output_device = output_device
self.track_debug = track_debug
# WebSocket connection # WebSocket connection
self.ws = None self.ws = None
@@ -106,6 +114,17 @@ class MicrophoneClient:
# Verbose mode for streaming LLM responses # Verbose mode for streaming LLM responses
self.verbose = False self.verbose = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to WebSocket server.""" """Connect to WebSocket server."""
@@ -114,20 +133,30 @@ class MicrophoneClient:
self.running = True self.running = True
print("Connected!") print("Connected!")
# Send invite command # WS v1 handshake: hello -> session.start
await self.send_command({ await self.send_command({
"command": "invite", "type": "hello",
"option": { "version": "v1",
"codec": "pcm", })
"sampleRate": self.sample_rate await self.send_command({
} "type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
}) })
async def send_command(self, cmd: dict) -> None: async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server.""" """Send JSON command to server."""
if self.ws: if self.ws:
await self.ws.send(json.dumps(cmd)) await self.ws.send(json.dumps(cmd))
print(f"→ Command: {cmd.get('command', 'unknown')}") print(f"→ Command: {cmd.get('type', 'unknown')}")
async def send_chat(self, text: str) -> None: async def send_chat(self, text: str) -> None:
"""Send chat message (text input).""" """Send chat message (text input)."""
@@ -136,7 +165,7 @@ class MicrophoneClient:
self.first_audio_received = False self.first_audio_received = False
await self.send_command({ await self.send_command({
"command": "chat", "type": "input.text",
"text": text "text": text
}) })
print(f"→ Chat: {text}") print(f"→ Chat: {text}")
@@ -144,13 +173,14 @@ class MicrophoneClient:
async def send_interrupt(self) -> None: async def send_interrupt(self) -> None:
"""Send interrupt command.""" """Send interrupt command."""
await self.send_command({ await self.send_command({
"command": "interrupt" "type": "response.cancel",
"graceful": False,
}) })
async def send_hangup(self, reason: str = "User quit") -> None: async def send_hangup(self, reason: str = "User quit") -> None:
"""Send hangup command.""" """Send hangup command."""
await self.send_command({ await self.send_command({
"command": "hangup", "type": "session.stop",
"reason": reason "reason": reason
}) })
@@ -295,43 +325,48 @@ class MicrophoneClient:
async def _handle_event(self, event: dict) -> None: async def _handle_event(self, event: dict) -> None:
"""Handle incoming event.""" """Handle incoming event."""
event_type = event.get("event", "unknown") event_type = event.get("type", event.get("event", "unknown"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "answer": if event_type in {"hello.ack", "session.started"}:
print("← Session ready!") print(f"← Session ready!{ids}")
elif event_type == "speaking": elif event_type == "config.resolved":
print("User speech detected") print(f"Config resolved: {event.get('config', {}).get('output', {})}{ids}")
elif event_type == "silence": elif event_type == "input.speech_started":
print("← User silence detected") print(f"← User speech detected{ids}")
elif event_type == "transcript": elif event_type == "input.speech_stopped":
print(f"← User silence detected{ids}")
elif event_type in {"transcript", "transcript.delta", "transcript.final"}:
# Display user speech transcription # Display user speech transcription
text = event.get("text", "") text = event.get("text", "")
is_final = event.get("isFinal", False) is_final = event_type == "transcript.final" or bool(event.get("isFinal"))
if is_final: if is_final:
# Clear the interim line and print final # Clear the interim line and print final
print(" " * 80, end="\r") # Clear previous interim text print(" " * 80, end="\r") # Clear previous interim text
print(f"→ You: {text}") print(f"→ You: {text}{ids}")
else: else:
# Interim result - show with indicator (overwrite same line) # Interim result - show with indicator (overwrite same line)
display_text = text[:60] + "..." if len(text) > 60 else text display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r") print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb": elif event_type in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event # Server-side TTFB event
latency_ms = event.get("latencyMs", 0) latency_ms = event.get("latencyMs", 0)
print(f"← [TTFB] Server reported latency: {latency_ms}ms") print(f"← [TTFB] Server reported latency: {latency_ms}ms")
elif event_type == "llmResponse": elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}:
# LLM text response # LLM text response
text = event.get("text", "") text = event.get("text", "")
is_final = event.get("isFinal", False) is_final = event_type == "assistant.response.final" or bool(event.get("isFinal"))
if is_final: if is_final:
# Print final LLM response # Print final LLM response
print(f"← AI: {text}") print(f"← AI: {text}")
elif self.verbose: elif self.verbose:
# Show streaming chunks only in verbose mode # Show streaming chunks only in verbose mode
display_text = text[:60] + "..." if len(text) > 60 else text display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [streaming] {display_text}") print(f" [streaming] {display_text}{ids}")
elif event_type == "trackStart": elif event_type in {"trackStart", "output.audio.start"}:
print("← Bot started speaking") print(f"← Bot started speaking{ids}")
# IMPORTANT: Accept audio again after trackStart # IMPORTANT: Accept audio again after trackStart
self._discard_audio = False self._discard_audio = False
self._audio_sequence += 1 self._audio_sequence += 1
@@ -342,13 +377,13 @@ class MicrophoneClient:
# Clear any old audio in buffer # Clear any old audio in buffer
with self.audio_output_lock: with self.audio_output_lock:
self.audio_output_buffer = b"" self.audio_output_buffer = b""
elif event_type == "trackEnd": elif event_type in {"trackEnd", "output.audio.end"}:
print("← Bot finished speaking") print(f"← Bot finished speaking{ids}")
# Reset TTFB tracking after response completes # Reset TTFB tracking after response completes
self.request_start_time = None self.request_start_time = None
self.first_audio_received = False self.first_audio_received = False
elif event_type == "interrupt": elif event_type in {"interrupt", "response.interrupted"}:
print("← Bot interrupted!") print(f"← Bot interrupted!{ids}")
# IMPORTANT: Discard all audio until next trackStart # IMPORTANT: Discard all audio until next trackStart
self._discard_audio = True self._discard_audio = True
# Clear audio buffer immediately # Clear audio buffer immediately
@@ -357,12 +392,12 @@ class MicrophoneClient:
self.audio_output_buffer = b"" self.audio_output_buffer = b""
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)") print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
elif event_type == "error": elif event_type == "error":
print(f"← Error: {event.get('error')}") print(f"← Error: {event.get('error')}{ids}")
elif event_type == "hangup": elif event_type in {"hangup", "session.stopped"}:
print(f"← Hangup: {event.get('reason')}") print(f"← Hangup: {event.get('reason')}{ids}")
self.running = False self.running = False
else: else:
print(f"← Event: {event_type}") print(f"← Event: {event_type}{ids}")
async def interactive_mode(self) -> None: async def interactive_mode(self) -> None:
"""Run interactive mode for text chat.""" """Run interactive mode for text chat."""
@@ -573,6 +608,26 @@ async def main():
action="store_true", action="store_true",
help="Show streaming LLM response chunks" help="Show streaming LLM response chunks"
) )
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="mic_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
args = parser.parse_args() args = parser.parse_args()
@@ -583,8 +638,12 @@ async def main():
client = MicrophoneClient( client = MicrophoneClient(
url=args.url, url=args.url,
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
input_device=args.input_device, input_device=args.input_device,
output_device=args.output_device output_device=args.output_device,
track_debug=args.track_debug,
) )
client.verbose = args.verbose client.verbose = args.verbose

View File

@@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
class SimpleVoiceClient: class SimpleVoiceClient:
"""Simple voice client with reliable audio playback.""" """Simple voice client with reliable audio playback."""
def __init__(self, url: str, sample_rate: int = 16000): def __init__(
self,
url: str,
sample_rate: int = 16000,
app_id: str = "assistant_demo",
channel: str = "simple_client",
config_version_id: str = "local-dev",
track_debug: bool = False,
):
self.url = url self.url = url
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.track_debug = track_debug
self.ws = None self.ws = None
self.running = False self.running = False
@@ -75,6 +87,17 @@ class SimpleVoiceClient:
# Interrupt handling - discard audio until next trackStart # Interrupt handling - discard audio until next trackStart
self._discard_audio = False self._discard_audio = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self): async def connect(self):
"""Connect to server.""" """Connect to server."""
@@ -83,12 +106,25 @@ class SimpleVoiceClient:
self.running = True self.running = True
print("Connected!") print("Connected!")
# Send invite # WS v1 handshake: hello -> session.start
await self.ws.send(json.dumps({ await self.ws.send(json.dumps({
"command": "invite", "type": "hello",
"option": {"codec": "pcm", "sampleRate": self.sample_rate} "version": "v1",
})) }))
print("-> invite") await self.ws.send(json.dumps({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
}))
print("-> hello/session.start")
async def send_chat(self, text: str): async def send_chat(self, text: str):
"""Send chat message.""" """Send chat message."""
@@ -96,8 +132,8 @@ class SimpleVoiceClient:
self.request_start_time = time.time() self.request_start_time = time.time()
self.first_audio_received = False self.first_audio_received = False
await self.ws.send(json.dumps({"command": "chat", "text": text})) await self.ws.send(json.dumps({"type": "input.text", "text": text}))
print(f"-> chat: {text}") print(f"-> input.text: {text}")
def play_audio(self, audio_data: bytes): def play_audio(self, audio_data: bytes):
"""Play audio data immediately.""" """Play audio data immediately."""
@@ -152,34 +188,39 @@ class SimpleVoiceClient:
else: else:
# JSON event # JSON event
event = json.loads(msg) event = json.loads(msg)
etype = event.get("event", "?") etype = event.get("type", event.get("event", "?"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}")
if etype == "transcript": if etype in {"transcript", "transcript.delta", "transcript.final"}:
# User speech transcription # User speech transcription
text = event.get("text", "") text = event.get("text", "")
is_final = event.get("isFinal", False) is_final = etype == "transcript.final" or bool(event.get("isFinal"))
if is_final: if is_final:
print(f"<- You said: {text}") print(f"<- You said: {text}{ids}")
else: else:
print(f"<- [listening] {text}", end="\r") print(f"<- [listening] {text}", end="\r")
elif etype == "ttfb": elif etype in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event # Server-side TTFB event
latency_ms = event.get("latencyMs", 0) latency_ms = event.get("latencyMs", 0)
print(f"<- [TTFB] Server reported latency: {latency_ms}ms") print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
elif etype == "trackStart": elif etype in {"trackStart", "output.audio.start"}:
# New track starting - accept audio again # New track starting - accept audio again
self._discard_audio = False self._discard_audio = False
print(f"<- {etype}") print(f"<- {etype}{ids}")
elif etype == "interrupt": elif etype in {"interrupt", "response.interrupted"}:
# Interrupt - discard audio until next trackStart # Interrupt - discard audio until next trackStart
self._discard_audio = True self._discard_audio = True
print(f"<- {etype} (discarding audio until new track)") print(f"<- {etype}{ids} (discarding audio until new track)")
elif etype == "hangup": elif etype in {"hangup", "session.stopped"}:
print(f"<- {etype}") print(f"<- {etype}{ids}")
self.running = False self.running = False
break break
elif etype == "config.resolved":
print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}")
else: else:
print(f"<- {etype}") print(f"<- {etype}{ids}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
@@ -270,6 +311,10 @@ async def main():
parser.add_argument("--text", help="Send text and play response") parser.add_argument("--text", help="Send text and play response")
parser.add_argument("--list-devices", action="store_true") parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--sample-rate", type=int, default=16000) parser.add_argument("--sample-rate", type=int, default=16000)
parser.add_argument("--app-id", default="assistant_demo")
parser.add_argument("--channel", default="simple_client")
parser.add_argument("--config-version-id", default="local-dev")
parser.add_argument("--track-debug", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@@ -277,7 +322,14 @@ async def main():
list_audio_devices() list_audio_devices()
return return
client = SimpleVoiceClient(args.url, args.sample_rate) client = SimpleVoiceClient(
args.url,
args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
track_debug=args.track_debug,
)
await client.run(args.text) await client.run(args.text)

View File

@@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000):
return audio_data return audio_data
async def receive_loop(ws, ready_event: asyncio.Event): async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False):
"""Listen for incoming messages from the server.""" """Listen for incoming messages from the server."""
def event_ids_suffix(data):
payload = data.get("data") if isinstance(data.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = payload.get(key, data.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
print("👂 Listening for server responses...") print("👂 Listening for server responses...")
async for msg in ws: async for msg in ws:
timestamp = datetime.now().strftime("%H:%M:%S") timestamp = datetime.now().strftime("%H:%M:%S")
@@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event):
try: try:
data = json.loads(msg.data) data = json.loads(msg.data)
event_type = data.get('type', 'Unknown') event_type = data.get('type', 'Unknown')
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...") ids = event_ids_suffix(data)
print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...")
if track_debug:
print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}")
if event_type == "session.started": if event_type == "session.started":
ready_event.set() ready_event.set()
except json.JSONDecodeError: except json.JSONDecodeError:
@@ -113,7 +126,7 @@ async def send_sine_loop(ws):
print("\n✅ Finished streaming test audio.") print("\n✅ Finished streaming test audio.")
async def run_client(url, file_path=None, use_sine=False): async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False):
"""Run the WebSocket test client.""" """Run the WebSocket test client."""
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
try: try:
@@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False):
async with session.ws_connect(url) as ws: async with session.ws_connect(url) as ws:
print("✅ Connected!") print("✅ Connected!")
session_ready = asyncio.Event() session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready)) recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
# Send v1 hello + session.start handshake # Send v1 hello + session.start handshake
await ws.send_json({"type": "hello", "version": "v1"}) await ws.send_json({"type": "hello", "version": "v1"})
@@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False):
"encoding": "pcm_s16le", "encoding": "pcm_s16le",
"sample_rate_hz": SAMPLE_RATE, "sample_rate_hz": SAMPLE_RATE,
"channels": 1 "channels": 1
} },
"metadata": {
"appId": "assistant_demo",
"channel": "test_websocket",
"configVersionId": "local-dev",
},
}) })
print("📤 Sent v1 hello/session.start") print("📤 Sent v1 hello/session.start")
await asyncio.wait_for(session_ready.wait(), timeout=8) await asyncio.wait_for(session_ready.wait(), timeout=8)
@@ -168,9 +186,10 @@ if __name__ == "__main__":
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL") parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
parser.add_argument("--file", help="Path to PCM/WAV file to stream") parser.add_argument("--file", help="Path to PCM/WAV file to stream")
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)") parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging")
args = parser.parse_args() args = parser.parse_args()
try: try:
asyncio.run(run_client(args.url, args.file, args.sine)) asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug))
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n👋 Client stopped.") print("\n👋 Client stopped.")

View File

@@ -57,10 +57,15 @@ class WavFileClient:
url: str, url: str,
input_file: str, input_file: str,
output_file: str, output_file: str,
app_id: str = "assistant_demo",
channel: str = "wav_client",
config_version_id: str = "local-dev",
sample_rate: int = 16000, sample_rate: int = 16000,
chunk_duration_ms: int = 20, chunk_duration_ms: int = 20,
wait_time: float = 15.0, wait_time: float = 15.0,
verbose: bool = False verbose: bool = False,
track_debug: bool = False,
tail_silence_ms: int = 800,
): ):
""" """
Initialize WAV file client. Initialize WAV file client.
@@ -77,11 +82,17 @@ class WavFileClient:
self.url = url self.url = url
self.input_file = Path(input_file) self.input_file = Path(input_file)
self.output_file = Path(output_file) self.output_file = Path(output_file)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.wait_time = wait_time self.wait_time = wait_time
self.verbose = verbose self.verbose = verbose
self.track_debug = track_debug
self.tail_silence_ms = max(0, int(tail_silence_ms))
self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms
# WebSocket connection # WebSocket connection
self.ws = None self.ws = None
@@ -125,6 +136,17 @@ class WavFileClient:
# Replace problematic characters for console output # Replace problematic characters for console output
safe_message = message.encode('ascii', errors='replace').decode('ascii') safe_message = message.encode('ascii', errors='replace').decode('ascii')
print(f"{direction} {safe_message}") print(f"{direction} {safe_message}")
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to WebSocket server.""" """Connect to WebSocket server."""
@@ -144,7 +166,12 @@ class WavFileClient:
"encoding": "pcm_s16le", "encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate, "sample_rate_hz": self.sample_rate,
"channels": 1 "channels": 1
} },
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
}) })
async def send_command(self, cmd: dict) -> None: async def send_command(self, cmd: dict) -> None:
@@ -216,6 +243,10 @@ class WavFileClient:
end_sample = min(sent_samples + chunk_size, total_samples) end_sample = min(sent_samples + chunk_size, total_samples)
chunk = audio_data[sent_samples:end_sample] chunk = audio_data[sent_samples:end_sample]
chunk_bytes = chunk.tobytes() chunk_bytes = chunk.tobytes()
if len(chunk_bytes) % self.frame_bytes != 0:
# v1 audio framing requires 640-byte (20ms) PCM units.
pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes)
chunk_bytes += b"\x00" * pad
# Send to server # Send to server
if self.ws: if self.ws:
@@ -232,6 +263,16 @@ class WavFileClient:
# Delay to simulate real-time streaming # Delay to simulate real-time streaming
# Server expects audio at real-time pace for VAD/ASR to work properly # Server expects audio at real-time pace for VAD/ASR to work properly
await asyncio.sleep(self.chunk_duration_ms / 1000) await asyncio.sleep(self.chunk_duration_ms / 1000)
# Add a short silence tail to help VAD/EOU close the final utterance.
if self.tail_silence_ms > 0 and self.ws:
tail_frames = max(1, self.tail_silence_ms // 20)
silence = b"\x00" * self.frame_bytes
for _ in range(tail_frames):
await self.ws.send(silence)
self.bytes_sent += len(silence)
await asyncio.sleep(0.02)
self.log_event("", f"Sent trailing silence: {self.tail_silence_ms}ms")
self.send_completed = True self.send_completed = True
elapsed = time.time() - self.send_start_time elapsed = time.time() - self.send_start_time
@@ -284,16 +325,22 @@ class WavFileClient:
async def _handle_event(self, event: dict) -> None: async def _handle_event(self, event: dict) -> None:
"""Handle incoming event.""" """Handle incoming event."""
event_type = event.get("type", "unknown") event_type = event.get("type", "unknown")
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "hello.ack": if event_type == "hello.ack":
self.log_event("", "Handshake acknowledged") self.log_event("", f"Handshake acknowledged{ids}")
elif event_type == "session.started": elif event_type == "session.started":
self.session_ready = True self.session_ready = True
self.log_event("", "Session ready!") self.log_event("", f"Session ready!{ids}")
elif event_type == "config.resolved":
config = event.get("config", {})
self.log_event("", f"Config resolved (output={config.get('output', {})}){ids}")
elif event_type == "input.speech_started": elif event_type == "input.speech_started":
self.log_event("", "Speech detected") self.log_event("", f"Speech detected{ids}")
elif event_type == "input.speech_stopped": elif event_type == "input.speech_stopped":
self.log_event("", "Silence detected") self.log_event("", f"Silence detected{ids}")
elif event_type == "transcript.delta": elif event_type == "transcript.delta":
text = event.get("text", "") text = event.get("text", "")
display_text = text[:60] + "..." if len(text) > 60 else text display_text = text[:60] + "..." if len(text) > 60 else text
@@ -301,35 +348,35 @@ class WavFileClient:
elif event_type == "transcript.final": elif event_type == "transcript.final":
text = event.get("text", "") text = event.get("text", "")
print(" " * 80, end="\r") print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}") self.log_event("", f"→ You: {text}{ids}")
elif event_type == "metrics.ttfb": elif event_type == "metrics.ttfb":
latency_ms = event.get("latencyMs", 0) latency_ms = event.get("latencyMs", 0)
self.log_event("", f"[TTFB] Server latency: {latency_ms}ms") self.log_event("", f"[TTFB] Server latency: {latency_ms}ms")
elif event_type == "assistant.response.delta": elif event_type == "assistant.response.delta":
text = event.get("text", "") text = event.get("text", "")
if self.verbose and text: if self.verbose and text:
self.log_event("", f"LLM: {text}") self.log_event("", f"LLM: {text}{ids}")
elif event_type == "assistant.response.final": elif event_type == "assistant.response.final":
text = event.get("text", "") text = event.get("text", "")
if text: if text:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}") self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}")
elif event_type == "output.audio.start": elif event_type == "output.audio.start":
self.track_started = True self.track_started = True
self.response_start_time = time.time() self.response_start_time = time.time()
self.waiting_for_first_audio = True self.waiting_for_first_audio = True
self.log_event("", "Bot started speaking") self.log_event("", f"Bot started speaking{ids}")
elif event_type == "output.audio.end": elif event_type == "output.audio.end":
self.track_ended = True self.track_ended = True
self.log_event("", "Bot finished speaking") self.log_event("", f"Bot finished speaking{ids}")
elif event_type == "response.interrupted": elif event_type == "response.interrupted":
self.log_event("", "Bot interrupted!") self.log_event("", f"Bot interrupted!{ids}")
elif event_type == "error": elif event_type == "error":
self.log_event("!", f"Error: {event.get('message')}") self.log_event("!", f"Error: {event.get('message')}{ids}")
elif event_type == "session.stopped": elif event_type == "session.stopped":
self.log_event("", f"Session stopped: {event.get('reason')}") self.log_event("", f"Session stopped: {event.get('reason')}{ids}")
self.running = False self.running = False
else: else:
self.log_event("", f"Event: {event_type}") self.log_event("", f"Event: {event_type}{ids}")
def save_output_wav(self) -> None: def save_output_wav(self) -> None:
"""Save received audio to output WAV file.""" """Save received audio to output WAV file."""
@@ -473,6 +520,21 @@ async def main():
default=16000, default=16000,
help="Target sample rate for audio (default: 16000)" help="Target sample rate for audio (default: 16000)"
) )
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="wav_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument( parser.add_argument(
"--chunk-duration", "--chunk-duration",
type=int, type=int,
@@ -490,6 +552,17 @@ async def main():
action="store_true", action="store_true",
help="Enable verbose output" help="Enable verbose output"
) )
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
parser.add_argument(
"--tail-silence-ms",
type=int,
default=800,
help="Trailing silence to send after WAV playback for EOU detection (default: 800)"
)
args = parser.parse_args() args = parser.parse_args()
@@ -497,10 +570,15 @@ async def main():
url=args.url, url=args.url,
input_file=args.input, input_file=args.input,
output_file=args.output, output_file=args.output,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration, chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time, wait_time=args.wait_time,
verbose=args.verbose verbose=args.verbose,
track_debug=args.track_debug,
tail_silence_ms=args.tail_silence_ms,
) )
await client.run() await client.run()

View File

@@ -401,6 +401,9 @@
const targetSampleRate = 16000; const targetSampleRate = 16000;
const playbackStopRampSec = 0.008; const playbackStopRampSec = 0.008;
const appId = "assistant_demo";
const channel = "web_client";
const configVersionId = "local-dev";
function logLine(type, text, data) { function logLine(type, text, data) {
const time = new Date().toLocaleTimeString(); const time = new Date().toLocaleTimeString();
@@ -604,15 +607,35 @@
logLine("sys", `${cmd.type}`, cmd); logLine("sys", `${cmd.type}`, cmd);
} }
function eventIdsSuffix(event) {
const data = event && typeof event.data === "object" && event.data ? event.data : {};
const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"];
const parts = [];
for (const key of keys) {
const value = data[key] || event[key];
if (value) parts.push(`${key}=${value}`);
}
return parts.length ? ` [${parts.join(" ")}]` : "";
}
function handleEvent(event) { function handleEvent(event) {
const type = event.type || "unknown"; const type = event.type || "unknown";
logLine("event", type, event); const ids = eventIdsSuffix(event);
logLine("event", `${type}${ids}`, event);
if (type === "hello.ack") { if (type === "hello.ack") {
sendCommand({ sendCommand({
type: "session.start", type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
metadata: {
appId,
channel,
configVersionId,
},
}); });
} }
if (type === "config.resolved") {
logLine("sys", "config.resolved", event.config || {});
}
if (type === "transcript.final") { if (type === "transcript.final") {
if (event.text) { if (event.text) {
setInterim("You", ""); setInterim("You", "");