Compare commits

...

2 Commits

Author SHA1 Message Date
Xin Wang
c15c5283e2 Merge WS v1 engine and web debug runtime-config integration 2026-02-09 08:19:55 +08:00
Xin Wang
fb6d1eb1da Implement WS v1 protocol and runtime-config powered debug drawer 2026-02-09 08:19:39 +08:00
13 changed files with 986 additions and 298 deletions

View File

@@ -5,7 +5,7 @@ import uuid
from datetime import datetime from datetime import datetime
from ..db import get_db from ..db import get_db
from ..models import Assistant from ..models import Assistant, LLMModel, ASRModel, Voice
from ..schemas import ( from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut AssistantCreate, AssistantUpdate, AssistantOut
) )
@@ -13,6 +13,73 @@ from ..schemas import (
router = APIRouter(prefix="/assistants", tags=["Assistants"]) router = APIRouter(prefix="/assistants", tags=["Assistants"])
def _is_siliconflow_vendor(vendor: Optional[str]) -> bool:
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
metadata = {
"systemPrompt": assistant.prompt or "",
"greeting": assistant.opener or "",
"services": {},
}
warnings = []
if assistant.llm_model_id:
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
if llm:
metadata["services"]["llm"] = {
"provider": "openai",
"model": llm.model_name or llm.name,
"apiKey": llm.api_key,
"baseUrl": llm.base_url,
}
else:
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
if assistant.asr_model_id:
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
if asr:
asr_provider = "siliconflow" if _is_siliconflow_vendor(asr.vendor) else "buffered"
metadata["services"]["asr"] = {
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "siliconflow" else None,
}
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
if assistant.voice:
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
if voice:
tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge"
metadata["services"]["tts"] = {
"provider": tts_provider,
"model": voice.model,
"apiKey": voice.api_key if tts_provider == "siliconflow" else None,
"voice": voice.voice_key or voice.id,
"speed": assistant.speed or voice.speed,
}
else:
# Keep assistant.voice as direct voice identifier fallback
metadata["services"]["tts"] = {
"voice": assistant.voice,
"speed": assistant.speed or 1.0,
}
warnings.append(f"Voice resource not found: {assistant.voice}")
return {
"assistantId": assistant.id,
"sessionStartMetadata": metadata,
"sources": {
"llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
"voiceId": assistant.voice,
},
"warnings": warnings,
}
def assistant_to_dict(assistant: Assistant) -> dict: def assistant_to_dict(assistant: Assistant) -> dict:
return { return {
"id": assistant.id, "id": assistant.id,
@@ -84,6 +151,15 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
return assistant_to_dict(assistant) return assistant_to_dict(assistant)
@router.get("/{id}/runtime-config")
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
"""Resolve assistant runtime config for engine WS session.start metadata."""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _resolve_runtime_metadata(db, assistant)
@router.post("", response_model=AssistantOut) @router.post("", response_model=AssistantOut)
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
"""创建新助手""" """创建新助手"""
@@ -139,4 +215,3 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
db.delete(assistant) db.delete(assistant)
db.commit() db.commit()
return {"message": "Deleted successfully"} return {"message": "Deleted successfully"}

View File

@@ -166,3 +166,39 @@ class TestAssistantAPI:
response = client.post("/api/assistants", json=sample_assistant_data) response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["language"] == lang assert response.json()["language"] == lang
def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data):
"""Test resolved runtime config endpoint for WS session.start metadata."""
llm_resp = client.post("/api/llm", json=sample_llm_model_data)
assert llm_resp.status_code == 200
asr_resp = client.post("/api/asr", json=sample_asr_model_data)
assert asr_resp.status_code == 200
voice_resp = client.post("/api/voices", json=sample_voice_data)
assert voice_resp.status_code == 200
voice_id = voice_resp.json()["id"]
sample_assistant_data.update({
"llmModelId": sample_llm_model_data["id"],
"asrModelId": sample_asr_model_data["id"],
"voice": voice_id,
"prompt": "runtime prompt",
"opener": "runtime opener",
"speed": 1.1,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
payload = runtime_resp.json()
assert payload["assistantId"] == assistant_id
metadata = payload["sessionStartMetadata"]
assert metadata["systemPrompt"] == "runtime prompt"
assert metadata["greeting"] == "runtime opener"
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"]

View File

@@ -23,3 +23,9 @@ python examples/test_websocket.py
``` ```
python mic_client.py python mic_client.py
``` ```
## WS Protocol
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.

View File

@@ -87,6 +87,9 @@ class Settings(BaseSettings):
# WebSocket heartbeat and inactivity # WebSocket heartbeat and inactivity
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)") inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds") heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
@property @property
def chunk_size_bytes(self) -> int: def chunk_size_bytes(self) -> int:

View File

@@ -24,6 +24,7 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session from core.session import Session
from processors.tracks import Resampled16kTrack from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus from core.events import get_event_bus, reset_event_bus
from models.ws_v1 import ev
# Check interval for heartbeat/timeout (seconds) # Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5 _HEARTBEAT_CHECK_INTERVAL_SEC = 5
@@ -54,8 +55,7 @@ async def heartbeat_and_timeout_task(
if now - last_heartbeat_at[0] >= heartbeat_interval_sec: if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try: try:
await transport.send_event({ await transport.send_event({
"event": "heartBeat", **ev("heartbeat"),
"timestamp": int(time.time() * 1000),
}) })
last_heartbeat_at[0] = now last_heartbeat_at[0] = now
except Exception as e: except Exception as e:

View File

@@ -13,7 +13,7 @@ event-driven design.
import asyncio import asyncio
import time import time
from typing import Optional, Callable, Awaitable from typing import Optional, Callable, Awaitable, Dict, Any
from loguru import logger from loguru import logger
from core.transports import BaseTransport from core.transports import BaseTransport
@@ -28,6 +28,7 @@ from services.asr import BufferedASRService
from services.siliconflow_tts import SiliconFlowTTSService from services.siliconflow_tts import SiliconFlowTTSService
from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_asr import SiliconFlowASRService
from app.config import settings from app.config import settings
from models.ws_v1 import ev
class DuplexPipeline: class DuplexPipeline:
@@ -127,18 +128,65 @@ class DuplexPipeline:
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks) self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
# Runtime overrides injected from session.start metadata
self._runtime_llm: Dict[str, Any] = {}
self._runtime_asr: Dict[str, Any] = {}
self._runtime_tts: Dict[str, Any] = {}
self._runtime_system_prompt: Optional[str] = None
self._runtime_greeting: Optional[str] = None
logger.info(f"DuplexPipeline initialized for session {session_id}") logger.info(f"DuplexPipeline initialized for session {session_id}")
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
Expected metadata shape:
{
"systemPrompt": "...",
"greeting": "...",
"services": {
"llm": {...},
"asr": {...},
"tts": {...}
}
}
"""
if not metadata:
return
if "systemPrompt" in metadata:
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
if self._runtime_system_prompt:
self.conversation.system_prompt = self._runtime_system_prompt
if "greeting" in metadata:
self._runtime_greeting = str(metadata.get("greeting") or "")
self.conversation.greeting = self._runtime_greeting or None
services = metadata.get("services") or {}
if isinstance(services, dict):
if isinstance(services.get("llm"), dict):
self._runtime_llm = services["llm"]
if isinstance(services.get("asr"), dict):
self._runtime_asr = services["asr"]
if isinstance(services.get("tts"), dict):
self._runtime_tts = services["tts"]
async def start(self) -> None: async def start(self) -> None:
"""Start the pipeline and connect services.""" """Start the pipeline and connect services."""
try: try:
# Connect LLM service # Connect LLM service
if not self.llm_service: if not self.llm_service:
if settings.openai_api_key: llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
llm_model = self._runtime_llm.get("model") or settings.llm_model
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
if llm_provider == "openai" and llm_api_key:
self.llm_service = OpenAILLMService( self.llm_service = OpenAILLMService(
api_key=settings.openai_api_key, api_key=llm_api_key,
base_url=settings.openai_api_url, base_url=llm_base_url,
model=settings.llm_model model=llm_model
) )
else: else:
logger.warning("No OpenAI API key - using mock LLM") logger.warning("No OpenAI API key - using mock LLM")
@@ -148,33 +196,52 @@ class DuplexPipeline:
# Connect TTS service # Connect TTS service
if not self.tts_service: if not self.tts_service:
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key: tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
if tts_provider == "siliconflow" and tts_api_key:
self.tts_service = SiliconFlowTTSService( self.tts_service = SiliconFlowTTSService(
api_key=settings.siliconflow_api_key, api_key=tts_api_key,
voice=settings.tts_voice, voice=tts_voice,
model=settings.siliconflow_tts_model, model=tts_model,
sample_rate=settings.sample_rate, sample_rate=settings.sample_rate,
speed=settings.tts_speed speed=tts_speed
) )
logger.info("Using SiliconFlow TTS service") logger.info("Using SiliconFlow TTS service")
else: else:
self.tts_service = EdgeTTSService( self.tts_service = EdgeTTSService(
voice=settings.tts_voice, voice=tts_voice,
sample_rate=settings.sample_rate sample_rate=settings.sample_rate
) )
logger.info("Using Edge TTS service") logger.info("Using Edge TTS service")
await self.tts_service.connect() try:
await self.tts_service.connect()
except Exception as e:
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
self.tts_service = MockTTSService(
sample_rate=settings.sample_rate
)
await self.tts_service.connect()
# Connect ASR service # Connect ASR service
if not self.asr_service: if not self.asr_service:
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key: asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
if asr_provider == "siliconflow" and asr_api_key:
self.asr_service = SiliconFlowASRService( self.asr_service = SiliconFlowASRService(
api_key=settings.siliconflow_api_key, api_key=asr_api_key,
model=settings.siliconflow_asr_model, model=asr_model,
sample_rate=settings.sample_rate, sample_rate=settings.sample_rate,
interim_interval_ms=settings.asr_interim_interval_ms, interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=settings.asr_min_audio_ms, min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback on_transcript=self._on_transcript_callback
) )
logger.info("Using SiliconFlow ASR service") logger.info("Using SiliconFlow ASR service")
@@ -223,6 +290,13 @@ class DuplexPipeline:
"trackId": self.session_id, "trackId": self.session_id,
"probability": probability "probability": probability
}) })
await self.transport.send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
probability=probability,
)
)
else: else:
# No state change - keep previous status # No state change - keep previous status
vad_status = self._last_vad_status vad_status = self._last_vad_status
@@ -325,11 +399,11 @@ class DuplexPipeline:
# Send transcript event to client # Send transcript event to client
await self.transport.send_event({ await self.transport.send_event({
"event": "transcript", **ev(
"trackId": self.session_id, "transcript.final" if is_final else "transcript.delta",
"text": text, trackId=self.session_id,
"isFinal": is_final, text=text,
"timestamp": self._get_timestamp_ms() )
}) })
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
@@ -383,11 +457,11 @@ class DuplexPipeline:
# Send final transcription to client # Send final transcription to client
await self.transport.send_event({ await self.transport.send_event({
"event": "transcript", **ev(
"trackId": self.session_id, "transcript.final",
"text": user_text, trackId=self.session_id,
"isFinal": True, text=user_text,
"timestamp": self._get_timestamp_ms() )
}) })
# Clear buffers # Clear buffers
@@ -438,11 +512,11 @@ class DuplexPipeline:
# Send LLM response streaming event to client # Send LLM response streaming event to client
await self.transport.send_event({ await self.transport.send_event({
"event": "llmResponse", **ev(
"trackId": self.session_id, "assistant.response.delta",
"text": text_chunk, trackId=self.session_id,
"isFinal": False, text=text_chunk,
"timestamp": self._get_timestamp_ms() )
}) })
# Check for sentence completion - synthesize immediately for low latency # Check for sentence completion - synthesize immediately for low latency
@@ -462,9 +536,10 @@ class DuplexPipeline:
# Send track start on first audio # Send track start on first audio
if not first_audio_sent: if not first_audio_sent:
await self.transport.send_event({ await self.transport.send_event({
"event": "trackStart", **ev(
"trackId": self.session_id, "output.audio.start",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
first_audio_sent = True first_audio_sent = True
@@ -476,20 +551,21 @@ class DuplexPipeline:
# Send final LLM response event # Send final LLM response event
if full_response and not self._interrupt_event.is_set(): if full_response and not self._interrupt_event.is_set():
await self.transport.send_event({ await self.transport.send_event({
"event": "llmResponse", **ev(
"trackId": self.session_id, "assistant.response.final",
"text": full_response, trackId=self.session_id,
"isFinal": True, text=full_response,
"timestamp": self._get_timestamp_ms() )
}) })
# Speak any remaining text # Speak any remaining text
if sentence_buffer.strip() and not self._interrupt_event.is_set(): if sentence_buffer.strip() and not self._interrupt_event.is_set():
if not first_audio_sent: if not first_audio_sent:
await self.transport.send_event({ await self.transport.send_event({
"event": "trackStart", **ev(
"trackId": self.session_id, "output.audio.start",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
first_audio_sent = True first_audio_sent = True
await self._speak_sentence(sentence_buffer.strip()) await self._speak_sentence(sentence_buffer.strip())
@@ -497,9 +573,10 @@ class DuplexPipeline:
# Send track end # Send track end
if first_audio_sent: if first_audio_sent:
await self.transport.send_event({ await self.transport.send_event({
"event": "trackEnd", **ev(
"trackId": self.session_id, "output.audio.end",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
# End assistant turn # End assistant turn
@@ -545,10 +622,11 @@ class DuplexPipeline:
# Send TTFB event to client # Send TTFB event to client
await self.transport.send_event({ await self.transport.send_event({
"event": "ttfb", **ev(
"trackId": self.session_id, "metrics.ttfb",
"timestamp": self._get_timestamp_ms(), trackId=self.session_id,
"latencyMs": round(ttfb_ms) latencyMs=round(ttfb_ms),
)
}) })
# Double-check interrupt right before sending audio # Double-check interrupt right before sending audio
@@ -579,9 +657,10 @@ class DuplexPipeline:
# Send track start event # Send track start event
await self.transport.send_event({ await self.transport.send_event({
"event": "trackStart", **ev(
"trackId": self.session_id, "output.audio.start",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
self._is_bot_speaking = True self._is_bot_speaking = True
@@ -600,10 +679,11 @@ class DuplexPipeline:
# Send TTFB event to client # Send TTFB event to client
await self.transport.send_event({ await self.transport.send_event({
"event": "ttfb", **ev(
"trackId": self.session_id, "metrics.ttfb",
"timestamp": self._get_timestamp_ms(), trackId=self.session_id,
"latencyMs": round(ttfb_ms) latencyMs=round(ttfb_ms),
)
}) })
# Send audio to client # Send audio to client
@@ -614,9 +694,10 @@ class DuplexPipeline:
# Send track end event # Send track end event
await self.transport.send_event({ await self.transport.send_event({
"event": "trackEnd", **ev(
"trackId": self.session_id, "output.audio.end",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
except asyncio.CancelledError: except asyncio.CancelledError:
@@ -646,9 +727,10 @@ class DuplexPipeline:
# Send interrupt event to client IMMEDIATELY # Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio # This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self.transport.send_event({ await self.transport.send_event({
"event": "interrupt", **ev(
"trackId": self.session_id, "response.interrupted",
"timestamp": self._get_timestamp_ms() trackId=self.session_id,
)
}) })
# Cancel TTS # Cancel TTS

View File

@@ -2,13 +2,31 @@
import uuid import uuid
import json import json
from enum import Enum
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from loguru import logger from loguru import logger
from core.transports import BaseTransport from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline from core.duplex_pipeline import DuplexPipeline
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
from app.config import settings from app.config import settings
from models.ws_v1 import (
parse_client_message,
ev,
HelloMessage,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
)
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
class Session: class Session:
@@ -41,8 +59,11 @@ class Session:
# Session state # Session state
self.created_at = None self.created_at = None
self.state = "created" # created, invited, accepted, ringing, hungup self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO
self._pipeline_started = False self._pipeline_started = False
self.protocol_version: Optional[str] = None
self.authenticated: bool = False
# Track IDs # Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4()) self.current_track_id: Optional[str] = str(uuid.uuid4())
@@ -51,69 +72,27 @@ class Session:
async def handle_text(self, text_data: str) -> None: async def handle_text(self, text_data: str) -> None:
""" """
Handle incoming text data (JSON commands). Handle incoming text data (WS v1 JSON control messages).
Args: Args:
text_data: JSON text data text_data: JSON text data
""" """
try: try:
data = json.loads(text_data) data = json.loads(text_data)
command = parse_command(data) message = parse_client_message(data)
command_type = command.command await self._handle_v1_message(message)
logger.info(f"Session {self.id} received command: {command_type}")
# Route command to appropriate handler
if command_type == "invite":
await self._handle_invite(data)
elif command_type == "accept":
await self._handle_accept(data)
elif command_type == "reject":
await self._handle_reject(data)
elif command_type == "ringing":
await self._handle_ringing(data)
elif command_type == "tts":
await self._handle_tts(command)
elif command_type == "play":
await self._handle_play(data)
elif command_type == "interrupt":
await self._handle_interrupt(command)
elif command_type == "pause":
await self._handle_pause()
elif command_type == "resume":
await self._handle_resume()
elif command_type == "hangup":
await self._handle_hangup(command)
elif command_type == "history":
await self._handle_history(data)
elif command_type == "chat":
await self._handle_chat(command)
else:
logger.warning(f"Session {self.id} unknown command: {command_type}")
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Session {self.id} JSON decode error: {e}") logger.error(f"Session {self.id} JSON decode error: {e}")
await self._send_error("client", f"Invalid JSON: {e}") await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
except ValueError as e: except ValueError as e:
logger.error(f"Session {self.id} command parse error: {e}") logger.error(f"Session {self.id} command parse error: {e}")
await self._send_error("client", f"Invalid command: {e}") await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
except Exception as e: except Exception as e:
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
await self._send_error("server", f"Internal error: {e}") await self._send_error("server", f"Internal error: {e}", "server.internal")
async def handle_audio(self, audio_bytes: bytes) -> None: async def handle_audio(self, audio_bytes: bytes) -> None:
""" """
@@ -122,156 +101,166 @@ class Session:
Args: Args:
audio_bytes: PCM audio data audio_bytes: PCM audio data
""" """
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
"Audio received before session.start",
"protocol.order",
)
return
try: try:
await self.pipeline.process_audio(audio_bytes) await self.pipeline.process_audio(audio_bytes)
except Exception as e: except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
async def _handle_invite(self, data: Dict[str, Any]) -> None: async def _handle_v1_message(self, message: Any) -> None:
"""Handle invite command.""" """Route validated WS v1 message to handlers."""
self.state = "invited" msg_type = message.type
option = data.get("option", {}) logger.info(f"Session {self.id} received message: {msg_type}")
# Send answer event if isinstance(message, HelloMessage):
await self.transport.send_event({ await self._handle_hello(message)
"event": "answer", return
"trackId": self.current_track_id,
"timestamp": self._get_timestamp_ms() # All messages below require hello handshake first
}) if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
f"Message '{msg_type}' requires active session",
"protocol.order",
)
return
if isinstance(message, InputTextMessage):
await self.pipeline.process_text(message.text)
elif isinstance(message, ResponseCancelMessage):
if message.graceful:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/auth/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
if settings.ws_api_key:
if api_key != settings.ws_api_key:
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
elif settings.ws_require_auth and not (api_key or jwt):
await self._send_error("auth", "Authentication required", "auth.required")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
# Apply runtime service/prompt overrides from backend if provided
self.pipeline.apply_runtime_overrides(message.metadata)
# Start duplex pipeline # Start duplex pipeline
if not self._pipeline_started: if not self._pipeline_started:
try: await self.pipeline.start()
await self.pipeline.start() self._pipeline_started = True
self._pipeline_started = True logger.info(f"Session {self.id} duplex pipeline started")
logger.info(f"Session {self.id} duplex pipeline started")
except Exception as e:
logger.error(f"Failed to start duplex pipeline: {e}")
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
async def _handle_accept(self, data: Dict[str, Any]) -> None:
"""Handle accept command."""
self.state = "accepted" self.state = "accepted"
logger.info(f"Session {self.id} accepted") self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
)
)
async def _handle_reject(self, data: Dict[str, Any]) -> None: async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle reject command.""" """Handle session stop."""
self.state = "rejected" if self.ws_state == WsSessionState.STOPPED:
reason = data.get("reason", "Rejected") return
logger.info(f"Session {self.id} rejected: {reason}")
async def _handle_ringing(self, data: Dict[str, Any]) -> None: stop_reason = reason or "client_requested"
"""Handle ringing command."""
self.state = "ringing"
logger.info(f"Session {self.id} ringing")
async def _handle_tts(self, command: TTSCommand) -> None:
"""Handle TTS command."""
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
# Send track start event
await self.transport.send_event({
"event": "trackStart",
"trackId": self.current_track_id,
"timestamp": self._get_timestamp_ms(),
"playId": command.play_id
})
# TODO: Implement actual TTS synthesis
# For now, just send track end event
await self.transport.send_event({
"event": "trackEnd",
"trackId": self.current_track_id,
"timestamp": self._get_timestamp_ms(),
"duration": 1000,
"ssrc": 0,
"playId": command.play_id
})
async def _handle_play(self, data: Dict[str, Any]) -> None:
"""Handle play command."""
url = data.get("url", "")
logger.info(f"Session {self.id} play: {url}")
# Send track start event
await self.transport.send_event({
"event": "trackStart",
"trackId": self.current_track_id,
"timestamp": self._get_timestamp_ms(),
"playId": url
})
# TODO: Implement actual audio playback
# For now, just send track end event
await self.transport.send_event({
"event": "trackEnd",
"trackId": self.current_track_id,
"timestamp": self._get_timestamp_ms(),
"duration": 1000,
"ssrc": 0,
"playId": url
})
async def _handle_interrupt(self, command: InterruptCommand) -> None:
"""Handle interrupt command."""
if command.graceful:
logger.info(f"Session {self.id} graceful interrupt")
else:
logger.info(f"Session {self.id} immediate interrupt")
await self.pipeline.interrupt()
async def _handle_pause(self) -> None:
"""Handle pause command."""
logger.info(f"Session {self.id} paused")
async def _handle_resume(self) -> None:
"""Handle resume command."""
logger.info(f"Session {self.id} resumed")
async def _handle_hangup(self, command: HangupCommand) -> None:
"""Handle hangup command."""
self.state = "hungup" self.state = "hungup"
reason = command.reason or "User requested" self.ws_state = WsSessionState.STOPPED
logger.info(f"Session {self.id} hung up: {reason}") await self.transport.send_event(
ev(
# Send hangup event "session.stopped",
await self.transport.send_event({ sessionId=self.id,
"event": "hangup", reason=stop_reason,
"timestamp": self._get_timestamp_ms(), )
"reason": reason, )
"initiator": command.initiator or "user"
})
# Close transport
await self.transport.close() await self.transport.close()
async def _handle_history(self, data: Dict[str, Any]) -> None: async def _send_error(self, sender: str, error_message: str, code: str) -> None:
"""Handle history command."""
speaker = data.get("speaker", "unknown")
text = data.get("text", "")
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
async def _handle_chat(self, command: ChatCommand) -> None:
"""Handle chat command."""
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
await self.pipeline.process_text(command.text)
async def _send_error(self, sender: str, error_message: str) -> None:
""" """
Send error event to client. Send error event to client.
Args: Args:
sender: Component that generated the error sender: Component that generated the error
error_message: Error message error_message: Error message
code: Machine-readable error code
""" """
await self.transport.send_event({ await self.transport.send_event(
"event": "error", ev(
"trackId": self.current_track_id, "error",
"timestamp": self._get_timestamp_ms(), sender=sender,
"sender": sender, code=code,
"error": error_message message=error_message,
}) trackId=self.current_track_id,
)
)
def _get_timestamp_ms(self) -> int: def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds.""" """Get current timestamp in milliseconds."""

169
engine/docs/ws_v1_schema.md Normal file
View File

@@ -0,0 +1,169 @@
# WS v1 Protocol Schema (`/ws`)
This document defines the public WebSocket protocol for the `/ws` endpoint.
## Transport
- A single WebSocket connection carries:
- JSON text frames for control/events.
- Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default).
## Handshake and State Machine
Required message order:
1. Client sends `hello`.
2. Server replies `hello.ack`.
3. Client sends `session.start`.
4. Server replies `session.started`.
5. Client may stream binary audio and/or send `input.text`.
6. Client sends `session.stop` (or closes socket).
If order is violated, server emits `error` with `code = "protocol.order"`.
## Client -> Server Messages
### `hello`
```json
{
"type": "hello",
"version": "v1",
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
}
```
Rules:
- `version` must be `v1`.
- If `WS_API_KEY` is configured on server, `auth.apiKey` must match.
- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present.
### `session.start`
```json
{
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": 16000,
"channels": 1
},
"metadata": {
"client": "web-debug",
"systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?",
"services": {
"llm": {
"provider": "openai",
"model": "gpt-4o-mini",
"apiKey": "sk-...",
"baseUrl": "https://api.openai.com/v1"
},
"asr": {
"provider": "siliconflow",
"model": "FunAudioLLM/SenseVoiceSmall",
"apiKey": "sf-...",
"interimIntervalMs": 500,
"minAudioMs": 300
},
"tts": {
"provider": "siliconflow",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"apiKey": "sf-...",
"voice": "anna",
"speed": 1.0
}
}
}
}
```
`metadata.services` is optional. If omitted, server defaults to environment configuration.
### `input.text`
```json
{
"type": "input.text",
"text": "What can you do?"
}
```
### `response.cancel`
```json
{
"type": "response.cancel",
"graceful": false
}
```
### `session.stop`
```json
{
"type": "session.stop",
"reason": "client_disconnect"
}
```
## Server -> Client Events
All server events include:
```json
{
"type": "event.name",
"timestamp": 1730000000000
}
```
Common events:
- `hello.ack`
- Fields: `sessionId`, `version`
- `session.started`
- Fields: `sessionId`, `trackId`, `audio`
- `session.stopped`
- Fields: `sessionId`, `reason`
- `heartbeat`
- `input.speech_started`
- Fields: `trackId`, `probability`
- `input.speech_stopped`
- Fields: `trackId`, `probability`
- `transcript.delta`
- Fields: `trackId`, `text`
- `transcript.final`
- Fields: `trackId`, `text`
- `assistant.response.delta`
- Fields: `trackId`, `text`
- `assistant.response.final`
- Fields: `trackId`, `text`
- `output.audio.start`
- Fields: `trackId`
- `output.audio.end`
- Fields: `trackId`
- `response.interrupted`
- Fields: `trackId`
- `metrics.ttfb`
- Fields: `trackId`, `latencyMs`
- `error`
- Fields: `sender`, `code`, `message`, `trackId`
## Binary Audio Frames
After `session.started`, client may send binary PCM chunks continuously.
Recommended format:
- 16-bit signed little-endian PCM.
- 1 channel.
- 16000 Hz.
- 20ms frames (640 bytes) preferred.
## Compatibility
This endpoint now enforces v1 message schema for JSON control frames.
Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol.

View File

@@ -36,7 +36,7 @@ def generate_sine_wave(duration_ms=1000):
return audio_data return audio_data
async def receive_loop(ws): async def receive_loop(ws, ready_event: asyncio.Event):
"""Listen for incoming messages from the server.""" """Listen for incoming messages from the server."""
print("👂 Listening for server responses...") print("👂 Listening for server responses...")
async for msg in ws: async for msg in ws:
@@ -45,8 +45,10 @@ async def receive_loop(ws):
if msg.type == aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.TEXT:
try: try:
data = json.loads(msg.data) data = json.loads(msg.data)
event_type = data.get('event', 'Unknown') event_type = data.get('type', 'Unknown')
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...") print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
if event_type == "session.started":
ready_event.set()
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...") print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
@@ -118,35 +120,43 @@ async def run_client(url, file_path=None, use_sine=False):
print(f"🔌 Connecting to {url}...") print(f"🔌 Connecting to {url}...")
async with session.ws_connect(url) as ws: async with session.ws_connect(url) as ws:
print("✅ Connected!") print("✅ Connected!")
session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
# Send initial invite command # Send v1 hello + session.start handshake
init_cmd = { await ws.send_json({"type": "hello", "version": "v1"})
"command": "invite", await ws.send_json({
"option": { "type": "session.start",
"codec": "pcm", "audio": {
"samplerate": SAMPLE_RATE "encoding": "pcm_s16le",
"sample_rate_hz": SAMPLE_RATE,
"channels": 1
} }
} })
await ws.send_json(init_cmd) print("📤 Sent v1 hello/session.start")
print("📤 Sent Invite Command") await asyncio.wait_for(session_ready.wait(), timeout=8)
# Select sender based on args # Select sender based on args
if use_sine: if use_sine:
sender_task = send_sine_loop(ws) await send_sine_loop(ws)
elif file_path: elif file_path:
sender_task = send_file_loop(ws, file_path) await send_file_loop(ws, file_path)
else: else:
# Default to sine wave # Default to sine wave
sender_task = send_sine_loop(ws) await send_sine_loop(ws)
# Run send and receive loops in parallel await ws.send_json({"type": "session.stop", "reason": "test_complete"})
await asyncio.gather( await asyncio.sleep(1)
receive_loop(ws), recv_task.cancel()
sender_task try:
) await recv_task
except asyncio.CancelledError:
pass
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
print(f"❌ Connection Failed. Is the server running at {url}?") print(f"❌ Connection Failed. Is the server running at {url}?")
except asyncio.TimeoutError:
print("❌ Timeout waiting for session.started")
except Exception as e: except Exception as e:
print(f"❌ Error: {e}") print(f"❌ Error: {e}")
finally: finally:

View File

@@ -547,7 +547,7 @@
setStatus(true, "Session open"); setStatus(true, "Session open");
logLine("sys", "WebSocket connected"); logLine("sys", "WebSocket connected");
ensureAudioContext(); ensureAudioContext();
sendCommand({ command: "invite", option: { codec: "pcm", sampleRate: targetSampleRate } }); sendCommand({ type: "hello", version: "v1" });
}; };
ws.onclose = () => { ws.onclose = () => {
@@ -574,7 +574,10 @@
} }
function disconnect() { function disconnect() {
if (ws) ws.close(); if (ws && ws.readyState === WebSocket.OPEN) {
sendCommand({ type: "session.stop", reason: "client_disconnect" });
ws.close();
}
ws = null; ws = null;
setStatus(false, "Disconnected"); setStatus(false, "Disconnected");
} }
@@ -585,40 +588,48 @@
return; return;
} }
ws.send(JSON.stringify(cmd)); ws.send(JSON.stringify(cmd));
logLine("sys", `${cmd.command}`, cmd); logLine("sys", `${cmd.type}`, cmd);
} }
function handleEvent(event) { function handleEvent(event) {
const type = event.event || "unknown"; const type = event.type || "unknown";
logLine("event", type, event); logLine("event", type, event);
if (type === "transcript") { if (type === "hello.ack") {
if (event.isFinal && event.text) { sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
});
}
if (type === "transcript.final") {
if (event.text) {
setInterim("You", ""); setInterim("You", "");
addChat("You", event.text); addChat("You", event.text);
} else if (event.text) {
interimUserText += event.text;
setInterim("You", interimUserText);
} }
} }
if (type === "llmResponse") { if (type === "transcript.delta" && event.text) {
if (event.isFinal && event.text) { setInterim("You", event.text);
}
if (type === "assistant.response.final") {
if (event.text) {
setInterim("AI", ""); setInterim("AI", "");
addChat("AI", event.text); addChat("AI", event.text);
} else if (event.text) {
interimAiText += event.text;
setInterim("AI", interimAiText);
} }
} }
if (type === "trackStart") { if (type === "assistant.response.delta" && event.text) {
interimAiText += event.text;
setInterim("AI", interimAiText);
}
if (type === "output.audio.start") {
// New bot audio: stop any previous playback to avoid overlap // New bot audio: stop any previous playback to avoid overlap
stopPlayback(); stopPlayback();
discardAudio = false; discardAudio = false;
interimAiText = "";
} }
if (type === "speaking") { if (type === "input.speech_started") {
// User started speaking: clear any in-flight audio to avoid overlap // User started speaking: clear any in-flight audio to avoid overlap
stopPlayback(); stopPlayback();
} }
if (type === "interrupt") { if (type === "response.interrupted") {
stopPlayback(); stopPlayback();
} }
} }
@@ -716,7 +727,7 @@
if (!text) return; if (!text) return;
ensureAudioContext(); ensureAudioContext();
addChat("You", text); addChat("You", text);
sendCommand({ command: "chat", text }); sendCommand({ type: "input.text", text });
chatInput.value = ""; chatInput.value = "";
}); });
clearLogBtn.addEventListener("click", () => { clearLogBtn.addEventListener("click", () => {

67
engine/models/ws_v1.py Normal file
View File

@@ -0,0 +1,67 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
def now_ms() -> int:
"""Current unix timestamp in milliseconds."""
import time
return int(time.time() * 1000)
# Client -> Server messages
class HelloMessage(BaseModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
class SessionStartMessage(BaseModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
type: Literal["response.cancel"]
graceful: bool = False
CLIENT_MESSAGE_TYPES = {
"hello": HelloMessage,
"session.start": SessionStartMessage,
"session.stop": SessionStopMessage,
"input.text": InputTextMessage,
"response.cancel": ResponseCancelMessage,
}
def parse_client_message(data: Dict[str, Any]) -> BaseModel:
"""Parse and validate a WS v1 client message."""
msg_type = data.get("type")
if not msg_type:
raise ValueError("Missing 'type' field")
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data)
# Server -> Client event helpers
def ev(event_type: str, **payload: Any) -> Dict[str, Any]:
"""Create a WS v1 server event payload."""
base = {"type": event_type, "timestamp": now_ms()}
base.update(payload)
return base

View File

@@ -4,8 +4,7 @@ import { Plus, Search, Play, Copy, Trash2, Edit2, Mic, MessageSquare, Save, Vide
import { Button, Input, Card, Badge, Drawer, Dialog } from '../components/UI'; import { Button, Input, Card, Badge, Drawer, Dialog } from '../components/UI';
import { mockLLMModels, mockASRModels } from '../services/mockData'; import { mockLLMModels, mockASRModels } from '../services/mockData';
import { Assistant, KnowledgeBase, TabValue, Voice } from '../types'; import { Assistant, KnowledgeBase, TabValue, Voice } from '../types';
import { GoogleGenAI } from "@google/genai"; import { createAssistant, deleteAssistant, fetchAssistantRuntimeConfig, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi';
import { createAssistant, deleteAssistant, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi';
interface ToolItem { interface ToolItem {
id: string; id: string;
@@ -1022,11 +1021,26 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
const [inputText, setInputText] = useState(''); const [inputText, setInputText] = useState('');
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle'); const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle');
const [wsStatus, setWsStatus] = useState<'disconnected' | 'connecting' | 'ready' | 'error'>('disconnected');
const [wsError, setWsError] = useState('');
const [resolvedConfigOpen, setResolvedConfigOpen] = useState(false);
const [resolvedConfigView, setResolvedConfigView] = useState<string>('');
const [wsUrl, setWsUrl] = useState<string>(() => {
const fromStorage = localStorage.getItem('debug_ws_url');
if (fromStorage) return fromStorage;
const defaultHost = window.location.hostname || 'localhost';
return `ws://${defaultHost}:8000/ws`;
});
// Media State // Media State
const videoRef = useRef<HTMLVideoElement>(null); const videoRef = useRef<HTMLVideoElement>(null);
const streamRef = useRef<MediaStream | null>(null); const streamRef = useRef<MediaStream | null>(null);
const scrollRef = useRef<HTMLDivElement>(null); const scrollRef = useRef<HTMLDivElement>(null);
const wsRef = useRef<WebSocket | null>(null);
const wsReadyRef = useRef(false);
const pendingResolveRef = useRef<(() => void) | null>(null);
const pendingRejectRef = useRef<((e: Error) => void) | null>(null);
const assistantDraftIndexRef = useRef<number | null>(null);
const [devices, setDevices] = useState<MediaDeviceInfo[]>([]); const [devices, setDevices] = useState<MediaDeviceInfo[]>([]);
const [selectedCamera, setSelectedCamera] = useState<string>(''); const [selectedCamera, setSelectedCamera] = useState<string>('');
@@ -1045,11 +1059,16 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
} else { } else {
setMode('text'); setMode('text');
stopMedia(); stopMedia();
closeWs();
setIsSwapped(false); setIsSwapped(false);
setCallStatus('idle'); setCallStatus('idle');
} }
}, [isOpen, assistant, mode]); }, [isOpen, assistant, mode]);
useEffect(() => {
localStorage.setItem('debug_ws_url', wsUrl);
}, [wsUrl]);
// Auto-scroll logic // Auto-scroll logic
useEffect(() => { useEffect(() => {
if (scrollRef.current) { if (scrollRef.current) {
@@ -1132,29 +1151,196 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
setIsLoading(true); setIsLoading(true);
try { try {
if (process.env.API_KEY) { if (mode === 'text') {
const ai = new GoogleGenAI({ apiKey: process.env.API_KEY }); await ensureWsSession();
const chat = ai.chats.create({ wsRef.current?.send(JSON.stringify({ type: 'input.text', text: userMsg }));
model: "gemini-3-flash-preview",
config: { systemInstruction: assistant.prompt },
history: messages.map(m => ({ role: m.role, parts: [{ text: m.text }] }))
});
const result = await chat.sendMessage({ message: userMsg });
setMessages(prev => [...prev, { role: 'model', text: result.text || '' }]);
} else { } else {
setTimeout(() => { setTimeout(() => {
setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]); setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]);
setIsLoading(false); setIsLoading(false);
}, 1000); }, 1000);
} }
} catch (e) { } catch (e) {
console.error(e); console.error(e);
setMessages(prev => [...prev, { role: 'model', text: "Error: Failed to connect to AI service." }]); setMessages(prev => [...prev, { role: 'model', text: "Error: Failed to connect to AI service." }]);
} finally {
setIsLoading(false); setIsLoading(false);
} finally {
if (mode !== 'text') setIsLoading(false);
} }
}; };
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
try {
const resolved = await fetchAssistantRuntimeConfig(assistant.id);
setResolvedConfigView(
JSON.stringify(
{
assistantId: resolved.assistantId,
sources: resolved.sources,
warnings: resolved.warnings,
sessionStartMetadata: resolved.sessionStartMetadata || {},
},
null,
2,
),
);
return resolved.sessionStartMetadata || {};
} catch (error) {
console.error('Failed to load runtime config, using fallback.', error);
const fallback = {
systemPrompt: assistant.prompt || '',
greeting: assistant.opener || '',
services: {},
};
setResolvedConfigView(
JSON.stringify(
{
assistantId: assistant.id,
sources: {},
warnings: ['runtime-config endpoint failed; using local fallback'],
sessionStartMetadata: fallback,
},
null,
2,
),
);
return fallback;
}
};
const closeWs = () => {
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: 'session.stop', reason: 'debug_drawer_closed' }));
}
wsRef.current?.close();
wsRef.current = null;
wsReadyRef.current = false;
pendingResolveRef.current = null;
pendingRejectRef.current = null;
assistantDraftIndexRef.current = null;
if (isOpen) setWsStatus('disconnected');
};
const ensureWsSession = async () => {
if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) {
return;
}
if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
await new Promise<void>((resolve, reject) => {
pendingResolveRef.current = resolve;
pendingRejectRef.current = reject;
});
return;
}
const metadata = await fetchRuntimeMetadata();
setWsStatus('connecting');
setWsError('');
await new Promise<void>((resolve, reject) => {
pendingResolveRef.current = resolve;
pendingRejectRef.current = reject;
const ws = new WebSocket(wsUrl);
ws.binaryType = 'arraybuffer';
wsRef.current = ws;
ws.onopen = () => {
ws.send(JSON.stringify({ type: 'hello', version: 'v1' }));
};
ws.onmessage = (event) => {
if (typeof event.data !== 'string') return;
let payload: any;
try {
payload = JSON.parse(event.data);
} catch {
return;
}
const type = payload?.type;
if (type === 'hello.ack') {
ws.send(
JSON.stringify({
type: 'session.start',
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
metadata,
})
);
return;
}
if (type === 'session.started') {
wsReadyRef.current = true;
setWsStatus('ready');
pendingResolveRef.current?.();
pendingResolveRef.current = null;
pendingRejectRef.current = null;
return;
}
if (type === 'assistant.response.delta') {
const delta = String(payload.text || '');
if (!delta) return;
setMessages((prev) => {
const idx = assistantDraftIndexRef.current;
if (idx === null || !prev[idx] || prev[idx].role !== 'model') {
const next = [...prev, { role: 'model' as const, text: delta }];
assistantDraftIndexRef.current = next.length - 1;
return next;
}
const next = [...prev];
next[idx] = { ...next[idx], text: next[idx].text + delta };
return next;
});
return;
}
if (type === 'assistant.response.final') {
const finalText = String(payload.text || '');
setMessages((prev) => {
const idx = assistantDraftIndexRef.current;
assistantDraftIndexRef.current = null;
if (idx !== null && prev[idx] && prev[idx].role === 'model') {
const next = [...prev];
next[idx] = { ...next[idx], text: finalText || next[idx].text };
return next;
}
return finalText ? [...prev, { role: 'model', text: finalText }] : prev;
});
setIsLoading(false);
return;
}
if (type === 'error') {
const message = String(payload.message || 'Unknown error');
setWsStatus('error');
setWsError(message);
setIsLoading(false);
const err = new Error(message);
pendingRejectRef.current?.(err);
pendingResolveRef.current = null;
pendingRejectRef.current = null;
}
};
ws.onerror = () => {
const err = new Error('WebSocket connection error');
setWsStatus('error');
setWsError(err.message);
setIsLoading(false);
pendingRejectRef.current?.(err);
pendingResolveRef.current = null;
pendingRejectRef.current = null;
};
ws.onclose = () => {
wsReadyRef.current = false;
if (wsStatus !== 'error') setWsStatus('disconnected');
};
});
};
const TranscriptionLog = () => ( const TranscriptionLog = () => (
<div ref={scrollRef} className="flex-1 overflow-y-auto space-y-4 p-2 border border-white/5 rounded-md bg-black/20 min-h-0"> <div ref={scrollRef} className="flex-1 overflow-y-auto space-y-4 p-2 border border-white/5 rounded-md bg-black/20 min-h-0">
{messages.length === 0 && <div className="text-center text-muted-foreground text-xs py-4"></div>} {messages.length === 0 && <div className="text-center text-muted-foreground text-xs py-4"></div>}
@@ -1205,7 +1391,38 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
<div className="flex-1 overflow-hidden flex flex-col min-h-0 mb-4"> <div className="flex-1 overflow-hidden flex flex-col min-h-0 mb-4">
{mode === 'text' ? ( {mode === 'text' ? (
<TranscriptionLog /> <div className="flex flex-col gap-2 h-full min-h-0">
<div className="shrink-0 rounded-md border border-white/10 bg-white/5 p-2 grid grid-cols-1 md:grid-cols-3 gap-2">
<Input value={wsUrl} onChange={(e) => setWsUrl(e.target.value)} placeholder="ws://localhost:8000/ws" />
<div className="md:col-span-2 flex items-center gap-2">
<Badge variant="outline" className="text-xs">
WS: {wsStatus}
</Badge>
<Button size="sm" variant="secondary" onClick={() => ensureWsSession()} disabled={wsStatus === 'connecting'}>
Connect
</Button>
<Button size="sm" variant="ghost" onClick={closeWs}>
Disconnect
</Button>
{wsError && <span className="text-xs text-red-400 truncate">{wsError}</span>}
</div>
<div className="md:col-span-3 rounded-md border border-white/10 bg-black/20">
<button
className="w-full px-3 py-2 text-left text-xs text-muted-foreground hover:text-foreground flex items-center justify-between"
onClick={() => setResolvedConfigOpen((v) => !v)}
>
<span>View Resolved Runtime Config (read-only)</span>
<ChevronDown className={`h-3.5 w-3.5 transition-transform ${resolvedConfigOpen ? 'rotate-180' : ''}`} />
</button>
{resolvedConfigOpen && (
<pre className="px-3 pb-3 text-[11px] leading-5 text-cyan-100/90 whitespace-pre-wrap break-all max-h-52 overflow-auto">
{resolvedConfigView || 'Connect to load resolved config...'}
</pre>
)}
</div>
</div>
<TranscriptionLog />
</div>
) : callStatus === 'idle' ? ( ) : callStatus === 'idle' ? (
<div className="flex-1 flex flex-col items-center justify-center space-y-6 border border-white/5 rounded-xl bg-black/20 animate-in fade-in zoom-in-95"> <div className="flex-1 flex flex-col items-center justify-center space-y-6 border border-white/5 rounded-xl bg-black/20 animate-in fade-in zoom-in-95">
<div className="relative"> <div className="relative">

View File

@@ -41,6 +41,10 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({
configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none', configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none',
apiUrl: readField(raw, ['apiUrl', 'api_url'], ''), apiUrl: readField(raw, ['apiUrl', 'api_url'], ''),
apiKey: readField(raw, ['apiKey', 'api_key'], ''), apiKey: readField(raw, ['apiKey', 'api_key'], ''),
llmModelId: readField(raw, ['llmModelId', 'llm_model_id'], ''),
asrModelId: readField(raw, ['asrModelId', 'asr_model_id'], ''),
embeddingModelId: readField(raw, ['embeddingModelId', 'embedding_model_id'], ''),
rerankModelId: readField(raw, ['rerankModelId', 'rerank_model_id'], ''),
}); });
const mapVoice = (raw: AnyRecord): Voice => ({ const mapVoice = (raw: AnyRecord): Voice => ({
@@ -218,6 +222,10 @@ export const updateAssistant = async (id: string, data: Partial<Assistant>): Pro
configMode: data.configMode, configMode: data.configMode,
apiUrl: data.apiUrl, apiUrl: data.apiUrl,
apiKey: data.apiKey, apiKey: data.apiKey,
llmModelId: data.llmModelId,
asrModelId: data.asrModelId,
embeddingModelId: data.embeddingModelId,
rerankModelId: data.rerankModelId,
}; };
const response = await apiRequest<AnyRecord>(`/assistants/${id}`, { method: 'PUT', body: payload }); const response = await apiRequest<AnyRecord>(`/assistants/${id}`, { method: 'PUT', body: payload });
return mapAssistant(response); return mapAssistant(response);
@@ -227,6 +235,21 @@ export const deleteAssistant = async (id: string): Promise<void> => {
await apiRequest(`/assistants/${id}`, { method: 'DELETE' }); await apiRequest(`/assistants/${id}`, { method: 'DELETE' });
}; };
export interface AssistantRuntimeConfigResponse {
assistantId: string;
sessionStartMetadata: Record<string, any>;
sources?: {
llmModelId?: string;
asrModelId?: string;
voiceId?: string;
};
warnings?: string[];
}
export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise<AssistantRuntimeConfigResponse> => {
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/runtime-config`);
};
export const fetchVoices = async (): Promise<Voice[]> => { export const fetchVoices = async (): Promise<Voice[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices'); const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices');
const list = Array.isArray(response) ? response : (response.list || []); const list = Array.isArray(response) ? response : (response.list || []);