Update engine

This commit is contained in:
Xin Wang
2026-02-23 17:16:18 +08:00
parent 01c0de0a4d
commit c6c84b5af9
9 changed files with 991 additions and 186 deletions

View File

@@ -24,7 +24,6 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
from models.ws_v1 import ev
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
@@ -54,9 +53,7 @@ async def heartbeat_and_timeout_task(
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await transport.send_event({
**ev("heartbeat"),
})
await session.send_heartbeat()
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
"name": "current_time",
@@ -96,6 +103,9 @@ class DuplexPipeline:
self.transport = transport
self.session_id = session_id
self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD
self.vad_model = SileroVAD(
@@ -120,6 +130,8 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager
self.conversation = ConversationManager(
@@ -153,6 +165,7 @@ class DuplexPipeline:
self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling
self._interrupt_event = asyncio.Event()
@@ -186,9 +199,28 @@ class DuplexPipeline:
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
@@ -276,6 +308,131 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower()
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": sorted(self._runtime_tool_executor.keys()),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
@@ -472,11 +629,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting
if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
text=greeting_to_speak,
trackId=self.session_id,
trackId=self.track_audio_out,
),
priority=20,
)
@@ -494,10 +653,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority)
await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority)
if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +753,13 @@ class DuplexPipeline:
# Emit VAD event
await self.event_bus.publish(event_type, {
"trackId": self.session_id,
"trackId": self.track_audio_in,
"probability": probability
})
await self._send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
trackId=self.track_audio_in,
probability=probability,
),
priority=30,
@@ -661,6 +868,9 @@ class DuplexPipeline:
# Cancel any current speaking
await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn
await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +893,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final:
return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text
# Send transcript event to client
await self._send_event({
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
text=text,
if is_final:
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
{
**ev(
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
}, priority=30)
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None:
"""Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
@@ -779,6 +1010,7 @@ class DuplexPipeline:
return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU.
@@ -786,7 +1018,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"transcript.final",
trackId=self.session_id,
trackId=self.track_audio_in,
text=user_text,
)
}, priority=25)
@@ -794,6 +1026,8 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False
self._pending_speech_audio = b""
@@ -894,6 +1128,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client.
return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1176,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip()
)
normalized = self._normalize_tool_result(result)
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
trackId=self.track_audio_out,
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
)
},
@@ -927,6 +1204,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue
@@ -972,6 +1252,7 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
@@ -998,6 +1279,11 @@ class DuplexPipeline:
user_text: User's transcribed text
"""
try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1012,6 +1298,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False
first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
@@ -1028,6 +1316,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1045,11 +1334,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}"
)
tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call,
)
},
@@ -1071,19 +1368,13 @@ class DuplexPipeline:
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
self._pending_llm_delta += text_chunk
now_ms = time.monotonic() * 1000.0
if (
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
while True:
split_result = extract_tts_sentence(
@@ -1112,11 +1403,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1130,6 +1422,7 @@ class DuplexPipeline:
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if (
self._tts_output_enabled()
and remaining_text
@@ -1137,11 +1430,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1204,11 +1498,12 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
trackId=self.track_audio_out,
text=full_response,
)
},
@@ -1217,10 +1512,11 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1241,6 +1537,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
"""
@@ -1277,7 +1575,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1354,10 +1652,11 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
await self._send_event({
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1379,7 +1678,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1391,10 +1690,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01)
# Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1422,13 +1722,14 @@ class DuplexPipeline:
self._interrupt_event.set()
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({
**ev(
"response.interrupted",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=0)
@@ -1455,6 +1756,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()

View File

@@ -1,15 +1,16 @@
"""Session management for active calls."""
import asyncio
import uuid
import hashlib
import json
import time
import re
import time
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
fetch_assistant_config,
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
@@ -49,6 +50,32 @@ class Session:
Uses full duplex voice conversation pipeline.
"""
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
"""
Initialize session.
@@ -78,7 +105,10 @@ class Session:
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
@@ -89,6 +119,7 @@ class Session:
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +160,52 @@ class Session:
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
await self.pipeline.process_audio(audio_bytes)
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
await self._send_error(
"client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=True,
)
while len(self._audio_ingress_buffer) >= frame_bytes:
frame = self._audio_ingress_buffer[:frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
@@ -217,10 +287,9 @@ class Session:
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
await self._send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
@@ -231,8 +300,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -248,28 +321,37 @@ class Session:
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
await self._send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio or {},
)
)
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
await self._send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
@@ -285,17 +367,23 @@ class Session:
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
await self._send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
) -> None:
"""
Send error event to client.
@@ -304,13 +392,25 @@ class Session:
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=self.current_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
@@ -483,10 +583,9 @@ class Session:
node = transition.node
edge = transition.edge
await self.transport.send_event(
await self._send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
@@ -494,10 +593,9 @@ class Session:
reason=reason,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
@@ -510,10 +608,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
@@ -522,10 +619,9 @@ class Session:
return
if node.node_type == "human_transfer":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
@@ -534,16 +630,68 @@ class Session:
return
if node.node_type == "end":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
@@ -629,6 +777,100 @@ class Session:
merged[key] = value
return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
if not settings.backend_url:
return {}
payload = await fetch_assistant_config(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg = payload.get("assistant") if isinstance(payload.get("assistant"), dict) else payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
if isinstance(assistant_cfg.get("services"), dict):
runtime["services"] = assistant_cfg.get("services")
if isinstance(assistant_cfg.get("tools"), list):
runtime["tools"] = assistant_cfg.get("tools")
runtime["assistantId"] = str(assistant_id)
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:

View File

@@ -52,43 +52,26 @@ Rules:
"channels": 1
},
"metadata": {
"appId": "assistant_123",
"channel": "web",
"configVersionId": "cfg_20260217_01",
"client": "web-debug",
"output": {
"mode": "audio"
},
"systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?",
"services": {
"llm": {
"provider": "openai",
"model": "gpt-4o-mini",
"apiKey": "sk-...",
"baseUrl": "https://api.openai.com/v1"
},
"asr": {
"provider": "openai_compatible",
"model": "FunAudioLLM/SenseVoiceSmall",
"apiKey": "sf-...",
"interimIntervalMs": 500,
"minAudioMs": 300
},
"tts": {
"enabled": true,
"provider": "openai_compatible",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"apiKey": "sf-...",
"voice": "anna",
"speed": 1.0
}
}
"greeting": "Hi, how can I help?"
}
}
```
`metadata.services` is optional. If omitted, server defaults to environment configuration.
Rules:
- Client-side `metadata.services` is ignored.
- Service config (including secrets) is resolved server-side (env/backend).
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints).
Text-only mode:
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
- Set `metadata.output.mode = "text"`.
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
### `input.text`
@@ -121,6 +104,7 @@ Text-only mode:
### `tool_call.results`
Client tool execution results returned to server.
Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side).
```json
{
@@ -138,21 +122,35 @@ Client tool execution results returned to server.
## Server -> Client Events
All server events include:
All server events include an envelope:
```json
{
"type": "event.name",
"timestamp": 1730000000000
"timestamp": 1730000000000,
"sessionId": "sess_xxx",
"seq": 42,
"source": "asr",
"trackId": "audio_in",
"data": {}
}
```
Envelope notes:
- `seq` is monotonically increasing within one session (for replay/resume).
- `source` is one of: `asr | llm | tts | tool | system`.
- `data` is structured payload; legacy top-level fields are kept for compatibility.
Common events:
- `hello.ack`
- Fields: `sessionId`, `version`
- `session.started`
- Fields: `sessionId`, `trackId`, `audio`
- Fields: `sessionId`, `trackId`, `tracks`, `audio`
- `config.resolved`
- Fields: `sessionId`, `trackId`, `config`
- Sent immediately after `session.started`.
- Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets.
- `session.stopped`
- Fields: `sessionId`, `reason`
- `heartbeat`
@@ -169,9 +167,10 @@ Common events:
- `assistant.response.final`
- Fields: `trackId`, `text`
- `assistant.tool_call`
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`)
- Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms`
- `assistant.tool_result`
- Fields: `trackId`, `source`, `result`
- Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error`
- `error`: `{ code, message, retryable }` when `ok=false`
- `output.audio.start`
- Fields: `trackId`
- `output.audio.end`
@@ -183,15 +182,49 @@ Common events:
- `error`
- Fields: `sender`, `code`, `message`, `trackId`
Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`)
Correlation IDs (`event.data`):
- `turn_id`: one user-assistant interaction turn.
- `utterance_id`: one ASR final utterance.
- `response_id`: one assistant response generation.
- `tool_call_id`: one tool invocation.
- `tts_id`: one TTS playback segment.
## Binary Audio Frames
After `session.started`, client may send binary PCM chunks continuously.
Recommended format:
- 16-bit signed little-endian PCM.
- 1 channel.
- 16000 Hz.
- 20ms frames (640 bytes) preferred.
MVP fixed format:
- 16-bit signed little-endian PCM (`pcm_s16le`)
- mono (1 channel)
- 16000 Hz
- 20ms frame = 640 bytes
Framing rules:
- Binary audio frame unit is 640 bytes.
- A WS binary message may carry one or multiple complete 640-byte frames.
- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors.
TTS boundary events:
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
## Event Throttling
To keep client rendering and server load stable, v1 applies/recommends:
- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms).
- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms).
- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms.
## Error Structure
`error` keeps legacy top-level fields (`code`, `message`) and adds structured info:
- `stage`: `protocol | asr | llm | tts | tool | audio`
- `retryable`: boolean
- `data.error`: `{ stage, code, message, retryable }`
## Compatibility

