diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index 4949d06..44ae3ba 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime from ..db import get_db -from ..models import Assistant +from ..models import Assistant, LLMModel, ASRModel, Voice from ..schemas import ( AssistantCreate, AssistantUpdate, AssistantOut ) @@ -13,6 +13,73 @@ from ..schemas import ( router = APIRouter(prefix="/assistants", tags=["Assistants"]) +def _is_siliconflow_vendor(vendor: Optional[str]) -> bool: + return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"} + + +def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: + metadata = { + "systemPrompt": assistant.prompt or "", + "greeting": assistant.opener or "", + "services": {}, + } + warnings = [] + + if assistant.llm_model_id: + llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first() + if llm: + metadata["services"]["llm"] = { + "provider": "openai", + "model": llm.model_name or llm.name, + "apiKey": llm.api_key, + "baseUrl": llm.base_url, + } + else: + warnings.append(f"LLM model not found: {assistant.llm_model_id}") + + if assistant.asr_model_id: + asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first() + if asr: + asr_provider = "siliconflow" if _is_siliconflow_vendor(asr.vendor) else "buffered" + metadata["services"]["asr"] = { + "provider": asr_provider, + "model": asr.model_name or asr.name, + "apiKey": asr.api_key if asr_provider == "siliconflow" else None, + } + else: + warnings.append(f"ASR model not found: {assistant.asr_model_id}") + + if assistant.voice: + voice = db.query(Voice).filter(Voice.id == assistant.voice).first() + if voice: + tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge" + metadata["services"]["tts"] = { + "provider": tts_provider, + "model": voice.model, + "apiKey": voice.api_key if tts_provider == "siliconflow" else None, + "voice": voice.voice_key or voice.id, + "speed": assistant.speed or voice.speed, + } + else: + # Keep assistant.voice as direct voice identifier fallback + metadata["services"]["tts"] = { + "voice": assistant.voice, + "speed": assistant.speed or 1.0, + } + warnings.append(f"Voice resource not found: {assistant.voice}") + + return { + "assistantId": assistant.id, + "sessionStartMetadata": metadata, + "sources": { + "llmModelId": assistant.llm_model_id, + "asrModelId": assistant.asr_model_id, + "voiceId": assistant.voice, + }, + "warnings": warnings, + } + + def assistant_to_dict(assistant: Assistant) -> dict: return { "id": assistant.id, @@ -84,6 +151,15 @@ def get_assistant(id: str, db: Session = Depends(get_db)): return assistant_to_dict(assistant) +@router.get("/{id}/runtime-config") +def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)): + """Resolve assistant runtime config for engine WS session.start metadata.""" + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + return _resolve_runtime_metadata(db, assistant) + + @router.post("", response_model=AssistantOut) def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): """创建新助手""" @@ -139,4 +215,3 @@ def delete_assistant(id: str, db: Session = Depends(get_db)): db.delete(assistant) db.commit() return {"message": "Deleted successfully"} - diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index fd10704..9d6b200 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -166,3 +166,39 @@ class TestAssistantAPI: response = client.post("/api/assistants", json=sample_assistant_data) assert response.status_code == 200 assert response.json()["language"] == lang + + def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data): + """Test resolved runtime config endpoint for WS session.start metadata.""" + llm_resp = client.post("/api/llm", json=sample_llm_model_data) + assert llm_resp.status_code == 200 + + asr_resp = client.post("/api/asr", json=sample_asr_model_data) + assert asr_resp.status_code == 200 + + voice_resp = client.post("/api/voices", json=sample_voice_data) + assert voice_resp.status_code == 200 + voice_id = voice_resp.json()["id"] + + sample_assistant_data.update({ + "llmModelId": sample_llm_model_data["id"], + "asrModelId": sample_asr_model_data["id"], + "voice": voice_id, + "prompt": "runtime prompt", + "opener": "runtime opener", + "speed": 1.1, + }) + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + payload = runtime_resp.json() + + assert payload["assistantId"] == assistant_id + metadata = payload["sessionStartMetadata"] + assert metadata["systemPrompt"] == "runtime prompt" + assert metadata["greeting"] == "runtime opener" + assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"] + assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"] + assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"] diff --git a/engine/README.md b/engine/README.md index 6e7da04..17d9e3a 100644 --- a/engine/README.md +++ b/engine/README.md @@ -22,4 +22,10 @@ python examples/test_websocket.py ``` python mic_client.py -``` \ No newline at end of file +``` + +## WS Protocol + +`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames. + +See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`. diff --git a/engine/app/config.py b/engine/app/config.py index 689eee5..0c10fc1 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -87,6 +87,9 @@ class Settings(BaseSettings): # WebSocket heartbeat and inactivity inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)") heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds") + ws_protocol_version: str = Field(default="v1", description="Public WS protocol version") + ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth") + ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set") @property def chunk_size_bytes(self) -> int: diff --git a/engine/app/main.py b/engine/app/main.py index fa77621..593d534 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -24,6 +24,7 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport from core.session import Session from processors.tracks import Resampled16kTrack from core.events import get_event_bus, reset_event_bus +from models.ws_v1 import ev # Check interval for heartbeat/timeout (seconds) _HEARTBEAT_CHECK_INTERVAL_SEC = 5 @@ -54,8 +55,7 @@ async def heartbeat_and_timeout_task( if now - last_heartbeat_at[0] >= heartbeat_interval_sec: try: await transport.send_event({ - "event": "heartBeat", - "timestamp": int(time.time() * 1000), + **ev("heartbeat"), }) last_heartbeat_at[0] = now except Exception as e: diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 9d4a938..be31239 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -13,7 +13,7 @@ event-driven design. import asyncio import time -from typing import Optional, Callable, Awaitable +from typing import Optional, Callable, Awaitable, Dict, Any from loguru import logger from core.transports import BaseTransport @@ -28,6 +28,7 @@ from services.asr import BufferedASRService from services.siliconflow_tts import SiliconFlowTTSService from services.siliconflow_asr import SiliconFlowASRService from app.config import settings +from models.ws_v1 import ev class DuplexPipeline: @@ -126,19 +127,66 @@ class DuplexPipeline: self._barge_in_speech_frames: int = 0 # Count speech frames self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks) + + # Runtime overrides injected from session.start metadata + self._runtime_llm: Dict[str, Any] = {} + self._runtime_asr: Dict[str, Any] = {} + self._runtime_tts: Dict[str, Any] = {} + self._runtime_system_prompt: Optional[str] = None + self._runtime_greeting: Optional[str] = None logger.info(f"DuplexPipeline initialized for session {session_id}") + + def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: + """ + Apply runtime overrides from WS session.start metadata. + + Expected metadata shape: + { + "systemPrompt": "...", + "greeting": "...", + "services": { + "llm": {...}, + "asr": {...}, + "tts": {...} + } + } + """ + if not metadata: + return + + if "systemPrompt" in metadata: + self._runtime_system_prompt = str(metadata.get("systemPrompt") or "") + if self._runtime_system_prompt: + self.conversation.system_prompt = self._runtime_system_prompt + if "greeting" in metadata: + self._runtime_greeting = str(metadata.get("greeting") or "") + self.conversation.greeting = self._runtime_greeting or None + + services = metadata.get("services") or {} + if isinstance(services, dict): + if isinstance(services.get("llm"), dict): + self._runtime_llm = services["llm"] + if isinstance(services.get("asr"), dict): + self._runtime_asr = services["asr"] + if isinstance(services.get("tts"), dict): + self._runtime_tts = services["tts"] async def start(self) -> None: """Start the pipeline and connect services.""" try: # Connect LLM service if not self.llm_service: - if settings.openai_api_key: + llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key + llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url + llm_model = self._runtime_llm.get("model") or settings.llm_model + llm_provider = (self._runtime_llm.get("provider") or "openai").lower() + + if llm_provider == "openai" and llm_api_key: self.llm_service = OpenAILLMService( - api_key=settings.openai_api_key, - base_url=settings.openai_api_url, - model=settings.llm_model + api_key=llm_api_key, + base_url=llm_base_url, + model=llm_model ) else: logger.warning("No OpenAI API key - using mock LLM") @@ -148,33 +196,52 @@ class DuplexPipeline: # Connect TTS service if not self.tts_service: - if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key: + tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() + tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key + tts_voice = self._runtime_tts.get("voice") or settings.tts_voice + tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model + tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) + + if tts_provider == "siliconflow" and tts_api_key: self.tts_service = SiliconFlowTTSService( - api_key=settings.siliconflow_api_key, - voice=settings.tts_voice, - model=settings.siliconflow_tts_model, + api_key=tts_api_key, + voice=tts_voice, + model=tts_model, sample_rate=settings.sample_rate, - speed=settings.tts_speed + speed=tts_speed ) logger.info("Using SiliconFlow TTS service") else: self.tts_service = EdgeTTSService( - voice=settings.tts_voice, + voice=tts_voice, sample_rate=settings.sample_rate ) logger.info("Using Edge TTS service") - - await self.tts_service.connect() + + try: + await self.tts_service.connect() + except Exception as e: + logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS") + self.tts_service = MockTTSService( + sample_rate=settings.sample_rate + ) + await self.tts_service.connect() # Connect ASR service if not self.asr_service: - if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key: + asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() + asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key + asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model + asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) + asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) + + if asr_provider == "siliconflow" and asr_api_key: self.asr_service = SiliconFlowASRService( - api_key=settings.siliconflow_api_key, - model=settings.siliconflow_asr_model, + api_key=asr_api_key, + model=asr_model, sample_rate=settings.sample_rate, - interim_interval_ms=settings.asr_interim_interval_ms, - min_audio_for_interim_ms=settings.asr_min_audio_ms, + interim_interval_ms=asr_interim_interval, + min_audio_for_interim_ms=asr_min_audio_ms, on_transcript=self._on_transcript_callback ) logger.info("Using SiliconFlow ASR service") @@ -223,6 +290,13 @@ class DuplexPipeline: "trackId": self.session_id, "probability": probability }) + await self.transport.send_event( + ev( + "input.speech_started" if event_type == "speaking" else "input.speech_stopped", + trackId=self.session_id, + probability=probability, + ) + ) else: # No state change - keep previous status vad_status = self._last_vad_status @@ -325,11 +399,11 @@ class DuplexPipeline: # Send transcript event to client await self.transport.send_event({ - "event": "transcript", - "trackId": self.session_id, - "text": text, - "isFinal": is_final, - "timestamp": self._get_timestamp_ms() + **ev( + "transcript.final" if is_final else "transcript.delta", + trackId=self.session_id, + text=text, + ) }) logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") @@ -383,11 +457,11 @@ class DuplexPipeline: # Send final transcription to client await self.transport.send_event({ - "event": "transcript", - "trackId": self.session_id, - "text": user_text, - "isFinal": True, - "timestamp": self._get_timestamp_ms() + **ev( + "transcript.final", + trackId=self.session_id, + text=user_text, + ) }) # Clear buffers @@ -438,11 +512,11 @@ class DuplexPipeline: # Send LLM response streaming event to client await self.transport.send_event({ - "event": "llmResponse", - "trackId": self.session_id, - "text": text_chunk, - "isFinal": False, - "timestamp": self._get_timestamp_ms() + **ev( + "assistant.response.delta", + trackId=self.session_id, + text=text_chunk, + ) }) # Check for sentence completion - synthesize immediately for low latency @@ -462,9 +536,10 @@ class DuplexPipeline: # Send track start on first audio if not first_audio_sent: await self.transport.send_event({ - "event": "trackStart", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "output.audio.start", + trackId=self.session_id, + ) }) first_audio_sent = True @@ -476,20 +551,21 @@ class DuplexPipeline: # Send final LLM response event if full_response and not self._interrupt_event.is_set(): await self.transport.send_event({ - "event": "llmResponse", - "trackId": self.session_id, - "text": full_response, - "isFinal": True, - "timestamp": self._get_timestamp_ms() + **ev( + "assistant.response.final", + trackId=self.session_id, + text=full_response, + ) }) # Speak any remaining text if sentence_buffer.strip() and not self._interrupt_event.is_set(): if not first_audio_sent: await self.transport.send_event({ - "event": "trackStart", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "output.audio.start", + trackId=self.session_id, + ) }) first_audio_sent = True await self._speak_sentence(sentence_buffer.strip()) @@ -497,9 +573,10 @@ class DuplexPipeline: # Send track end if first_audio_sent: await self.transport.send_event({ - "event": "trackEnd", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "output.audio.end", + trackId=self.session_id, + ) }) # End assistant turn @@ -545,10 +622,11 @@ class DuplexPipeline: # Send TTFB event to client await self.transport.send_event({ - "event": "ttfb", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms(), - "latencyMs": round(ttfb_ms) + **ev( + "metrics.ttfb", + trackId=self.session_id, + latencyMs=round(ttfb_ms), + ) }) # Double-check interrupt right before sending audio @@ -579,9 +657,10 @@ class DuplexPipeline: # Send track start event await self.transport.send_event({ - "event": "trackStart", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "output.audio.start", + trackId=self.session_id, + ) }) self._is_bot_speaking = True @@ -600,10 +679,11 @@ class DuplexPipeline: # Send TTFB event to client await self.transport.send_event({ - "event": "ttfb", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms(), - "latencyMs": round(ttfb_ms) + **ev( + "metrics.ttfb", + trackId=self.session_id, + latencyMs=round(ttfb_ms), + ) }) # Send audio to client @@ -614,9 +694,10 @@ class DuplexPipeline: # Send track end event await self.transport.send_event({ - "event": "trackEnd", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "output.audio.end", + trackId=self.session_id, + ) }) except asyncio.CancelledError: @@ -646,9 +727,10 @@ class DuplexPipeline: # Send interrupt event to client IMMEDIATELY # This must happen BEFORE canceling services, so client knows to discard in-flight audio await self.transport.send_event({ - "event": "interrupt", - "trackId": self.session_id, - "timestamp": self._get_timestamp_ms() + **ev( + "response.interrupted", + trackId=self.session_id, + ) }) # Cancel TTS diff --git a/engine/core/session.py b/engine/core/session.py index 54bf0d4..7ff368e 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -2,13 +2,31 @@ import uuid import json +from enum import Enum from typing import Optional, Dict, Any from loguru import logger from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline -from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand from app.config import settings +from models.ws_v1 import ( + parse_client_message, + ev, + HelloMessage, + SessionStartMessage, + SessionStopMessage, + InputTextMessage, + ResponseCancelMessage, +) + + +class WsSessionState(str, Enum): + """Protocol state machine for WS sessions.""" + + WAIT_HELLO = "wait_hello" + WAIT_START = "wait_start" + ACTIVE = "active" + STOPPED = "stopped" class Session: @@ -41,8 +59,11 @@ class Session: # Session state self.created_at = None - self.state = "created" # created, invited, accepted, ringing, hungup + self.state = "created" # Legacy call state for /call/lists + self.ws_state = WsSessionState.WAIT_HELLO self._pipeline_started = False + self.protocol_version: Optional[str] = None + self.authenticated: bool = False # Track IDs self.current_track_id: Optional[str] = str(uuid.uuid4()) @@ -51,69 +72,27 @@ class Session: async def handle_text(self, text_data: str) -> None: """ - Handle incoming text data (JSON commands). + Handle incoming text data (WS v1 JSON control messages). Args: text_data: JSON text data """ try: data = json.loads(text_data) - command = parse_command(data) - command_type = command.command - - logger.info(f"Session {self.id} received command: {command_type}") - - # Route command to appropriate handler - if command_type == "invite": - await self._handle_invite(data) - - elif command_type == "accept": - await self._handle_accept(data) - - elif command_type == "reject": - await self._handle_reject(data) - - elif command_type == "ringing": - await self._handle_ringing(data) - - elif command_type == "tts": - await self._handle_tts(command) - - elif command_type == "play": - await self._handle_play(data) - - elif command_type == "interrupt": - await self._handle_interrupt(command) - - elif command_type == "pause": - await self._handle_pause() - - elif command_type == "resume": - await self._handle_resume() - - elif command_type == "hangup": - await self._handle_hangup(command) - - elif command_type == "history": - await self._handle_history(data) - - elif command_type == "chat": - await self._handle_chat(command) - - else: - logger.warning(f"Session {self.id} unknown command: {command_type}") + message = parse_client_message(data) + await self._handle_v1_message(message) except json.JSONDecodeError as e: logger.error(f"Session {self.id} JSON decode error: {e}") - await self._send_error("client", f"Invalid JSON: {e}") + await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json") except ValueError as e: logger.error(f"Session {self.id} command parse error: {e}") - await self._send_error("client", f"Invalid command: {e}") + await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message") except Exception as e: logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) - await self._send_error("server", f"Internal error: {e}") + await self._send_error("server", f"Internal error: {e}", "server.internal") async def handle_audio(self, audio_bytes: bytes) -> None: """ @@ -122,156 +101,166 @@ class Session: Args: audio_bytes: PCM audio data """ + if self.ws_state != WsSessionState.ACTIVE: + await self._send_error( + "client", + "Audio received before session.start", + "protocol.order", + ) + return + try: await self.pipeline.process_audio(audio_bytes) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) - async def _handle_invite(self, data: Dict[str, Any]) -> None: - """Handle invite command.""" - self.state = "invited" - option = data.get("option", {}) + async def _handle_v1_message(self, message: Any) -> None: + """Route validated WS v1 message to handlers.""" + msg_type = message.type + logger.info(f"Session {self.id} received message: {msg_type}") - # Send answer event - await self.transport.send_event({ - "event": "answer", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms() - }) + if isinstance(message, HelloMessage): + await self._handle_hello(message) + return + + # All messages below require hello handshake first + if self.ws_state == WsSessionState.WAIT_HELLO: + await self._send_error( + "client", + "Expected hello message first", + "protocol.order", + ) + return + + if isinstance(message, SessionStartMessage): + await self._handle_session_start(message) + return + + # All messages below require active session + if self.ws_state != WsSessionState.ACTIVE: + await self._send_error( + "client", + f"Message '{msg_type}' requires active session", + "protocol.order", + ) + return + + if isinstance(message, InputTextMessage): + await self.pipeline.process_text(message.text) + elif isinstance(message, ResponseCancelMessage): + if message.graceful: + logger.info(f"Session {self.id} graceful response.cancel") + else: + await self.pipeline.interrupt() + elif isinstance(message, SessionStopMessage): + await self._handle_session_stop(message.reason) + else: + await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") + + async def _handle_hello(self, message: HelloMessage) -> None: + """Handle initial hello/auth/version negotiation.""" + if self.ws_state != WsSessionState.WAIT_HELLO: + await self._send_error("client", "Duplicate hello", "protocol.order") + return + + if message.version != settings.ws_protocol_version: + await self._send_error( + "client", + f"Unsupported protocol version '{message.version}'", + "protocol.version_unsupported", + ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + auth_payload = message.auth or {} + api_key = auth_payload.get("apiKey") + jwt = auth_payload.get("jwt") + + if settings.ws_api_key: + if api_key != settings.ws_api_key: + await self._send_error("auth", "Invalid API key", "auth.invalid_api_key") + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + elif settings.ws_require_auth and not (api_key or jwt): + await self._send_error("auth", "Authentication required", "auth.required") + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + self.authenticated = True + self.protocol_version = message.version + self.ws_state = WsSessionState.WAIT_START + await self.transport.send_event( + ev( + "hello.ack", + sessionId=self.id, + version=self.protocol_version, + ) + ) + + async def _handle_session_start(self, message: SessionStartMessage) -> None: + """Handle explicit session start after successful hello.""" + if self.ws_state != WsSessionState.WAIT_START: + await self._send_error("client", "Duplicate session.start", "protocol.order") + return + + # Apply runtime service/prompt overrides from backend if provided + self.pipeline.apply_runtime_overrides(message.metadata) # Start duplex pipeline if not self._pipeline_started: - try: - await self.pipeline.start() - self._pipeline_started = True - logger.info(f"Session {self.id} duplex pipeline started") - except Exception as e: - logger.error(f"Failed to start duplex pipeline: {e}") + await self.pipeline.start() + self._pipeline_started = True + logger.info(f"Session {self.id} duplex pipeline started") - logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}") - - async def _handle_accept(self, data: Dict[str, Any]) -> None: - """Handle accept command.""" self.state = "accepted" - logger.info(f"Session {self.id} accepted") + self.ws_state = WsSessionState.ACTIVE + await self.transport.send_event( + ev( + "session.started", + sessionId=self.id, + trackId=self.current_track_id, + audio=message.audio or {}, + ) + ) - async def _handle_reject(self, data: Dict[str, Any]) -> None: - """Handle reject command.""" - self.state = "rejected" - reason = data.get("reason", "Rejected") - logger.info(f"Session {self.id} rejected: {reason}") + async def _handle_session_stop(self, reason: Optional[str]) -> None: + """Handle session stop.""" + if self.ws_state == WsSessionState.STOPPED: + return - async def _handle_ringing(self, data: Dict[str, Any]) -> None: - """Handle ringing command.""" - self.state = "ringing" - logger.info(f"Session {self.id} ringing") - - async def _handle_tts(self, command: TTSCommand) -> None: - """Handle TTS command.""" - logger.info(f"Session {self.id} TTS: {command.text[:50]}...") - - # Send track start event - await self.transport.send_event({ - "event": "trackStart", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms(), - "playId": command.play_id - }) - - # TODO: Implement actual TTS synthesis - # For now, just send track end event - await self.transport.send_event({ - "event": "trackEnd", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms(), - "duration": 1000, - "ssrc": 0, - "playId": command.play_id - }) - - async def _handle_play(self, data: Dict[str, Any]) -> None: - """Handle play command.""" - url = data.get("url", "") - logger.info(f"Session {self.id} play: {url}") - - # Send track start event - await self.transport.send_event({ - "event": "trackStart", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms(), - "playId": url - }) - - # TODO: Implement actual audio playback - # For now, just send track end event - await self.transport.send_event({ - "event": "trackEnd", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms(), - "duration": 1000, - "ssrc": 0, - "playId": url - }) - - async def _handle_interrupt(self, command: InterruptCommand) -> None: - """Handle interrupt command.""" - if command.graceful: - logger.info(f"Session {self.id} graceful interrupt") - else: - logger.info(f"Session {self.id} immediate interrupt") - await self.pipeline.interrupt() - - async def _handle_pause(self) -> None: - """Handle pause command.""" - logger.info(f"Session {self.id} paused") - - async def _handle_resume(self) -> None: - """Handle resume command.""" - logger.info(f"Session {self.id} resumed") - - async def _handle_hangup(self, command: HangupCommand) -> None: - """Handle hangup command.""" + stop_reason = reason or "client_requested" self.state = "hungup" - reason = command.reason or "User requested" - logger.info(f"Session {self.id} hung up: {reason}") - - # Send hangup event - await self.transport.send_event({ - "event": "hangup", - "timestamp": self._get_timestamp_ms(), - "reason": reason, - "initiator": command.initiator or "user" - }) - - # Close transport + self.ws_state = WsSessionState.STOPPED + await self.transport.send_event( + ev( + "session.stopped", + sessionId=self.id, + reason=stop_reason, + ) + ) await self.transport.close() - async def _handle_history(self, data: Dict[str, Any]) -> None: - """Handle history command.""" - speaker = data.get("speaker", "unknown") - text = data.get("text", "") - logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...") - - async def _handle_chat(self, command: ChatCommand) -> None: - """Handle chat command.""" - logger.info(f"Session {self.id} chat: {command.text[:50]}...") - await self.pipeline.process_text(command.text) - - async def _send_error(self, sender: str, error_message: str) -> None: + async def _send_error(self, sender: str, error_message: str, code: str) -> None: """ Send error event to client. Args: sender: Component that generated the error error_message: Error message + code: Machine-readable error code """ - await self.transport.send_event({ - "event": "error", - "trackId": self.current_track_id, - "timestamp": self._get_timestamp_ms(), - "sender": sender, - "error": error_message - }) + await self.transport.send_event( + ev( + "error", + sender=sender, + code=code, + message=error_message, + trackId=self.current_track_id, + ) + ) def _get_timestamp_ms(self) -> int: """Get current timestamp in milliseconds.""" diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md new file mode 100644 index 0000000..de2cc2c --- /dev/null +++ b/engine/docs/ws_v1_schema.md @@ -0,0 +1,169 @@ +# WS v1 Protocol Schema (`/ws`) + +This document defines the public WebSocket protocol for the `/ws` endpoint. + +## Transport + +- A single WebSocket connection carries: + - JSON text frames for control/events. + - Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default). + +## Handshake and State Machine + +Required message order: + +1. Client sends `hello`. +2. Server replies `hello.ack`. +3. Client sends `session.start`. +4. Server replies `session.started`. +5. Client may stream binary audio and/or send `input.text`. +6. Client sends `session.stop` (or closes socket). + +If order is violated, server emits `error` with `code = "protocol.order"`. + +## Client -> Server Messages + +### `hello` + +```json +{ + "type": "hello", + "version": "v1", + "auth": { + "apiKey": "optional-api-key", + "jwt": "optional-jwt" + } +} +``` + +Rules: +- `version` must be `v1`. +- If `WS_API_KEY` is configured on server, `auth.apiKey` must match. +- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present. + +### `session.start` + +```json +{ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": 16000, + "channels": 1 + }, + "metadata": { + "client": "web-debug", + "systemPrompt": "You are concise.", + "greeting": "Hi, how can I help?", + "services": { + "llm": { + "provider": "openai", + "model": "gpt-4o-mini", + "apiKey": "sk-...", + "baseUrl": "https://api.openai.com/v1" + }, + "asr": { + "provider": "siliconflow", + "model": "FunAudioLLM/SenseVoiceSmall", + "apiKey": "sf-...", + "interimIntervalMs": 500, + "minAudioMs": 300 + }, + "tts": { + "provider": "siliconflow", + "model": "FunAudioLLM/CosyVoice2-0.5B", + "apiKey": "sf-...", + "voice": "anna", + "speed": 1.0 + } + } + } +} +``` + +`metadata.services` is optional. If omitted, server defaults to environment configuration. + +### `input.text` + +```json +{ + "type": "input.text", + "text": "What can you do?" +} +``` + +### `response.cancel` + +```json +{ + "type": "response.cancel", + "graceful": false +} +``` + +### `session.stop` + +```json +{ + "type": "session.stop", + "reason": "client_disconnect" +} +``` + +## Server -> Client Events + +All server events include: + +```json +{ + "type": "event.name", + "timestamp": 1730000000000 +} +``` + +Common events: + +- `hello.ack` + - Fields: `sessionId`, `version` +- `session.started` + - Fields: `sessionId`, `trackId`, `audio` +- `session.stopped` + - Fields: `sessionId`, `reason` +- `heartbeat` +- `input.speech_started` + - Fields: `trackId`, `probability` +- `input.speech_stopped` + - Fields: `trackId`, `probability` +- `transcript.delta` + - Fields: `trackId`, `text` +- `transcript.final` + - Fields: `trackId`, `text` +- `assistant.response.delta` + - Fields: `trackId`, `text` +- `assistant.response.final` + - Fields: `trackId`, `text` +- `output.audio.start` + - Fields: `trackId` +- `output.audio.end` + - Fields: `trackId` +- `response.interrupted` + - Fields: `trackId` +- `metrics.ttfb` + - Fields: `trackId`, `latencyMs` +- `error` + - Fields: `sender`, `code`, `message`, `trackId` + +## Binary Audio Frames + +After `session.started`, client may send binary PCM chunks continuously. + +Recommended format: +- 16-bit signed little-endian PCM. +- 1 channel. +- 16000 Hz. +- 20ms frames (640 bytes) preferred. + +## Compatibility + +This endpoint now enforces v1 message schema for JSON control frames. +Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol. diff --git a/engine/examples/test_websocket.py b/engine/examples/test_websocket.py index 20d388d..0d2675d 100644 --- a/engine/examples/test_websocket.py +++ b/engine/examples/test_websocket.py @@ -36,7 +36,7 @@ def generate_sine_wave(duration_ms=1000): return audio_data -async def receive_loop(ws): +async def receive_loop(ws, ready_event: asyncio.Event): """Listen for incoming messages from the server.""" print("👂 Listening for server responses...") async for msg in ws: @@ -45,8 +45,10 @@ async def receive_loop(ws): if msg.type == aiohttp.WSMsgType.TEXT: try: data = json.loads(msg.data) - event_type = data.get('event', 'Unknown') + event_type = data.get('type', 'Unknown') print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...") + if event_type == "session.started": + ready_event.set() except json.JSONDecodeError: print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...") @@ -118,35 +120,43 @@ async def run_client(url, file_path=None, use_sine=False): print(f"🔌 Connecting to {url}...") async with session.ws_connect(url) as ws: print("✅ Connected!") + session_ready = asyncio.Event() + recv_task = asyncio.create_task(receive_loop(ws, session_ready)) - # Send initial invite command - init_cmd = { - "command": "invite", - "option": { - "codec": "pcm", - "samplerate": SAMPLE_RATE + # Send v1 hello + session.start handshake + await ws.send_json({"type": "hello", "version": "v1"}) + await ws.send_json({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": SAMPLE_RATE, + "channels": 1 } - } - await ws.send_json(init_cmd) - print("📤 Sent Invite Command") + }) + print("📤 Sent v1 hello/session.start") + await asyncio.wait_for(session_ready.wait(), timeout=8) # Select sender based on args if use_sine: - sender_task = send_sine_loop(ws) + await send_sine_loop(ws) elif file_path: - sender_task = send_file_loop(ws, file_path) + await send_file_loop(ws, file_path) else: # Default to sine wave - sender_task = send_sine_loop(ws) + await send_sine_loop(ws) - # Run send and receive loops in parallel - await asyncio.gather( - receive_loop(ws), - sender_task - ) + await ws.send_json({"type": "session.stop", "reason": "test_complete"}) + await asyncio.sleep(1) + recv_task.cancel() + try: + await recv_task + except asyncio.CancelledError: + pass except aiohttp.ClientConnectorError: print(f"❌ Connection Failed. Is the server running at {url}?") + except asyncio.TimeoutError: + print("❌ Timeout waiting for session.started") except Exception as e: print(f"❌ Error: {e}") finally: diff --git a/engine/examples/web_client.html b/engine/examples/web_client.html index bee3d28..c259657 100644 --- a/engine/examples/web_client.html +++ b/engine/examples/web_client.html @@ -547,7 +547,7 @@ setStatus(true, "Session open"); logLine("sys", "WebSocket connected"); ensureAudioContext(); - sendCommand({ command: "invite", option: { codec: "pcm", sampleRate: targetSampleRate } }); + sendCommand({ type: "hello", version: "v1" }); }; ws.onclose = () => { @@ -574,7 +574,10 @@ } function disconnect() { - if (ws) ws.close(); + if (ws && ws.readyState === WebSocket.OPEN) { + sendCommand({ type: "session.stop", reason: "client_disconnect" }); + ws.close(); + } ws = null; setStatus(false, "Disconnected"); } @@ -585,40 +588,48 @@ return; } ws.send(JSON.stringify(cmd)); - logLine("sys", `→ ${cmd.command}`, cmd); + logLine("sys", `→ ${cmd.type}`, cmd); } function handleEvent(event) { - const type = event.event || "unknown"; + const type = event.type || "unknown"; logLine("event", type, event); - if (type === "transcript") { - if (event.isFinal && event.text) { + if (type === "hello.ack") { + sendCommand({ + type: "session.start", + audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, + }); + } + if (type === "transcript.final") { + if (event.text) { setInterim("You", ""); addChat("You", event.text); - } else if (event.text) { - interimUserText += event.text; - setInterim("You", interimUserText); } } - if (type === "llmResponse") { - if (event.isFinal && event.text) { + if (type === "transcript.delta" && event.text) { + setInterim("You", event.text); + } + if (type === "assistant.response.final") { + if (event.text) { setInterim("AI", ""); addChat("AI", event.text); - } else if (event.text) { - interimAiText += event.text; - setInterim("AI", interimAiText); } } - if (type === "trackStart") { + if (type === "assistant.response.delta" && event.text) { + interimAiText += event.text; + setInterim("AI", interimAiText); + } + if (type === "output.audio.start") { // New bot audio: stop any previous playback to avoid overlap stopPlayback(); discardAudio = false; + interimAiText = ""; } - if (type === "speaking") { + if (type === "input.speech_started") { // User started speaking: clear any in-flight audio to avoid overlap stopPlayback(); } - if (type === "interrupt") { + if (type === "response.interrupted") { stopPlayback(); } } @@ -716,7 +727,7 @@ if (!text) return; ensureAudioContext(); addChat("You", text); - sendCommand({ command: "chat", text }); + sendCommand({ type: "input.text", text }); chatInput.value = ""; }); clearLogBtn.addEventListener("click", () => { diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py new file mode 100644 index 0000000..7c51a90 --- /dev/null +++ b/engine/models/ws_v1.py @@ -0,0 +1,67 @@ +"""WS v1 protocol message models and helpers.""" + +from typing import Optional, Dict, Any, Literal +from pydantic import BaseModel, Field + + +def now_ms() -> int: + """Current unix timestamp in milliseconds.""" + import time + + return int(time.time() * 1000) + + +# Client -> Server messages +class HelloMessage(BaseModel): + type: Literal["hello"] + version: str = Field(..., description="Protocol version, currently v1") + auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}") + + +class SessionStartMessage(BaseModel): + type: Literal["session.start"] + audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") + + +class SessionStopMessage(BaseModel): + type: Literal["session.stop"] + reason: Optional[str] = None + + +class InputTextMessage(BaseModel): + type: Literal["input.text"] + text: str + + +class ResponseCancelMessage(BaseModel): + type: Literal["response.cancel"] + graceful: bool = False + + +CLIENT_MESSAGE_TYPES = { + "hello": HelloMessage, + "session.start": SessionStartMessage, + "session.stop": SessionStopMessage, + "input.text": InputTextMessage, + "response.cancel": ResponseCancelMessage, +} + + +def parse_client_message(data: Dict[str, Any]) -> BaseModel: + """Parse and validate a WS v1 client message.""" + msg_type = data.get("type") + if not msg_type: + raise ValueError("Missing 'type' field") + msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) + if not msg_class: + raise ValueError(f"Unknown client message type: {msg_type}") + return msg_class(**data) + + +# Server -> Client event helpers +def ev(event_type: str, **payload: Any) -> Dict[str, Any]: + """Create a WS v1 server event payload.""" + base = {"type": event_type, "timestamp": now_ms()} + base.update(payload) + return base diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index caed9fa..c11aa44 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -4,8 +4,7 @@ import { Plus, Search, Play, Copy, Trash2, Edit2, Mic, MessageSquare, Save, Vide import { Button, Input, Card, Badge, Drawer, Dialog } from '../components/UI'; import { mockLLMModels, mockASRModels } from '../services/mockData'; import { Assistant, KnowledgeBase, TabValue, Voice } from '../types'; -import { GoogleGenAI } from "@google/genai"; -import { createAssistant, deleteAssistant, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi'; +import { createAssistant, deleteAssistant, fetchAssistantRuntimeConfig, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi'; interface ToolItem { id: string; @@ -1022,11 +1021,26 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis const [inputText, setInputText] = useState(''); const [isLoading, setIsLoading] = useState(false); const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle'); + const [wsStatus, setWsStatus] = useState<'disconnected' | 'connecting' | 'ready' | 'error'>('disconnected'); + const [wsError, setWsError] = useState(''); + const [resolvedConfigOpen, setResolvedConfigOpen] = useState(false); + const [resolvedConfigView, setResolvedConfigView] = useState(''); + const [wsUrl, setWsUrl] = useState(() => { + const fromStorage = localStorage.getItem('debug_ws_url'); + if (fromStorage) return fromStorage; + const defaultHost = window.location.hostname || 'localhost'; + return `ws://${defaultHost}:8000/ws`; + }); // Media State const videoRef = useRef(null); const streamRef = useRef(null); const scrollRef = useRef(null); + const wsRef = useRef(null); + const wsReadyRef = useRef(false); + const pendingResolveRef = useRef<(() => void) | null>(null); + const pendingRejectRef = useRef<((e: Error) => void) | null>(null); + const assistantDraftIndexRef = useRef(null); const [devices, setDevices] = useState([]); const [selectedCamera, setSelectedCamera] = useState(''); @@ -1045,11 +1059,16 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis } else { setMode('text'); stopMedia(); + closeWs(); setIsSwapped(false); setCallStatus('idle'); } }, [isOpen, assistant, mode]); + useEffect(() => { + localStorage.setItem('debug_ws_url', wsUrl); + }, [wsUrl]); + // Auto-scroll logic useEffect(() => { if (scrollRef.current) { @@ -1132,29 +1151,196 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis setIsLoading(true); try { - if (process.env.API_KEY) { - const ai = new GoogleGenAI({ apiKey: process.env.API_KEY }); - const chat = ai.chats.create({ - model: "gemini-3-flash-preview", - config: { systemInstruction: assistant.prompt }, - history: messages.map(m => ({ role: m.role, parts: [{ text: m.text }] })) - }); - const result = await chat.sendMessage({ message: userMsg }); - setMessages(prev => [...prev, { role: 'model', text: result.text || '' }]); + if (mode === 'text') { + await ensureWsSession(); + wsRef.current?.send(JSON.stringify({ type: 'input.text', text: userMsg })); } else { setTimeout(() => { - setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]); - setIsLoading(false); + setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]); + setIsLoading(false); }, 1000); } } catch (e) { console.error(e); setMessages(prev => [...prev, { role: 'model', text: "Error: Failed to connect to AI service." }]); - } finally { setIsLoading(false); + } finally { + if (mode !== 'text') setIsLoading(false); } }; + const fetchRuntimeMetadata = async (): Promise> => { + try { + const resolved = await fetchAssistantRuntimeConfig(assistant.id); + setResolvedConfigView( + JSON.stringify( + { + assistantId: resolved.assistantId, + sources: resolved.sources, + warnings: resolved.warnings, + sessionStartMetadata: resolved.sessionStartMetadata || {}, + }, + null, + 2, + ), + ); + return resolved.sessionStartMetadata || {}; + } catch (error) { + console.error('Failed to load runtime config, using fallback.', error); + const fallback = { + systemPrompt: assistant.prompt || '', + greeting: assistant.opener || '', + services: {}, + }; + setResolvedConfigView( + JSON.stringify( + { + assistantId: assistant.id, + sources: {}, + warnings: ['runtime-config endpoint failed; using local fallback'], + sessionStartMetadata: fallback, + }, + null, + 2, + ), + ); + return fallback; + } + }; + + const closeWs = () => { + if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { + wsRef.current.send(JSON.stringify({ type: 'session.stop', reason: 'debug_drawer_closed' })); + } + wsRef.current?.close(); + wsRef.current = null; + wsReadyRef.current = false; + pendingResolveRef.current = null; + pendingRejectRef.current = null; + assistantDraftIndexRef.current = null; + if (isOpen) setWsStatus('disconnected'); + }; + + const ensureWsSession = async () => { + if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) { + return; + } + + if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) { + await new Promise((resolve, reject) => { + pendingResolveRef.current = resolve; + pendingRejectRef.current = reject; + }); + return; + } + + const metadata = await fetchRuntimeMetadata(); + setWsStatus('connecting'); + setWsError(''); + + await new Promise((resolve, reject) => { + pendingResolveRef.current = resolve; + pendingRejectRef.current = reject; + const ws = new WebSocket(wsUrl); + ws.binaryType = 'arraybuffer'; + wsRef.current = ws; + + ws.onopen = () => { + ws.send(JSON.stringify({ type: 'hello', version: 'v1' })); + }; + + ws.onmessage = (event) => { + if (typeof event.data !== 'string') return; + let payload: any; + try { + payload = JSON.parse(event.data); + } catch { + return; + } + + const type = payload?.type; + if (type === 'hello.ack') { + ws.send( + JSON.stringify({ + type: 'session.start', + audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 }, + metadata, + }) + ); + return; + } + + if (type === 'session.started') { + wsReadyRef.current = true; + setWsStatus('ready'); + pendingResolveRef.current?.(); + pendingResolveRef.current = null; + pendingRejectRef.current = null; + return; + } + + if (type === 'assistant.response.delta') { + const delta = String(payload.text || ''); + if (!delta) return; + setMessages((prev) => { + const idx = assistantDraftIndexRef.current; + if (idx === null || !prev[idx] || prev[idx].role !== 'model') { + const next = [...prev, { role: 'model' as const, text: delta }]; + assistantDraftIndexRef.current = next.length - 1; + return next; + } + const next = [...prev]; + next[idx] = { ...next[idx], text: next[idx].text + delta }; + return next; + }); + return; + } + + if (type === 'assistant.response.final') { + const finalText = String(payload.text || ''); + setMessages((prev) => { + const idx = assistantDraftIndexRef.current; + assistantDraftIndexRef.current = null; + if (idx !== null && prev[idx] && prev[idx].role === 'model') { + const next = [...prev]; + next[idx] = { ...next[idx], text: finalText || next[idx].text }; + return next; + } + return finalText ? [...prev, { role: 'model', text: finalText }] : prev; + }); + setIsLoading(false); + return; + } + + if (type === 'error') { + const message = String(payload.message || 'Unknown error'); + setWsStatus('error'); + setWsError(message); + setIsLoading(false); + const err = new Error(message); + pendingRejectRef.current?.(err); + pendingResolveRef.current = null; + pendingRejectRef.current = null; + } + }; + + ws.onerror = () => { + const err = new Error('WebSocket connection error'); + setWsStatus('error'); + setWsError(err.message); + setIsLoading(false); + pendingRejectRef.current?.(err); + pendingResolveRef.current = null; + pendingRejectRef.current = null; + }; + + ws.onclose = () => { + wsReadyRef.current = false; + if (wsStatus !== 'error') setWsStatus('disconnected'); + }; + }); + }; + const TranscriptionLog = () => (
{messages.length === 0 &&
暂无转写记录
} @@ -1205,7 +1391,38 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
{mode === 'text' ? ( - +
+
+ setWsUrl(e.target.value)} placeholder="ws://localhost:8000/ws" /> +
+ + WS: {wsStatus} + + + + {wsError && {wsError}} +
+
+ + {resolvedConfigOpen && ( +
+                        {resolvedConfigView || 'Connect to load resolved config...'}
+                      
+ )} +
+
+ +
) : callStatus === 'idle' ? (
diff --git a/web/services/backendApi.ts b/web/services/backendApi.ts index 82625d5..e6498c8 100644 --- a/web/services/backendApi.ts +++ b/web/services/backendApi.ts @@ -41,6 +41,10 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({ configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none', apiUrl: readField(raw, ['apiUrl', 'api_url'], ''), apiKey: readField(raw, ['apiKey', 'api_key'], ''), + llmModelId: readField(raw, ['llmModelId', 'llm_model_id'], ''), + asrModelId: readField(raw, ['asrModelId', 'asr_model_id'], ''), + embeddingModelId: readField(raw, ['embeddingModelId', 'embedding_model_id'], ''), + rerankModelId: readField(raw, ['rerankModelId', 'rerank_model_id'], ''), }); const mapVoice = (raw: AnyRecord): Voice => ({ @@ -218,6 +222,10 @@ export const updateAssistant = async (id: string, data: Partial): Pro configMode: data.configMode, apiUrl: data.apiUrl, apiKey: data.apiKey, + llmModelId: data.llmModelId, + asrModelId: data.asrModelId, + embeddingModelId: data.embeddingModelId, + rerankModelId: data.rerankModelId, }; const response = await apiRequest(`/assistants/${id}`, { method: 'PUT', body: payload }); return mapAssistant(response); @@ -227,6 +235,21 @@ export const deleteAssistant = async (id: string): Promise => { await apiRequest(`/assistants/${id}`, { method: 'DELETE' }); }; +export interface AssistantRuntimeConfigResponse { + assistantId: string; + sessionStartMetadata: Record; + sources?: { + llmModelId?: string; + asrModelId?: string; + voiceId?: string; + }; + warnings?: string[]; +} + +export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise => { + return apiRequest(`/assistants/${assistantId}/runtime-config`); +}; + export const fetchVoices = async (): Promise => { const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices'); const list = Array.isArray(response) ? response : (response.list || []);