Update engine

This commit is contained in:
Xin Wang
2026-02-23 17:16:18 +08:00
parent 01c0de0a4d
commit c6c84b5af9
9 changed files with 991 additions and 186 deletions

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
"name": "current_time",
@@ -96,6 +103,9 @@ class DuplexPipeline:
self.transport = transport
self.session_id = session_id
self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD
self.vad_model = SileroVAD(
@@ -120,6 +130,8 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager
self.conversation = ConversationManager(
@@ -153,6 +165,7 @@ class DuplexPipeline:
self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling
self._interrupt_event = asyncio.Event()
@@ -186,9 +199,28 @@ class DuplexPipeline:
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
@@ -276,6 +308,131 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower()
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": sorted(self._runtime_tool_executor.keys()),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
@@ -472,11 +629,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting
if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
text=greeting_to_speak,
trackId=self.session_id,
trackId=self.track_audio_out,
),
priority=20,
)
@@ -494,10 +653,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority)
await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority)
if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +753,13 @@ class DuplexPipeline:
# Emit VAD event
await self.event_bus.publish(event_type, {
"trackId": self.session_id,
"trackId": self.track_audio_in,
"probability": probability
})
await self._send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
trackId=self.track_audio_in,
probability=probability,
),
priority=30,
@@ -661,6 +868,9 @@ class DuplexPipeline:
# Cancel any current speaking
await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn
await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +893,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final:
return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text
# Send transcript event to client
await self._send_event({
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
text=text,
if is_final:
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
{
**ev(
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
}, priority=30)
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None:
"""Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
@@ -779,6 +1010,7 @@ class DuplexPipeline:
return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU.
@@ -786,7 +1018,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"transcript.final",
trackId=self.session_id,
trackId=self.track_audio_in,
text=user_text,
)
}, priority=25)
@@ -794,6 +1026,8 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False
self._pending_speech_audio = b""
@@ -894,6 +1128,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client.
return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1176,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip()
)
normalized = self._normalize_tool_result(result)
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
trackId=self.track_audio_out,
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
)
},
@@ -927,6 +1204,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue
@@ -972,6 +1252,7 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
@@ -998,6 +1279,11 @@ class DuplexPipeline:
user_text: User's transcribed text
"""
try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1012,6 +1298,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False
first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
@@ -1028,6 +1316,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1045,11 +1334,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}"
)
tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call,
)
},
@@ -1071,19 +1368,13 @@ class DuplexPipeline:
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
self._pending_llm_delta += text_chunk
now_ms = time.monotonic() * 1000.0
if (
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
while True:
split_result = extract_tts_sentence(
@@ -1112,11 +1403,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1130,6 +1422,7 @@ class DuplexPipeline:
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if (
self._tts_output_enabled()
and remaining_text
@@ -1137,11 +1430,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1204,11 +1498,12 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
trackId=self.track_audio_out,
text=full_response,
)
},
@@ -1217,10 +1512,11 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1241,6 +1537,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
"""
@@ -1277,7 +1575,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1354,10 +1652,11 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
await self._send_event({
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1379,7 +1678,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1391,10 +1690,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01)
# Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1422,13 +1722,14 @@ class DuplexPipeline:
self._interrupt_event.set()
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({
**ev(
"response.interrupted",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=0)
@@ -1455,6 +1756,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()

View File

@@ -1,15 +1,16 @@
"""Session management for active calls."""
import asyncio
import uuid
import hashlib
import json
import time
import re
import time
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
fetch_assistant_config,
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
@@ -49,6 +50,32 @@ class Session:
Uses full duplex voice conversation pipeline.
"""
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
"""
Initialize session.
@@ -78,7 +105,10 @@ class Session:
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
@@ -89,6 +119,7 @@ class Session:
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +160,52 @@ class Session:
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
await self.pipeline.process_audio(audio_bytes)
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
await self._send_error(
"client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=True,
)
while len(self._audio_ingress_buffer) >= frame_bytes:
frame = self._audio_ingress_buffer[:frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
@@ -217,10 +287,9 @@ class Session:
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
await self._send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
@@ -231,8 +300,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -248,28 +321,37 @@ class Session:
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
await self._send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio or {},
)
)
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
await self._send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
@@ -285,17 +367,23 @@ class Session:
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
await self._send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
) -> None:
"""
Send error event to client.
@@ -304,13 +392,25 @@ class Session:
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=self.current_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
@@ -483,10 +583,9 @@ class Session:
node = transition.node
edge = transition.edge
await self.transport.send_event(
await self._send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
@@ -494,10 +593,9 @@ class Session:
reason=reason,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
@@ -510,10 +608,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
@@ -522,10 +619,9 @@ class Session:
return
if node.node_type == "human_transfer":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
@@ -534,16 +630,68 @@ class Session:
return
if node.node_type == "end":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
@@ -629,6 +777,100 @@ class Session:
merged[key] = value
return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
if not settings.backend_url:
return {}
payload = await fetch_assistant_config(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg = payload.get("assistant") if isinstance(payload.get("assistant"), dict) else payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
if isinstance(assistant_cfg.get("services"), dict):
runtime["services"] = assistant_cfg.get("services")
if isinstance(assistant_cfg.get("tools"), list):
runtime["tools"] = assistant_cfg.get("tools")
runtime["assistantId"] = str(assistant_id)
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try: