Update engine

This commit is contained in:
Xin Wang
2026-02-23 17:16:18 +08:00
parent 01c0de0a4d
commit c6c84b5af9
9 changed files with 991 additions and 186 deletions

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
"name": "current_time",
@@ -96,6 +103,9 @@ class DuplexPipeline:
self.transport = transport
self.session_id = session_id
self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD
self.vad_model = SileroVAD(
@@ -120,6 +130,8 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager
self.conversation = ConversationManager(
@@ -153,6 +165,7 @@ class DuplexPipeline:
self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling
self._interrupt_event = asyncio.Event()
@@ -186,9 +199,28 @@ class DuplexPipeline:
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
@@ -276,6 +308,131 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower()
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": sorted(self._runtime_tool_executor.keys()),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
@@ -472,11 +629,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting
if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
text=greeting_to_speak,
trackId=self.session_id,
trackId=self.track_audio_out,
),
priority=20,
)
@@ -494,10 +653,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority)
await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority)
if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +753,13 @@ class DuplexPipeline:
# Emit VAD event
await self.event_bus.publish(event_type, {
"trackId": self.session_id,
"trackId": self.track_audio_in,
"probability": probability
})
await self._send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
trackId=self.track_audio_in,
probability=probability,
),
priority=30,
@@ -661,6 +868,9 @@ class DuplexPipeline:
# Cancel any current speaking
await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn
await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +893,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final:
return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text
# Send transcript event to client
await self._send_event({
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
text=text,
if is_final:
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
{
**ev(
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
}, priority=30)
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None:
"""Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
@@ -779,6 +1010,7 @@ class DuplexPipeline:
return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU.
@@ -786,7 +1018,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"transcript.final",
trackId=self.session_id,
trackId=self.track_audio_in,
text=user_text,
)
}, priority=25)
@@ -794,6 +1026,8 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False
self._pending_speech_audio = b""
@@ -894,6 +1128,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client.
return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1176,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip()
)
normalized = self._normalize_tool_result(result)
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
trackId=self.track_audio_out,
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
)
},
@@ -927,6 +1204,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue
@@ -972,6 +1252,7 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
@@ -998,6 +1279,11 @@ class DuplexPipeline:
user_text: User's transcribed text
"""
try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1012,6 +1298,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False
first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
@@ -1028,6 +1316,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1045,11 +1334,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}"
)
tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call,
)
},
@@ -1071,19 +1368,13 @@ class DuplexPipeline:
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
self._pending_llm_delta += text_chunk
now_ms = time.monotonic() * 1000.0
if (
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
while True:
split_result = extract_tts_sentence(
@@ -1112,11 +1403,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1130,6 +1422,7 @@ class DuplexPipeline:
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if (
self._tts_output_enabled()
and remaining_text
@@ -1137,11 +1430,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1204,11 +1498,12 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
trackId=self.track_audio_out,
text=full_response,
)
},
@@ -1217,10 +1512,11 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1241,6 +1537,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
"""
@@ -1277,7 +1575,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1354,10 +1652,11 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
await self._send_event({
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1379,7 +1678,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1391,10 +1690,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01)
# Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1422,13 +1722,14 @@ class DuplexPipeline:
self._interrupt_event.set()
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({
**ev(
"response.interrupted",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=0)
@@ -1455,6 +1756,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()