View File

@@ -59,8 +59,12 @@ class MicrophoneClient:
url: str,
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
app_id: str = "assistant_demo",
channel: str = "mic_client",
config_version_id: str = "local-dev",
input_device: int = None,
output_device: int = None
output_device: int = None,
track_debug: bool = False,
):
"""
Initialize microphone client.
@@ -76,8 +80,12 @@ class MicrophoneClient:
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.input_device = input_device
self.output_device = output_device
self.track_debug = track_debug
# WebSocket connection
self.ws = None
@@ -106,6 +114,17 @@ class MicrophoneClient:
# Verbose mode for streaming LLM responses
self.verbose = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None:
"""Connect to WebSocket server."""
@@ -114,20 +133,30 @@ class MicrophoneClient:
self.running = True
print("Connected!")
# Send invite command
# WS v1 handshake: hello -> session.start
await self.send_command({
"command": "invite",
"option": {
"codec": "pcm",
"sampleRate": self.sample_rate
}
"type": "hello",
"version": "v1",
})
await self.send_command({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
})
async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server."""
if self.ws:
await self.ws.send(json.dumps(cmd))
print(f"→ Command: {cmd.get('command', 'unknown')}")
print(f"→ Command: {cmd.get('type', 'unknown')}")
async def send_chat(self, text: str) -> None:
"""Send chat message (text input)."""
@@ -136,7 +165,7 @@ class MicrophoneClient:
self.first_audio_received = False
await self.send_command({
"command": "chat",
"type": "input.text",
"text": text
})
print(f"→ Chat: {text}")
@@ -144,13 +173,14 @@ class MicrophoneClient:
async def send_interrupt(self) -> None:
"""Send interrupt command."""
await self.send_command({
"command": "interrupt"
"type": "response.cancel",
"graceful": False,
})
async def send_hangup(self, reason: str = "User quit") -> None:
"""Send hangup command."""
await self.send_command({
"command": "hangup",
"type": "session.stop",
"reason": reason
})
@@ -295,43 +325,48 @@ class MicrophoneClient:
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("event", "unknown")
event_type = event.get("type", event.get("event", "unknown"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "answer":
print("← Session ready!")
elif event_type == "speaking":
print("User speech detected")
elif event_type == "silence":
print("← User silence detected")
elif event_type == "transcript":
if event_type in {"hello.ack", "session.started"}:
print(f"← Session ready!{ids}")
elif event_type == "config.resolved":
print(f"Config resolved: {event.get('config', {}).get('output', {})}{ids}")
elif event_type == "input.speech_started":
print(f"← User speech detected{ids}")
elif event_type == "input.speech_stopped":
print(f"← User silence detected{ids}")
elif event_type in {"transcript", "transcript.delta", "transcript.final"}:
# Display user speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = event_type == "transcript.final" or bool(event.get("isFinal"))
if is_final:
# Clear the interim line and print final
print(" " * 80, end="\r") # Clear previous interim text
print(f"→ You: {text}")
print(f"→ You: {text}{ids}")
else:
# Interim result - show with indicator (overwrite same line)
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb":
elif event_type in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"← [TTFB] Server reported latency: {latency_ms}ms")
elif event_type == "llmResponse":
elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}:
# LLM text response
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = event_type == "assistant.response.final" or bool(event.get("isFinal"))
if is_final:
# Print final LLM response
print(f"← AI: {text}")
elif self.verbose:
# Show streaming chunks only in verbose mode
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [streaming] {display_text}")
elif event_type == "trackStart":
print("← Bot started speaking")
print(f" [streaming] {display_text}{ids}")
elif event_type in {"trackStart", "output.audio.start"}:
print(f"← Bot started speaking{ids}")
# IMPORTANT: Accept audio again after trackStart
self._discard_audio = False
self._audio_sequence += 1
@@ -342,13 +377,13 @@ class MicrophoneClient:
# Clear any old audio in buffer
with self.audio_output_lock:
self.audio_output_buffer = b""
elif event_type == "trackEnd":
print("← Bot finished speaking")
elif event_type in {"trackEnd", "output.audio.end"}:
print(f"← Bot finished speaking{ids}")
# Reset TTFB tracking after response completes
self.request_start_time = None
self.first_audio_received = False
elif event_type == "interrupt":
print("← Bot interrupted!")
elif event_type in {"interrupt", "response.interrupted"}:
print(f"← Bot interrupted!{ids}")
# IMPORTANT: Discard all audio until next trackStart
self._discard_audio = True
# Clear audio buffer immediately
@@ -357,12 +392,12 @@ class MicrophoneClient:
self.audio_output_buffer = b""
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
elif event_type == "error":
print(f"← Error: {event.get('error')}")
elif event_type == "hangup":
print(f"← Hangup: {event.get('reason')}")
print(f"← Error: {event.get('error')}{ids}")
elif event_type in {"hangup", "session.stopped"}:
print(f"← Hangup: {event.get('reason')}{ids}")
self.running = False
else:
print(f"← Event: {event_type}")
print(f"← Event: {event_type}{ids}")
async def interactive_mode(self) -> None:
"""Run interactive mode for text chat."""
@@ -573,6 +608,26 @@ async def main():
action="store_true",
help="Show streaming LLM response chunks"
)
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="mic_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
args = parser.parse_args()
@@ -583,8 +638,12 @@ async def main():
client = MicrophoneClient(
url=args.url,
sample_rate=args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
input_device=args.input_device,
output_device=args.output_device
output_device=args.output_device,
track_debug=args.track_debug,
)
client.verbose = args.verbose

