Implement WS v1 protocol and runtime-config powered debug drawer

This commit is contained in:
Xin Wang
2026-02-09 08:19:39 +08:00
parent 0fc56e2685
commit fb6d1eb1da
13 changed files with 986 additions and 298 deletions

View File

@@ -13,7 +13,7 @@ event-driven design.
import asyncio
import time
from typing import Optional, Callable, Awaitable
from typing import Optional, Callable, Awaitable, Dict, Any
from loguru import logger
from core.transports import BaseTransport
@@ -28,6 +28,7 @@ from services.asr import BufferedASRService
from services.siliconflow_tts import SiliconFlowTTSService
from services.siliconflow_asr import SiliconFlowASRService
from app.config import settings
from models.ws_v1 import ev
class DuplexPipeline:
@@ -126,19 +127,66 @@ class DuplexPipeline:
self._barge_in_speech_frames: int = 0 # Count speech frames
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
# Runtime overrides injected from session.start metadata
self._runtime_llm: Dict[str, Any] = {}
self._runtime_asr: Dict[str, Any] = {}
self._runtime_tts: Dict[str, Any] = {}
self._runtime_system_prompt: Optional[str] = None
self._runtime_greeting: Optional[str] = None
logger.info(f"DuplexPipeline initialized for session {session_id}")
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
Expected metadata shape:
{
"systemPrompt": "...",
"greeting": "...",
"services": {
"llm": {...},
"asr": {...},
"tts": {...}
}
}
"""
if not metadata:
return
if "systemPrompt" in metadata:
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
if self._runtime_system_prompt:
self.conversation.system_prompt = self._runtime_system_prompt
if "greeting" in metadata:
self._runtime_greeting = str(metadata.get("greeting") or "")
self.conversation.greeting = self._runtime_greeting or None
services = metadata.get("services") or {}
if isinstance(services, dict):
if isinstance(services.get("llm"), dict):
self._runtime_llm = services["llm"]
if isinstance(services.get("asr"), dict):
self._runtime_asr = services["asr"]
if isinstance(services.get("tts"), dict):
self._runtime_tts = services["tts"]
async def start(self) -> None:
"""Start the pipeline and connect services."""
try:
# Connect LLM service
if not self.llm_service:
if settings.openai_api_key:
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
llm_model = self._runtime_llm.get("model") or settings.llm_model
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
if llm_provider == "openai" and llm_api_key:
self.llm_service = OpenAILLMService(
api_key=settings.openai_api_key,
base_url=settings.openai_api_url,
model=settings.llm_model
api_key=llm_api_key,
base_url=llm_base_url,
model=llm_model
)
else:
logger.warning("No OpenAI API key - using mock LLM")
@@ -148,33 +196,52 @@ class DuplexPipeline:
# Connect TTS service
if not self.tts_service:
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
if tts_provider == "siliconflow" and tts_api_key:
self.tts_service = SiliconFlowTTSService(
api_key=settings.siliconflow_api_key,
voice=settings.tts_voice,
model=settings.siliconflow_tts_model,
api_key=tts_api_key,
voice=tts_voice,
model=tts_model,
sample_rate=settings.sample_rate,
speed=settings.tts_speed
speed=tts_speed
)
logger.info("Using SiliconFlow TTS service")
else:
self.tts_service = EdgeTTSService(
voice=settings.tts_voice,
voice=tts_voice,
sample_rate=settings.sample_rate
)
logger.info("Using Edge TTS service")
await self.tts_service.connect()
try:
await self.tts_service.connect()
except Exception as e:
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
self.tts_service = MockTTSService(
sample_rate=settings.sample_rate
)
await self.tts_service.connect()
# Connect ASR service
if not self.asr_service:
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key:
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
if asr_provider == "siliconflow" and asr_api_key:
self.asr_service = SiliconFlowASRService(
api_key=settings.siliconflow_api_key,
model=settings.siliconflow_asr_model,
api_key=asr_api_key,
model=asr_model,
sample_rate=settings.sample_rate,
interim_interval_ms=settings.asr_interim_interval_ms,
min_audio_for_interim_ms=settings.asr_min_audio_ms,
interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback
)
logger.info("Using SiliconFlow ASR service")
@@ -223,6 +290,13 @@ class DuplexPipeline:
"trackId": self.session_id,
"probability": probability
})
await self.transport.send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
probability=probability,
)
)
else:
# No state change - keep previous status
vad_status = self._last_vad_status
@@ -325,11 +399,11 @@ class DuplexPipeline:
# Send transcript event to client
await self.transport.send_event({
"event": "transcript",
"trackId": self.session_id,
"text": text,
"isFinal": is_final,
"timestamp": self._get_timestamp_ms()
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
text=text,
)
})
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
@@ -383,11 +457,11 @@ class DuplexPipeline:
# Send final transcription to client
await self.transport.send_event({
"event": "transcript",
"trackId": self.session_id,
"text": user_text,
"isFinal": True,
"timestamp": self._get_timestamp_ms()
**ev(
"transcript.final",
trackId=self.session_id,
text=user_text,
)
})
# Clear buffers
@@ -438,11 +512,11 @@ class DuplexPipeline:
# Send LLM response streaming event to client
await self.transport.send_event({
"event": "llmResponse",
"trackId": self.session_id,
"text": text_chunk,
"isFinal": False,
"timestamp": self._get_timestamp_ms()
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
})
# Check for sentence completion - synthesize immediately for low latency
@@ -462,9 +536,10 @@ class DuplexPipeline:
# Send track start on first audio
if not first_audio_sent:
await self.transport.send_event({
"event": "trackStart",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"output.audio.start",
trackId=self.session_id,
)
})
first_audio_sent = True
@@ -476,20 +551,21 @@ class DuplexPipeline:
# Send final LLM response event
if full_response and not self._interrupt_event.is_set():
await self.transport.send_event({
"event": "llmResponse",
"trackId": self.session_id,
"text": full_response,
"isFinal": True,
"timestamp": self._get_timestamp_ms()
**ev(
"assistant.response.final",
trackId=self.session_id,
text=full_response,
)
})
# Speak any remaining text
if sentence_buffer.strip() and not self._interrupt_event.is_set():
if not first_audio_sent:
await self.transport.send_event({
"event": "trackStart",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"output.audio.start",
trackId=self.session_id,
)
})
first_audio_sent = True
await self._speak_sentence(sentence_buffer.strip())
@@ -497,9 +573,10 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self.transport.send_event({
"event": "trackEnd",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"output.audio.end",
trackId=self.session_id,
)
})
# End assistant turn
@@ -545,10 +622,11 @@ class DuplexPipeline:
# Send TTFB event to client
await self.transport.send_event({
"event": "ttfb",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms(),
"latencyMs": round(ttfb_ms)
**ev(
"metrics.ttfb",
trackId=self.session_id,
latencyMs=round(ttfb_ms),
)
})
# Double-check interrupt right before sending audio
@@ -579,9 +657,10 @@ class DuplexPipeline:
# Send track start event
await self.transport.send_event({
"event": "trackStart",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"output.audio.start",
trackId=self.session_id,
)
})
self._is_bot_speaking = True
@@ -600,10 +679,11 @@ class DuplexPipeline:
# Send TTFB event to client
await self.transport.send_event({
"event": "ttfb",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms(),
"latencyMs": round(ttfb_ms)
**ev(
"metrics.ttfb",
trackId=self.session_id,
latencyMs=round(ttfb_ms),
)
})
# Send audio to client
@@ -614,9 +694,10 @@ class DuplexPipeline:
# Send track end event
await self.transport.send_event({
"event": "trackEnd",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"output.audio.end",
trackId=self.session_id,
)
})
except asyncio.CancelledError:
@@ -646,9 +727,10 @@ class DuplexPipeline:
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self.transport.send_event({
"event": "interrupt",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
**ev(
"response.interrupted",
trackId=self.session_id,
)
})
# Cancel TTS