View File

@@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
class SimpleVoiceClient:
"""Simple voice client with reliable audio playback."""
def __init__(self, url: str, sample_rate: int = 16000):
def __init__(
self,
url: str,
sample_rate: int = 16000,
app_id: str = "assistant_demo",
channel: str = "simple_client",
config_version_id: str = "local-dev",
track_debug: bool = False,
):
self.url = url
self.sample_rate = sample_rate
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.track_debug = track_debug
self.ws = None
self.running = False
@@ -75,6 +87,17 @@ class SimpleVoiceClient:
# Interrupt handling - discard audio until next trackStart
self._discard_audio = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self):
"""Connect to server."""
@@ -83,12 +106,25 @@ class SimpleVoiceClient:
self.running = True
print("Connected!")
# Send invite
# WS v1 handshake: hello -> session.start
await self.ws.send(json.dumps({
"command": "invite",
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
"type": "hello",
"version": "v1",
}))
print("-> invite")
await self.ws.send(json.dumps({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
}))
print("-> hello/session.start")
async def send_chat(self, text: str):
"""Send chat message."""
@@ -96,8 +132,8 @@ class SimpleVoiceClient:
self.request_start_time = time.time()
self.first_audio_received = False
await self.ws.send(json.dumps({"command": "chat", "text": text}))
print(f"-> chat: {text}")
await self.ws.send(json.dumps({"type": "input.text", "text": text}))
print(f"-> input.text: {text}")
def play_audio(self, audio_data: bytes):
"""Play audio data immediately."""
@@ -152,34 +188,39 @@ class SimpleVoiceClient:
else:
# JSON event
event = json.loads(msg)
etype = event.get("event", "?")
etype = event.get("type", event.get("event", "?"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}")
if etype == "transcript":
if etype in {"transcript", "transcript.delta", "transcript.final"}:
# User speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = etype == "transcript.final" or bool(event.get("isFinal"))
if is_final:
print(f"<- You said: {text}")
print(f"<- You said: {text}{ids}")
else:
print(f"<- [listening] {text}", end="\r")
elif etype == "ttfb":
elif etype in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
elif etype == "trackStart":
elif etype in {"trackStart", "output.audio.start"}:
# New track starting - accept audio again
self._discard_audio = False
print(f"<- {etype}")
elif etype == "interrupt":
print(f"<- {etype}{ids}")
elif etype in {"interrupt", "response.interrupted"}:
# Interrupt - discard audio until next trackStart
self._discard_audio = True
print(f"<- {etype} (discarding audio until new track)")
elif etype == "hangup":
print(f"<- {etype}")
print(f"<- {etype}{ids} (discarding audio until new track)")
elif etype in {"hangup", "session.stopped"}:
print(f"<- {etype}{ids}")
self.running = False
break
elif etype == "config.resolved":
print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}")
else:
print(f"<- {etype}")
print(f"<- {etype}{ids}")
except asyncio.TimeoutError:
continue
@@ -270,6 +311,10 @@ async def main():
parser.add_argument("--text", help="Send text and play response")
parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--sample-rate", type=int, default=16000)
parser.add_argument("--app-id", default="assistant_demo")
parser.add_argument("--channel", default="simple_client")
parser.add_argument("--config-version-id", default="local-dev")
parser.add_argument("--track-debug", action="store_true")
args = parser.parse_args()
@@ -277,7 +322,14 @@ async def main():
list_audio_devices()
return
client = SimpleVoiceClient(args.url, args.sample_rate)
client = SimpleVoiceClient(
args.url,
args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
track_debug=args.track_debug,
)
await client.run(args.text)

View File

@@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000):
return audio_data
async def receive_loop(ws, ready_event: asyncio.Event):
async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False):
"""Listen for incoming messages from the server."""
def event_ids_suffix(data):
payload = data.get("data") if isinstance(data.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = payload.get(key, data.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
print("👂 Listening for server responses...")
async for msg in ws:
timestamp = datetime.now().strftime("%H:%M:%S")
@@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event):
try:
data = json.loads(msg.data)
event_type = data.get('type', 'Unknown')
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
ids = event_ids_suffix(data)
print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...")
if track_debug:
print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}")
if event_type == "session.started":
ready_event.set()
except json.JSONDecodeError:
@@ -113,7 +126,7 @@ async def send_sine_loop(ws):
print("\n✅ Finished streaming test audio.")
async def run_client(url, file_path=None, use_sine=False):
async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False):
"""Run the WebSocket test client."""
session = aiohttp.ClientSession()
try:
@@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False):
async with session.ws_connect(url) as ws:
print("✅ Connected!")
session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
# Send v1 hello + session.start handshake
await ws.send_json({"type": "hello", "version": "v1"})
@@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False):
"encoding": "pcm_s16le",
"sample_rate_hz": SAMPLE_RATE,
"channels": 1
}
},
"metadata": {
"appId": "assistant_demo",
"channel": "test_websocket",
"configVersionId": "local-dev",
},
})
print("📤 Sent v1 hello/session.start")
await asyncio.wait_for(session_ready.wait(), timeout=8)
@@ -168,9 +186,10 @@ if __name__ == "__main__":
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging")
args = parser.parse_args()
try:
asyncio.run(run_client(args.url, args.file, args.sine))
asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug))
except KeyboardInterrupt:
print("\n👋 Client stopped.")

View File

@@ -57,10 +57,15 @@ class WavFileClient:
url: str,
input_file: str,
output_file: str,
app_id: str = "assistant_demo",
channel: str = "wav_client",
config_version_id: str = "local-dev",
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
wait_time: float = 15.0,
verbose: bool = False
verbose: bool = False,
track_debug: bool = False,
tail_silence_ms: int = 800,
):
"""
Initialize WAV file client.
@@ -77,11 +82,17 @@ class WavFileClient:
self.url = url
self.input_file = Path(input_file)
self.output_file = Path(output_file)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.wait_time = wait_time
self.verbose = verbose
self.track_debug = track_debug
self.tail_silence_ms = max(0, int(tail_silence_ms))
self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms
# WebSocket connection
self.ws = None
@@ -125,6 +136,17 @@ class WavFileClient:
# Replace problematic characters for console output
safe_message = message.encode('ascii', errors='replace').decode('ascii')
print(f"{direction} {safe_message}")
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None:
"""Connect to WebSocket server."""
@@ -144,7 +166,12 @@ class WavFileClient:
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1
}
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
})
async def send_command(self, cmd: dict) -> None:
@@ -216,6 +243,10 @@ class WavFileClient:
end_sample = min(sent_samples + chunk_size, total_samples)
chunk = audio_data[sent_samples:end_sample]
chunk_bytes = chunk.tobytes()
if len(chunk_bytes) % self.frame_bytes != 0:
# v1 audio framing requires 640-byte (20ms) PCM units.
pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes)
chunk_bytes += b"\x00" * pad
# Send to server
if self.ws:
@@ -232,6 +263,16 @@ class WavFileClient:
# Delay to simulate real-time streaming
# Server expects audio at real-time pace for VAD/ASR to work properly
await asyncio.sleep(self.chunk_duration_ms / 1000)
# Add a short silence tail to help VAD/EOU close the final utterance.
if self.tail_silence_ms > 0 and self.ws:
tail_frames = max(1, self.tail_silence_ms // 20)
silence = b"\x00" * self.frame_bytes
for _ in range(tail_frames):
await self.ws.send(silence)
self.bytes_sent += len(silence)
await asyncio.sleep(0.02)
self.log_event("", f"Sent trailing silence: {self.tail_silence_ms}ms")
self.send_completed = True
elapsed = time.time() - self.send_start_time
@@ -284,16 +325,22 @@ class WavFileClient:
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("type", "unknown")
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "hello.ack":
self.log_event("", "Handshake acknowledged")
self.log_event("", f"Handshake acknowledged{ids}")
elif event_type == "session.started":
self.session_ready = True
self.log_event("", "Session ready!")
self.log_event("", f"Session ready!{ids}")
elif event_type == "config.resolved":
config = event.get("config", {})
self.log_event("", f"Config resolved (output={config.get('output', {})}){ids}")
elif event_type == "input.speech_started":
self.log_event("", "Speech detected")
self.log_event("", f"Speech detected{ids}")
elif event_type == "input.speech_stopped":
self.log_event("", "Silence detected")
self.log_event("", f"Silence detected{ids}")
elif event_type == "transcript.delta":
text = event.get("text", "")
display_text = text[:60] + "..." if len(text) > 60 else text
@@ -301,35 +348,35 @@ class WavFileClient:
elif event_type == "transcript.final":
text = event.get("text", "")
print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}")
self.log_event("", f"→ You: {text}{ids}")
elif event_type == "metrics.ttfb":
latency_ms = event.get("latencyMs", 0)
self.log_event("", f"[TTFB] Server latency: {latency_ms}ms")
elif event_type == "assistant.response.delta":
text = event.get("text", "")
if self.verbose and text:
self.log_event("", f"LLM: {text}")
self.log_event("", f"LLM: {text}{ids}")
elif event_type == "assistant.response.final":
text = event.get("text", "")
if text:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}")
elif event_type == "output.audio.start":
self.track_started = True
self.response_start_time = time.time()
self.waiting_for_first_audio = True
self.log_event("", "Bot started speaking")
self.log_event("", f"Bot started speaking{ids}")
elif event_type == "output.audio.end":
self.track_ended = True
self.log_event("", "Bot finished speaking")
self.log_event("", f"Bot finished speaking{ids}")
elif event_type == "response.interrupted":
self.log_event("", "Bot interrupted!")
self.log_event("", f"Bot interrupted!{ids}")
elif event_type == "error":
self.log_event("!", f"Error: {event.get('message')}")
self.log_event("!", f"Error: {event.get('message')}{ids}")
elif event_type == "session.stopped":
self.log_event("", f"Session stopped: {event.get('reason')}")
self.log_event("", f"Session stopped: {event.get('reason')}{ids}")
self.running = False
else:
self.log_event("", f"Event: {event_type}")
self.log_event("", f"Event: {event_type}{ids}")
def save_output_wav(self) -> None:
"""Save received audio to output WAV file."""
@@ -473,6 +520,21 @@ async def main():
default=16000,
help="Target sample rate for audio (default: 16000)"
)
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="wav_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--chunk-duration",
type=int,
@@ -490,6 +552,17 @@ async def main():
action="store_true",
help="Enable verbose output"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
parser.add_argument(
"--tail-silence-ms",
type=int,
default=800,
help="Trailing silence to send after WAV playback for EOU detection (default: 800)"
)
args = parser.parse_args()
@@ -497,10 +570,15 @@ async def main():
url=args.url,
input_file=args.input,
output_file=args.output,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time,
verbose=args.verbose
verbose=args.verbose,
track_debug=args.track_debug,
tail_silence_ms=args.tail_silence_ms,
)
await client.run()

View File

@@ -401,6 +401,9 @@
const targetSampleRate = 16000;
const playbackStopRampSec = 0.008;
const appId = "assistant_demo";
const channel = "web_client";
const configVersionId = "local-dev";
function logLine(type, text, data) {
const time = new Date().toLocaleTimeString();
@@ -604,15 +607,35 @@
logLine("sys", `${cmd.type}`, cmd);
}
function eventIdsSuffix(event) {
const data = event && typeof event.data === "object" && event.data ? event.data : {};
const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"];
const parts = [];
for (const key of keys) {
const value = data[key] || event[key];
if (value) parts.push(`${key}=${value}`);
}
return parts.length ? ` [${parts.join(" ")}]` : "";
}
function handleEvent(event) {
const type = event.type || "unknown";
logLine("event", type, event);
const ids = eventIdsSuffix(event);
logLine("event", `${type}${ids}`, event);
if (type === "hello.ack") {
sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
metadata: {
appId,
channel,
configVersionId,
},
});
}
if (type === "config.resolved") {
logLine("sys", "config.resolved", event.config || {});
}
if (type === "transcript.final") {
if (event.text) {
setInterim("You", "");