Compare commits
2 Commits
479cfb797b
...
c15c5283e2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c15c5283e2 | ||
|
|
fb6d1eb1da |
@@ -5,7 +5,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ..db import get_db
|
from ..db import get_db
|
||||||
from ..models import Assistant
|
from ..models import Assistant, LLMModel, ASRModel, Voice
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
AssistantCreate, AssistantUpdate, AssistantOut
|
AssistantCreate, AssistantUpdate, AssistantOut
|
||||||
)
|
)
|
||||||
@@ -13,6 +13,73 @@ from ..schemas import (
|
|||||||
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
||||||
|
|
||||||
|
|
||||||
|
def _is_siliconflow_vendor(vendor: Optional[str]) -> bool:
|
||||||
|
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||||
|
metadata = {
|
||||||
|
"systemPrompt": assistant.prompt or "",
|
||||||
|
"greeting": assistant.opener or "",
|
||||||
|
"services": {},
|
||||||
|
}
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
if assistant.llm_model_id:
|
||||||
|
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
|
||||||
|
if llm:
|
||||||
|
metadata["services"]["llm"] = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": llm.model_name or llm.name,
|
||||||
|
"apiKey": llm.api_key,
|
||||||
|
"baseUrl": llm.base_url,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
|
||||||
|
|
||||||
|
if assistant.asr_model_id:
|
||||||
|
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
|
||||||
|
if asr:
|
||||||
|
asr_provider = "siliconflow" if _is_siliconflow_vendor(asr.vendor) else "buffered"
|
||||||
|
metadata["services"]["asr"] = {
|
||||||
|
"provider": asr_provider,
|
||||||
|
"model": asr.model_name or asr.name,
|
||||||
|
"apiKey": asr.api_key if asr_provider == "siliconflow" else None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||||
|
|
||||||
|
if assistant.voice:
|
||||||
|
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
|
||||||
|
if voice:
|
||||||
|
tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge"
|
||||||
|
metadata["services"]["tts"] = {
|
||||||
|
"provider": tts_provider,
|
||||||
|
"model": voice.model,
|
||||||
|
"apiKey": voice.api_key if tts_provider == "siliconflow" else None,
|
||||||
|
"voice": voice.voice_key or voice.id,
|
||||||
|
"speed": assistant.speed or voice.speed,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Keep assistant.voice as direct voice identifier fallback
|
||||||
|
metadata["services"]["tts"] = {
|
||||||
|
"voice": assistant.voice,
|
||||||
|
"speed": assistant.speed or 1.0,
|
||||||
|
}
|
||||||
|
warnings.append(f"Voice resource not found: {assistant.voice}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"assistantId": assistant.id,
|
||||||
|
"sessionStartMetadata": metadata,
|
||||||
|
"sources": {
|
||||||
|
"llmModelId": assistant.llm_model_id,
|
||||||
|
"asrModelId": assistant.asr_model_id,
|
||||||
|
"voiceId": assistant.voice,
|
||||||
|
},
|
||||||
|
"warnings": warnings,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def assistant_to_dict(assistant: Assistant) -> dict:
|
def assistant_to_dict(assistant: Assistant) -> dict:
|
||||||
return {
|
return {
|
||||||
"id": assistant.id,
|
"id": assistant.id,
|
||||||
@@ -84,6 +151,15 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
|
|||||||
return assistant_to_dict(assistant)
|
return assistant_to_dict(assistant)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}/runtime-config")
|
||||||
|
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""Resolve assistant runtime config for engine WS session.start metadata."""
|
||||||
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||||
|
if not assistant:
|
||||||
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
|
return _resolve_runtime_metadata(db, assistant)
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=AssistantOut)
|
@router.post("", response_model=AssistantOut)
|
||||||
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||||
"""创建新助手"""
|
"""创建新助手"""
|
||||||
@@ -139,4 +215,3 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
|
|||||||
db.delete(assistant)
|
db.delete(assistant)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"message": "Deleted successfully"}
|
return {"message": "Deleted successfully"}
|
||||||
|
|
||||||
|
|||||||
@@ -166,3 +166,39 @@ class TestAssistantAPI:
|
|||||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["language"] == lang
|
assert response.json()["language"] == lang
|
||||||
|
|
||||||
|
def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data):
|
||||||
|
"""Test resolved runtime config endpoint for WS session.start metadata."""
|
||||||
|
llm_resp = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
assert llm_resp.status_code == 200
|
||||||
|
|
||||||
|
asr_resp = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
assert asr_resp.status_code == 200
|
||||||
|
|
||||||
|
voice_resp = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
assert voice_resp.status_code == 200
|
||||||
|
voice_id = voice_resp.json()["id"]
|
||||||
|
|
||||||
|
sample_assistant_data.update({
|
||||||
|
"llmModelId": sample_llm_model_data["id"],
|
||||||
|
"asrModelId": sample_asr_model_data["id"],
|
||||||
|
"voice": voice_id,
|
||||||
|
"prompt": "runtime prompt",
|
||||||
|
"opener": "runtime opener",
|
||||||
|
"speed": 1.1,
|
||||||
|
})
|
||||||
|
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert assistant_resp.status_code == 200
|
||||||
|
assistant_id = assistant_resp.json()["id"]
|
||||||
|
|
||||||
|
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||||
|
assert runtime_resp.status_code == 200
|
||||||
|
payload = runtime_resp.json()
|
||||||
|
|
||||||
|
assert payload["assistantId"] == assistant_id
|
||||||
|
metadata = payload["sessionStartMetadata"]
|
||||||
|
assert metadata["systemPrompt"] == "runtime prompt"
|
||||||
|
assert metadata["greeting"] == "runtime opener"
|
||||||
|
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
|
||||||
|
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
|
||||||
|
assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"]
|
||||||
|
|||||||
@@ -23,3 +23,9 @@ python examples/test_websocket.py
|
|||||||
```
|
```
|
||||||
python mic_client.py
|
python mic_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## WS Protocol
|
||||||
|
|
||||||
|
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
|
||||||
|
|
||||||
|
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.
|
||||||
|
|||||||
@@ -87,6 +87,9 @@ class Settings(BaseSettings):
|
|||||||
# WebSocket heartbeat and inactivity
|
# WebSocket heartbeat and inactivity
|
||||||
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
|
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
|
||||||
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
|
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
|
||||||
|
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
|
||||||
|
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
|
||||||
|
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_size_bytes(self) -> int:
|
def chunk_size_bytes(self) -> int:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from core.transports import SocketTransport, WebRtcTransport, BaseTransport
|
|||||||
from core.session import Session
|
from core.session import Session
|
||||||
from processors.tracks import Resampled16kTrack
|
from processors.tracks import Resampled16kTrack
|
||||||
from core.events import get_event_bus, reset_event_bus
|
from core.events import get_event_bus, reset_event_bus
|
||||||
|
from models.ws_v1 import ev
|
||||||
|
|
||||||
# Check interval for heartbeat/timeout (seconds)
|
# Check interval for heartbeat/timeout (seconds)
|
||||||
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
||||||
@@ -54,8 +55,7 @@ async def heartbeat_and_timeout_task(
|
|||||||
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
||||||
try:
|
try:
|
||||||
await transport.send_event({
|
await transport.send_event({
|
||||||
"event": "heartBeat",
|
**ev("heartbeat"),
|
||||||
"timestamp": int(time.time() * 1000),
|
|
||||||
})
|
})
|
||||||
last_heartbeat_at[0] = now
|
last_heartbeat_at[0] = now
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ event-driven design.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Callable, Awaitable
|
from typing import Optional, Callable, Awaitable, Dict, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from core.transports import BaseTransport
|
from core.transports import BaseTransport
|
||||||
@@ -28,6 +28,7 @@ from services.asr import BufferedASRService
|
|||||||
from services.siliconflow_tts import SiliconFlowTTSService
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
from services.siliconflow_asr import SiliconFlowASRService
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from models.ws_v1 import ev
|
||||||
|
|
||||||
|
|
||||||
class DuplexPipeline:
|
class DuplexPipeline:
|
||||||
@@ -127,18 +128,65 @@ class DuplexPipeline:
|
|||||||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||||||
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
||||||
|
|
||||||
|
# Runtime overrides injected from session.start metadata
|
||||||
|
self._runtime_llm: Dict[str, Any] = {}
|
||||||
|
self._runtime_asr: Dict[str, Any] = {}
|
||||||
|
self._runtime_tts: Dict[str, Any] = {}
|
||||||
|
self._runtime_system_prompt: Optional[str] = None
|
||||||
|
self._runtime_greeting: Optional[str] = None
|
||||||
|
|
||||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||||
|
|
||||||
|
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||||
|
"""
|
||||||
|
Apply runtime overrides from WS session.start metadata.
|
||||||
|
|
||||||
|
Expected metadata shape:
|
||||||
|
{
|
||||||
|
"systemPrompt": "...",
|
||||||
|
"greeting": "...",
|
||||||
|
"services": {
|
||||||
|
"llm": {...},
|
||||||
|
"asr": {...},
|
||||||
|
"tts": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "systemPrompt" in metadata:
|
||||||
|
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
|
||||||
|
if self._runtime_system_prompt:
|
||||||
|
self.conversation.system_prompt = self._runtime_system_prompt
|
||||||
|
if "greeting" in metadata:
|
||||||
|
self._runtime_greeting = str(metadata.get("greeting") or "")
|
||||||
|
self.conversation.greeting = self._runtime_greeting or None
|
||||||
|
|
||||||
|
services = metadata.get("services") or {}
|
||||||
|
if isinstance(services, dict):
|
||||||
|
if isinstance(services.get("llm"), dict):
|
||||||
|
self._runtime_llm = services["llm"]
|
||||||
|
if isinstance(services.get("asr"), dict):
|
||||||
|
self._runtime_asr = services["asr"]
|
||||||
|
if isinstance(services.get("tts"), dict):
|
||||||
|
self._runtime_tts = services["tts"]
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the pipeline and connect services."""
|
"""Start the pipeline and connect services."""
|
||||||
try:
|
try:
|
||||||
# Connect LLM service
|
# Connect LLM service
|
||||||
if not self.llm_service:
|
if not self.llm_service:
|
||||||
if settings.openai_api_key:
|
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
||||||
|
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
||||||
|
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||||||
|
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
||||||
|
|
||||||
|
if llm_provider == "openai" and llm_api_key:
|
||||||
self.llm_service = OpenAILLMService(
|
self.llm_service = OpenAILLMService(
|
||||||
api_key=settings.openai_api_key,
|
api_key=llm_api_key,
|
||||||
base_url=settings.openai_api_url,
|
base_url=llm_base_url,
|
||||||
model=settings.llm_model
|
model=llm_model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("No OpenAI API key - using mock LLM")
|
logger.warning("No OpenAI API key - using mock LLM")
|
||||||
@@ -148,33 +196,52 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Connect TTS service
|
# Connect TTS service
|
||||||
if not self.tts_service:
|
if not self.tts_service:
|
||||||
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||||
|
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||||
|
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||||
|
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||||
|
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||||
|
|
||||||
|
if tts_provider == "siliconflow" and tts_api_key:
|
||||||
self.tts_service = SiliconFlowTTSService(
|
self.tts_service = SiliconFlowTTSService(
|
||||||
api_key=settings.siliconflow_api_key,
|
api_key=tts_api_key,
|
||||||
voice=settings.tts_voice,
|
voice=tts_voice,
|
||||||
model=settings.siliconflow_tts_model,
|
model=tts_model,
|
||||||
sample_rate=settings.sample_rate,
|
sample_rate=settings.sample_rate,
|
||||||
speed=settings.tts_speed
|
speed=tts_speed
|
||||||
)
|
)
|
||||||
logger.info("Using SiliconFlow TTS service")
|
logger.info("Using SiliconFlow TTS service")
|
||||||
else:
|
else:
|
||||||
self.tts_service = EdgeTTSService(
|
self.tts_service = EdgeTTSService(
|
||||||
voice=settings.tts_voice,
|
voice=tts_voice,
|
||||||
sample_rate=settings.sample_rate
|
sample_rate=settings.sample_rate
|
||||||
)
|
)
|
||||||
logger.info("Using Edge TTS service")
|
logger.info("Using Edge TTS service")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.tts_service.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
|
||||||
|
self.tts_service = MockTTSService(
|
||||||
|
sample_rate=settings.sample_rate
|
||||||
|
)
|
||||||
await self.tts_service.connect()
|
await self.tts_service.connect()
|
||||||
|
|
||||||
# Connect ASR service
|
# Connect ASR service
|
||||||
if not self.asr_service:
|
if not self.asr_service:
|
||||||
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key:
|
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||||
|
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
||||||
|
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
||||||
|
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||||
|
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||||
|
|
||||||
|
if asr_provider == "siliconflow" and asr_api_key:
|
||||||
self.asr_service = SiliconFlowASRService(
|
self.asr_service = SiliconFlowASRService(
|
||||||
api_key=settings.siliconflow_api_key,
|
api_key=asr_api_key,
|
||||||
model=settings.siliconflow_asr_model,
|
model=asr_model,
|
||||||
sample_rate=settings.sample_rate,
|
sample_rate=settings.sample_rate,
|
||||||
interim_interval_ms=settings.asr_interim_interval_ms,
|
interim_interval_ms=asr_interim_interval,
|
||||||
min_audio_for_interim_ms=settings.asr_min_audio_ms,
|
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||||
on_transcript=self._on_transcript_callback
|
on_transcript=self._on_transcript_callback
|
||||||
)
|
)
|
||||||
logger.info("Using SiliconFlow ASR service")
|
logger.info("Using SiliconFlow ASR service")
|
||||||
@@ -223,6 +290,13 @@ class DuplexPipeline:
|
|||||||
"trackId": self.session_id,
|
"trackId": self.session_id,
|
||||||
"probability": probability
|
"probability": probability
|
||||||
})
|
})
|
||||||
|
await self.transport.send_event(
|
||||||
|
ev(
|
||||||
|
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||||||
|
trackId=self.session_id,
|
||||||
|
probability=probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# No state change - keep previous status
|
# No state change - keep previous status
|
||||||
vad_status = self._last_vad_status
|
vad_status = self._last_vad_status
|
||||||
@@ -325,11 +399,11 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send transcript event to client
|
# Send transcript event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "transcript",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"transcript.final" if is_final else "transcript.delta",
|
||||||
"text": text,
|
trackId=self.session_id,
|
||||||
"isFinal": is_final,
|
text=text,
|
||||||
"timestamp": self._get_timestamp_ms()
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||||
@@ -383,11 +457,11 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send final transcription to client
|
# Send final transcription to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "transcript",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"transcript.final",
|
||||||
"text": user_text,
|
trackId=self.session_id,
|
||||||
"isFinal": True,
|
text=user_text,
|
||||||
"timestamp": self._get_timestamp_ms()
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Clear buffers
|
# Clear buffers
|
||||||
@@ -438,11 +512,11 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send LLM response streaming event to client
|
# Send LLM response streaming event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "llmResponse",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"assistant.response.delta",
|
||||||
"text": text_chunk,
|
trackId=self.session_id,
|
||||||
"isFinal": False,
|
text=text_chunk,
|
||||||
"timestamp": self._get_timestamp_ms()
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check for sentence completion - synthesize immediately for low latency
|
# Check for sentence completion - synthesize immediately for low latency
|
||||||
@@ -462,9 +536,10 @@ class DuplexPipeline:
|
|||||||
# Send track start on first audio
|
# Send track start on first audio
|
||||||
if not first_audio_sent:
|
if not first_audio_sent:
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "trackStart",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"output.audio.start",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
first_audio_sent = True
|
first_audio_sent = True
|
||||||
|
|
||||||
@@ -476,20 +551,21 @@ class DuplexPipeline:
|
|||||||
# Send final LLM response event
|
# Send final LLM response event
|
||||||
if full_response and not self._interrupt_event.is_set():
|
if full_response and not self._interrupt_event.is_set():
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "llmResponse",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"assistant.response.final",
|
||||||
"text": full_response,
|
trackId=self.session_id,
|
||||||
"isFinal": True,
|
text=full_response,
|
||||||
"timestamp": self._get_timestamp_ms()
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Speak any remaining text
|
# Speak any remaining text
|
||||||
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||||
if not first_audio_sent:
|
if not first_audio_sent:
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "trackStart",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"output.audio.start",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
first_audio_sent = True
|
first_audio_sent = True
|
||||||
await self._speak_sentence(sentence_buffer.strip())
|
await self._speak_sentence(sentence_buffer.strip())
|
||||||
@@ -497,9 +573,10 @@ class DuplexPipeline:
|
|||||||
# Send track end
|
# Send track end
|
||||||
if first_audio_sent:
|
if first_audio_sent:
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "trackEnd",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"output.audio.end",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# End assistant turn
|
# End assistant turn
|
||||||
@@ -545,10 +622,11 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send TTFB event to client
|
# Send TTFB event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "ttfb",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"metrics.ttfb",
|
||||||
"timestamp": self._get_timestamp_ms(),
|
trackId=self.session_id,
|
||||||
"latencyMs": round(ttfb_ms)
|
latencyMs=round(ttfb_ms),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Double-check interrupt right before sending audio
|
# Double-check interrupt right before sending audio
|
||||||
@@ -579,9 +657,10 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send track start event
|
# Send track start event
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "trackStart",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"output.audio.start",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
self._is_bot_speaking = True
|
self._is_bot_speaking = True
|
||||||
@@ -600,10 +679,11 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send TTFB event to client
|
# Send TTFB event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "ttfb",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"metrics.ttfb",
|
||||||
"timestamp": self._get_timestamp_ms(),
|
trackId=self.session_id,
|
||||||
"latencyMs": round(ttfb_ms)
|
latencyMs=round(ttfb_ms),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Send audio to client
|
# Send audio to client
|
||||||
@@ -614,9 +694,10 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
# Send track end event
|
# Send track end event
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "trackEnd",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"output.audio.end",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -646,9 +727,10 @@ class DuplexPipeline:
|
|||||||
# Send interrupt event to client IMMEDIATELY
|
# Send interrupt event to client IMMEDIATELY
|
||||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
"event": "interrupt",
|
**ev(
|
||||||
"trackId": self.session_id,
|
"response.interrupted",
|
||||||
"timestamp": self._get_timestamp_ms()
|
trackId=self.session_id,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Cancel TTS
|
# Cancel TTS
|
||||||
|
|||||||
@@ -2,13 +2,31 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from core.transports import BaseTransport
|
from core.transports import BaseTransport
|
||||||
from core.duplex_pipeline import DuplexPipeline
|
from core.duplex_pipeline import DuplexPipeline
|
||||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from models.ws_v1 import (
|
||||||
|
parse_client_message,
|
||||||
|
ev,
|
||||||
|
HelloMessage,
|
||||||
|
SessionStartMessage,
|
||||||
|
SessionStopMessage,
|
||||||
|
InputTextMessage,
|
||||||
|
ResponseCancelMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WsSessionState(str, Enum):
|
||||||
|
"""Protocol state machine for WS sessions."""
|
||||||
|
|
||||||
|
WAIT_HELLO = "wait_hello"
|
||||||
|
WAIT_START = "wait_start"
|
||||||
|
ACTIVE = "active"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
@@ -41,8 +59,11 @@ class Session:
|
|||||||
|
|
||||||
# Session state
|
# Session state
|
||||||
self.created_at = None
|
self.created_at = None
|
||||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
self.state = "created" # Legacy call state for /call/lists
|
||||||
|
self.ws_state = WsSessionState.WAIT_HELLO
|
||||||
self._pipeline_started = False
|
self._pipeline_started = False
|
||||||
|
self.protocol_version: Optional[str] = None
|
||||||
|
self.authenticated: bool = False
|
||||||
|
|
||||||
# Track IDs
|
# Track IDs
|
||||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||||
@@ -51,69 +72,27 @@ class Session:
|
|||||||
|
|
||||||
async def handle_text(self, text_data: str) -> None:
|
async def handle_text(self, text_data: str) -> None:
|
||||||
"""
|
"""
|
||||||
Handle incoming text data (JSON commands).
|
Handle incoming text data (WS v1 JSON control messages).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_data: JSON text data
|
text_data: JSON text data
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = json.loads(text_data)
|
data = json.loads(text_data)
|
||||||
command = parse_command(data)
|
message = parse_client_message(data)
|
||||||
command_type = command.command
|
await self._handle_v1_message(message)
|
||||||
|
|
||||||
logger.info(f"Session {self.id} received command: {command_type}")
|
|
||||||
|
|
||||||
# Route command to appropriate handler
|
|
||||||
if command_type == "invite":
|
|
||||||
await self._handle_invite(data)
|
|
||||||
|
|
||||||
elif command_type == "accept":
|
|
||||||
await self._handle_accept(data)
|
|
||||||
|
|
||||||
elif command_type == "reject":
|
|
||||||
await self._handle_reject(data)
|
|
||||||
|
|
||||||
elif command_type == "ringing":
|
|
||||||
await self._handle_ringing(data)
|
|
||||||
|
|
||||||
elif command_type == "tts":
|
|
||||||
await self._handle_tts(command)
|
|
||||||
|
|
||||||
elif command_type == "play":
|
|
||||||
await self._handle_play(data)
|
|
||||||
|
|
||||||
elif command_type == "interrupt":
|
|
||||||
await self._handle_interrupt(command)
|
|
||||||
|
|
||||||
elif command_type == "pause":
|
|
||||||
await self._handle_pause()
|
|
||||||
|
|
||||||
elif command_type == "resume":
|
|
||||||
await self._handle_resume()
|
|
||||||
|
|
||||||
elif command_type == "hangup":
|
|
||||||
await self._handle_hangup(command)
|
|
||||||
|
|
||||||
elif command_type == "history":
|
|
||||||
await self._handle_history(data)
|
|
||||||
|
|
||||||
elif command_type == "chat":
|
|
||||||
await self._handle_chat(command)
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||||
await self._send_error("client", f"Invalid JSON: {e}")
|
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Session {self.id} command parse error: {e}")
|
logger.error(f"Session {self.id} command parse error: {e}")
|
||||||
await self._send_error("client", f"Invalid command: {e}")
|
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||||
await self._send_error("server", f"Internal error: {e}")
|
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
||||||
|
|
||||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -122,156 +101,166 @@ class Session:
|
|||||||
Args:
|
Args:
|
||||||
audio_bytes: PCM audio data
|
audio_bytes: PCM audio data
|
||||||
"""
|
"""
|
||||||
|
if self.ws_state != WsSessionState.ACTIVE:
|
||||||
|
await self._send_error(
|
||||||
|
"client",
|
||||||
|
"Audio received before session.start",
|
||||||
|
"protocol.order",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.pipeline.process_audio(audio_bytes)
|
await self.pipeline.process_audio(audio_bytes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
async def _handle_v1_message(self, message: Any) -> None:
|
||||||
"""Handle invite command."""
|
"""Route validated WS v1 message to handlers."""
|
||||||
self.state = "invited"
|
msg_type = message.type
|
||||||
option = data.get("option", {})
|
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||||
|
|
||||||
# Send answer event
|
if isinstance(message, HelloMessage):
|
||||||
await self.transport.send_event({
|
await self._handle_hello(message)
|
||||||
"event": "answer",
|
return
|
||||||
"trackId": self.current_track_id,
|
|
||||||
"timestamp": self._get_timestamp_ms()
|
# All messages below require hello handshake first
|
||||||
})
|
if self.ws_state == WsSessionState.WAIT_HELLO:
|
||||||
|
await self._send_error(
|
||||||
|
"client",
|
||||||
|
"Expected hello message first",
|
||||||
|
"protocol.order",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(message, SessionStartMessage):
|
||||||
|
await self._handle_session_start(message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# All messages below require active session
|
||||||
|
if self.ws_state != WsSessionState.ACTIVE:
|
||||||
|
await self._send_error(
|
||||||
|
"client",
|
||||||
|
f"Message '{msg_type}' requires active session",
|
||||||
|
"protocol.order",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(message, InputTextMessage):
|
||||||
|
await self.pipeline.process_text(message.text)
|
||||||
|
elif isinstance(message, ResponseCancelMessage):
|
||||||
|
if message.graceful:
|
||||||
|
logger.info(f"Session {self.id} graceful response.cancel")
|
||||||
|
else:
|
||||||
|
await self.pipeline.interrupt()
|
||||||
|
elif isinstance(message, SessionStopMessage):
|
||||||
|
await self._handle_session_stop(message.reason)
|
||||||
|
else:
|
||||||
|
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||||
|
|
||||||
|
async def _handle_hello(self, message: HelloMessage) -> None:
|
||||||
|
"""Handle initial hello/auth/version negotiation."""
|
||||||
|
if self.ws_state != WsSessionState.WAIT_HELLO:
|
||||||
|
await self._send_error("client", "Duplicate hello", "protocol.order")
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.version != settings.ws_protocol_version:
|
||||||
|
await self._send_error(
|
||||||
|
"client",
|
||||||
|
f"Unsupported protocol version '{message.version}'",
|
||||||
|
"protocol.version_unsupported",
|
||||||
|
)
|
||||||
|
await self.transport.close()
|
||||||
|
self.ws_state = WsSessionState.STOPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
auth_payload = message.auth or {}
|
||||||
|
api_key = auth_payload.get("apiKey")
|
||||||
|
jwt = auth_payload.get("jwt")
|
||||||
|
|
||||||
|
if settings.ws_api_key:
|
||||||
|
if api_key != settings.ws_api_key:
|
||||||
|
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
||||||
|
await self.transport.close()
|
||||||
|
self.ws_state = WsSessionState.STOPPED
|
||||||
|
return
|
||||||
|
elif settings.ws_require_auth and not (api_key or jwt):
|
||||||
|
await self._send_error("auth", "Authentication required", "auth.required")
|
||||||
|
await self.transport.close()
|
||||||
|
self.ws_state = WsSessionState.STOPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
self.authenticated = True
|
||||||
|
self.protocol_version = message.version
|
||||||
|
self.ws_state = WsSessionState.WAIT_START
|
||||||
|
await self.transport.send_event(
|
||||||
|
ev(
|
||||||
|
"hello.ack",
|
||||||
|
sessionId=self.id,
|
||||||
|
version=self.protocol_version,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||||
|
"""Handle explicit session start after successful hello."""
|
||||||
|
if self.ws_state != WsSessionState.WAIT_START:
|
||||||
|
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Apply runtime service/prompt overrides from backend if provided
|
||||||
|
self.pipeline.apply_runtime_overrides(message.metadata)
|
||||||
|
|
||||||
# Start duplex pipeline
|
# Start duplex pipeline
|
||||||
if not self._pipeline_started:
|
if not self._pipeline_started:
|
||||||
try:
|
|
||||||
await self.pipeline.start()
|
await self.pipeline.start()
|
||||||
self._pipeline_started = True
|
self._pipeline_started = True
|
||||||
logger.info(f"Session {self.id} duplex pipeline started")
|
logger.info(f"Session {self.id} duplex pipeline started")
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start duplex pipeline: {e}")
|
|
||||||
|
|
||||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
|
||||||
|
|
||||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
|
||||||
"""Handle accept command."""
|
|
||||||
self.state = "accepted"
|
self.state = "accepted"
|
||||||
logger.info(f"Session {self.id} accepted")
|
self.ws_state = WsSessionState.ACTIVE
|
||||||
|
await self.transport.send_event(
|
||||||
|
ev(
|
||||||
|
"session.started",
|
||||||
|
sessionId=self.id,
|
||||||
|
trackId=self.current_track_id,
|
||||||
|
audio=message.audio or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
||||||
"""Handle reject command."""
|
"""Handle session stop."""
|
||||||
self.state = "rejected"
|
if self.ws_state == WsSessionState.STOPPED:
|
||||||
reason = data.get("reason", "Rejected")
|
return
|
||||||
logger.info(f"Session {self.id} rejected: {reason}")
|
|
||||||
|
|
||||||
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
stop_reason = reason or "client_requested"
|
||||||
"""Handle ringing command."""
|
|
||||||
self.state = "ringing"
|
|
||||||
logger.info(f"Session {self.id} ringing")
|
|
||||||
|
|
||||||
async def _handle_tts(self, command: TTSCommand) -> None:
|
|
||||||
"""Handle TTS command."""
|
|
||||||
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
|
||||||
|
|
||||||
# Send track start event
|
|
||||||
await self.transport.send_event({
|
|
||||||
"event": "trackStart",
|
|
||||||
"trackId": self.current_track_id,
|
|
||||||
"timestamp": self._get_timestamp_ms(),
|
|
||||||
"playId": command.play_id
|
|
||||||
})
|
|
||||||
|
|
||||||
# TODO: Implement actual TTS synthesis
|
|
||||||
# For now, just send track end event
|
|
||||||
await self.transport.send_event({
|
|
||||||
"event": "trackEnd",
|
|
||||||
"trackId": self.current_track_id,
|
|
||||||
"timestamp": self._get_timestamp_ms(),
|
|
||||||
"duration": 1000,
|
|
||||||
"ssrc": 0,
|
|
||||||
"playId": command.play_id
|
|
||||||
})
|
|
||||||
|
|
||||||
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
|
||||||
"""Handle play command."""
|
|
||||||
url = data.get("url", "")
|
|
||||||
logger.info(f"Session {self.id} play: {url}")
|
|
||||||
|
|
||||||
# Send track start event
|
|
||||||
await self.transport.send_event({
|
|
||||||
"event": "trackStart",
|
|
||||||
"trackId": self.current_track_id,
|
|
||||||
"timestamp": self._get_timestamp_ms(),
|
|
||||||
"playId": url
|
|
||||||
})
|
|
||||||
|
|
||||||
# TODO: Implement actual audio playback
|
|
||||||
# For now, just send track end event
|
|
||||||
await self.transport.send_event({
|
|
||||||
"event": "trackEnd",
|
|
||||||
"trackId": self.current_track_id,
|
|
||||||
"timestamp": self._get_timestamp_ms(),
|
|
||||||
"duration": 1000,
|
|
||||||
"ssrc": 0,
|
|
||||||
"playId": url
|
|
||||||
})
|
|
||||||
|
|
||||||
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
|
||||||
"""Handle interrupt command."""
|
|
||||||
if command.graceful:
|
|
||||||
logger.info(f"Session {self.id} graceful interrupt")
|
|
||||||
else:
|
|
||||||
logger.info(f"Session {self.id} immediate interrupt")
|
|
||||||
await self.pipeline.interrupt()
|
|
||||||
|
|
||||||
async def _handle_pause(self) -> None:
|
|
||||||
"""Handle pause command."""
|
|
||||||
logger.info(f"Session {self.id} paused")
|
|
||||||
|
|
||||||
async def _handle_resume(self) -> None:
|
|
||||||
"""Handle resume command."""
|
|
||||||
logger.info(f"Session {self.id} resumed")
|
|
||||||
|
|
||||||
async def _handle_hangup(self, command: HangupCommand) -> None:
|
|
||||||
"""Handle hangup command."""
|
|
||||||
self.state = "hungup"
|
self.state = "hungup"
|
||||||
reason = command.reason or "User requested"
|
self.ws_state = WsSessionState.STOPPED
|
||||||
logger.info(f"Session {self.id} hung up: {reason}")
|
await self.transport.send_event(
|
||||||
|
ev(
|
||||||
# Send hangup event
|
"session.stopped",
|
||||||
await self.transport.send_event({
|
sessionId=self.id,
|
||||||
"event": "hangup",
|
reason=stop_reason,
|
||||||
"timestamp": self._get_timestamp_ms(),
|
)
|
||||||
"reason": reason,
|
)
|
||||||
"initiator": command.initiator or "user"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Close transport
|
|
||||||
await self.transport.close()
|
await self.transport.close()
|
||||||
|
|
||||||
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||||
"""Handle history command."""
|
|
||||||
speaker = data.get("speaker", "unknown")
|
|
||||||
text = data.get("text", "")
|
|
||||||
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
|
||||||
|
|
||||||
async def _handle_chat(self, command: ChatCommand) -> None:
|
|
||||||
"""Handle chat command."""
|
|
||||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
|
||||||
await self.pipeline.process_text(command.text)
|
|
||||||
|
|
||||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
|
||||||
"""
|
"""
|
||||||
Send error event to client.
|
Send error event to client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sender: Component that generated the error
|
sender: Component that generated the error
|
||||||
error_message: Error message
|
error_message: Error message
|
||||||
|
code: Machine-readable error code
|
||||||
"""
|
"""
|
||||||
await self.transport.send_event({
|
await self.transport.send_event(
|
||||||
"event": "error",
|
ev(
|
||||||
"trackId": self.current_track_id,
|
"error",
|
||||||
"timestamp": self._get_timestamp_ms(),
|
sender=sender,
|
||||||
"sender": sender,
|
code=code,
|
||||||
"error": error_message
|
message=error_message,
|
||||||
})
|
trackId=self.current_track_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_timestamp_ms(self) -> int:
|
def _get_timestamp_ms(self) -> int:
|
||||||
"""Get current timestamp in milliseconds."""
|
"""Get current timestamp in milliseconds."""
|
||||||
|
|||||||
169
engine/docs/ws_v1_schema.md
Normal file
169
engine/docs/ws_v1_schema.md
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# WS v1 Protocol Schema (`/ws`)
|
||||||
|
|
||||||
|
This document defines the public WebSocket protocol for the `/ws` endpoint.
|
||||||
|
|
||||||
|
## Transport
|
||||||
|
|
||||||
|
- A single WebSocket connection carries:
|
||||||
|
- JSON text frames for control/events.
|
||||||
|
- Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default).
|
||||||
|
|
||||||
|
## Handshake and State Machine
|
||||||
|
|
||||||
|
Required message order:
|
||||||
|
|
||||||
|
1. Client sends `hello`.
|
||||||
|
2. Server replies `hello.ack`.
|
||||||
|
3. Client sends `session.start`.
|
||||||
|
4. Server replies `session.started`.
|
||||||
|
5. Client may stream binary audio and/or send `input.text`.
|
||||||
|
6. Client sends `session.stop` (or closes socket).
|
||||||
|
|
||||||
|
If order is violated, server emits `error` with `code = "protocol.order"`.
|
||||||
|
|
||||||
|
## Client -> Server Messages
|
||||||
|
|
||||||
|
### `hello`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "hello",
|
||||||
|
"version": "v1",
|
||||||
|
"auth": {
|
||||||
|
"apiKey": "optional-api-key",
|
||||||
|
"jwt": "optional-jwt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- `version` must be `v1`.
|
||||||
|
- If `WS_API_KEY` is configured on server, `auth.apiKey` must match.
|
||||||
|
- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present.
|
||||||
|
|
||||||
|
### `session.start`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "session.start",
|
||||||
|
"audio": {
|
||||||
|
"encoding": "pcm_s16le",
|
||||||
|
"sample_rate_hz": 16000,
|
||||||
|
"channels": 1
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"client": "web-debug",
|
||||||
|
"systemPrompt": "You are concise.",
|
||||||
|
"greeting": "Hi, how can I help?",
|
||||||
|
"services": {
|
||||||
|
"llm": {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"apiKey": "sk-...",
|
||||||
|
"baseUrl": "https://api.openai.com/v1"
|
||||||
|
},
|
||||||
|
"asr": {
|
||||||
|
"provider": "siliconflow",
|
||||||
|
"model": "FunAudioLLM/SenseVoiceSmall",
|
||||||
|
"apiKey": "sf-...",
|
||||||
|
"interimIntervalMs": 500,
|
||||||
|
"minAudioMs": 300
|
||||||
|
},
|
||||||
|
"tts": {
|
||||||
|
"provider": "siliconflow",
|
||||||
|
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
||||||
|
"apiKey": "sf-...",
|
||||||
|
"voice": "anna",
|
||||||
|
"speed": 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`metadata.services` is optional. If omitted, server defaults to environment configuration.
|
||||||
|
|
||||||
|
### `input.text`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "input.text",
|
||||||
|
"text": "What can you do?"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `response.cancel`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "response.cancel",
|
||||||
|
"graceful": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `session.stop`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "session.stop",
|
||||||
|
"reason": "client_disconnect"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Server -> Client Events
|
||||||
|
|
||||||
|
All server events include:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "event.name",
|
||||||
|
"timestamp": 1730000000000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Common events:
|
||||||
|
|
||||||
|
- `hello.ack`
|
||||||
|
- Fields: `sessionId`, `version`
|
||||||
|
- `session.started`
|
||||||
|
- Fields: `sessionId`, `trackId`, `audio`
|
||||||
|
- `session.stopped`
|
||||||
|
- Fields: `sessionId`, `reason`
|
||||||
|
- `heartbeat`
|
||||||
|
- `input.speech_started`
|
||||||
|
- Fields: `trackId`, `probability`
|
||||||
|
- `input.speech_stopped`
|
||||||
|
- Fields: `trackId`, `probability`
|
||||||
|
- `transcript.delta`
|
||||||
|
- Fields: `trackId`, `text`
|
||||||
|
- `transcript.final`
|
||||||
|
- Fields: `trackId`, `text`
|
||||||
|
- `assistant.response.delta`
|
||||||
|
- Fields: `trackId`, `text`
|
||||||
|
- `assistant.response.final`
|
||||||
|
- Fields: `trackId`, `text`
|
||||||
|
- `output.audio.start`
|
||||||
|
- Fields: `trackId`
|
||||||
|
- `output.audio.end`
|
||||||
|
- Fields: `trackId`
|
||||||
|
- `response.interrupted`
|
||||||
|
- Fields: `trackId`
|
||||||
|
- `metrics.ttfb`
|
||||||
|
- Fields: `trackId`, `latencyMs`
|
||||||
|
- `error`
|
||||||
|
- Fields: `sender`, `code`, `message`, `trackId`
|
||||||
|
|
||||||
|
## Binary Audio Frames
|
||||||
|
|
||||||
|
After `session.started`, client may send binary PCM chunks continuously.
|
||||||
|
|
||||||
|
Recommended format:
|
||||||
|
- 16-bit signed little-endian PCM.
|
||||||
|
- 1 channel.
|
||||||
|
- 16000 Hz.
|
||||||
|
- 20ms frames (640 bytes) preferred.
|
||||||
|
|
||||||
|
## Compatibility
|
||||||
|
|
||||||
|
This endpoint now enforces v1 message schema for JSON control frames.
|
||||||
|
Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol.
|
||||||
@@ -36,7 +36,7 @@ def generate_sine_wave(duration_ms=1000):
|
|||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
async def receive_loop(ws):
|
async def receive_loop(ws, ready_event: asyncio.Event):
|
||||||
"""Listen for incoming messages from the server."""
|
"""Listen for incoming messages from the server."""
|
||||||
print("👂 Listening for server responses...")
|
print("👂 Listening for server responses...")
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
@@ -45,8 +45,10 @@ async def receive_loop(ws):
|
|||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
try:
|
try:
|
||||||
data = json.loads(msg.data)
|
data = json.loads(msg.data)
|
||||||
event_type = data.get('event', 'Unknown')
|
event_type = data.get('type', 'Unknown')
|
||||||
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
||||||
|
if event_type == "session.started":
|
||||||
|
ready_event.set()
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
||||||
|
|
||||||
@@ -118,35 +120,43 @@ async def run_client(url, file_path=None, use_sine=False):
|
|||||||
print(f"🔌 Connecting to {url}...")
|
print(f"🔌 Connecting to {url}...")
|
||||||
async with session.ws_connect(url) as ws:
|
async with session.ws_connect(url) as ws:
|
||||||
print("✅ Connected!")
|
print("✅ Connected!")
|
||||||
|
session_ready = asyncio.Event()
|
||||||
|
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
|
||||||
|
|
||||||
# Send initial invite command
|
# Send v1 hello + session.start handshake
|
||||||
init_cmd = {
|
await ws.send_json({"type": "hello", "version": "v1"})
|
||||||
"command": "invite",
|
await ws.send_json({
|
||||||
"option": {
|
"type": "session.start",
|
||||||
"codec": "pcm",
|
"audio": {
|
||||||
"samplerate": SAMPLE_RATE
|
"encoding": "pcm_s16le",
|
||||||
|
"sample_rate_hz": SAMPLE_RATE,
|
||||||
|
"channels": 1
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
await ws.send_json(init_cmd)
|
print("📤 Sent v1 hello/session.start")
|
||||||
print("📤 Sent Invite Command")
|
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
||||||
|
|
||||||
# Select sender based on args
|
# Select sender based on args
|
||||||
if use_sine:
|
if use_sine:
|
||||||
sender_task = send_sine_loop(ws)
|
await send_sine_loop(ws)
|
||||||
elif file_path:
|
elif file_path:
|
||||||
sender_task = send_file_loop(ws, file_path)
|
await send_file_loop(ws, file_path)
|
||||||
else:
|
else:
|
||||||
# Default to sine wave
|
# Default to sine wave
|
||||||
sender_task = send_sine_loop(ws)
|
await send_sine_loop(ws)
|
||||||
|
|
||||||
# Run send and receive loops in parallel
|
await ws.send_json({"type": "session.stop", "reason": "test_complete"})
|
||||||
await asyncio.gather(
|
await asyncio.sleep(1)
|
||||||
receive_loop(ws),
|
recv_task.cancel()
|
||||||
sender_task
|
try:
|
||||||
)
|
await recv_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
except aiohttp.ClientConnectorError:
|
except aiohttp.ClientConnectorError:
|
||||||
print(f"❌ Connection Failed. Is the server running at {url}?")
|
print(f"❌ Connection Failed. Is the server running at {url}?")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("❌ Timeout waiting for session.started")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error: {e}")
|
print(f"❌ Error: {e}")
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -547,7 +547,7 @@
|
|||||||
setStatus(true, "Session open");
|
setStatus(true, "Session open");
|
||||||
logLine("sys", "WebSocket connected");
|
logLine("sys", "WebSocket connected");
|
||||||
ensureAudioContext();
|
ensureAudioContext();
|
||||||
sendCommand({ command: "invite", option: { codec: "pcm", sampleRate: targetSampleRate } });
|
sendCommand({ type: "hello", version: "v1" });
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onclose = () => {
|
ws.onclose = () => {
|
||||||
@@ -574,7 +574,10 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function disconnect() {
|
function disconnect() {
|
||||||
if (ws) ws.close();
|
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||||
|
sendCommand({ type: "session.stop", reason: "client_disconnect" });
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
ws = null;
|
ws = null;
|
||||||
setStatus(false, "Disconnected");
|
setStatus(false, "Disconnected");
|
||||||
}
|
}
|
||||||
@@ -585,40 +588,48 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ws.send(JSON.stringify(cmd));
|
ws.send(JSON.stringify(cmd));
|
||||||
logLine("sys", `→ ${cmd.command}`, cmd);
|
logLine("sys", `→ ${cmd.type}`, cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleEvent(event) {
|
function handleEvent(event) {
|
||||||
const type = event.event || "unknown";
|
const type = event.type || "unknown";
|
||||||
logLine("event", type, event);
|
logLine("event", type, event);
|
||||||
if (type === "transcript") {
|
if (type === "hello.ack") {
|
||||||
if (event.isFinal && event.text) {
|
sendCommand({
|
||||||
|
type: "session.start",
|
||||||
|
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (type === "transcript.final") {
|
||||||
|
if (event.text) {
|
||||||
setInterim("You", "");
|
setInterim("You", "");
|
||||||
addChat("You", event.text);
|
addChat("You", event.text);
|
||||||
} else if (event.text) {
|
|
||||||
interimUserText += event.text;
|
|
||||||
setInterim("You", interimUserText);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (type === "llmResponse") {
|
if (type === "transcript.delta" && event.text) {
|
||||||
if (event.isFinal && event.text) {
|
setInterim("You", event.text);
|
||||||
|
}
|
||||||
|
if (type === "assistant.response.final") {
|
||||||
|
if (event.text) {
|
||||||
setInterim("AI", "");
|
setInterim("AI", "");
|
||||||
addChat("AI", event.text);
|
addChat("AI", event.text);
|
||||||
} else if (event.text) {
|
}
|
||||||
|
}
|
||||||
|
if (type === "assistant.response.delta" && event.text) {
|
||||||
interimAiText += event.text;
|
interimAiText += event.text;
|
||||||
setInterim("AI", interimAiText);
|
setInterim("AI", interimAiText);
|
||||||
}
|
}
|
||||||
}
|
if (type === "output.audio.start") {
|
||||||
if (type === "trackStart") {
|
|
||||||
// New bot audio: stop any previous playback to avoid overlap
|
// New bot audio: stop any previous playback to avoid overlap
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
discardAudio = false;
|
discardAudio = false;
|
||||||
|
interimAiText = "";
|
||||||
}
|
}
|
||||||
if (type === "speaking") {
|
if (type === "input.speech_started") {
|
||||||
// User started speaking: clear any in-flight audio to avoid overlap
|
// User started speaking: clear any in-flight audio to avoid overlap
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
}
|
}
|
||||||
if (type === "interrupt") {
|
if (type === "response.interrupted") {
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -716,7 +727,7 @@
|
|||||||
if (!text) return;
|
if (!text) return;
|
||||||
ensureAudioContext();
|
ensureAudioContext();
|
||||||
addChat("You", text);
|
addChat("You", text);
|
||||||
sendCommand({ command: "chat", text });
|
sendCommand({ type: "input.text", text });
|
||||||
chatInput.value = "";
|
chatInput.value = "";
|
||||||
});
|
});
|
||||||
clearLogBtn.addEventListener("click", () => {
|
clearLogBtn.addEventListener("click", () => {
|
||||||
|
|||||||
67
engine/models/ws_v1.py
Normal file
67
engine/models/ws_v1.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""WS v1 protocol message models and helpers."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any, Literal
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
def now_ms() -> int:
|
||||||
|
"""Current unix timestamp in milliseconds."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
# Client -> Server messages
|
||||||
|
class HelloMessage(BaseModel):
|
||||||
|
type: Literal["hello"]
|
||||||
|
version: str = Field(..., description="Protocol version, currently v1")
|
||||||
|
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStartMessage(BaseModel):
|
||||||
|
type: Literal["session.start"]
|
||||||
|
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStopMessage(BaseModel):
|
||||||
|
type: Literal["session.stop"]
|
||||||
|
reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InputTextMessage(BaseModel):
|
||||||
|
type: Literal["input.text"]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseCancelMessage(BaseModel):
|
||||||
|
type: Literal["response.cancel"]
|
||||||
|
graceful: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
CLIENT_MESSAGE_TYPES = {
|
||||||
|
"hello": HelloMessage,
|
||||||
|
"session.start": SessionStartMessage,
|
||||||
|
"session.stop": SessionStopMessage,
|
||||||
|
"input.text": InputTextMessage,
|
||||||
|
"response.cancel": ResponseCancelMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_client_message(data: Dict[str, Any]) -> BaseModel:
|
||||||
|
"""Parse and validate a WS v1 client message."""
|
||||||
|
msg_type = data.get("type")
|
||||||
|
if not msg_type:
|
||||||
|
raise ValueError("Missing 'type' field")
|
||||||
|
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
||||||
|
if not msg_class:
|
||||||
|
raise ValueError(f"Unknown client message type: {msg_type}")
|
||||||
|
return msg_class(**data)
|
||||||
|
|
||||||
|
|
||||||
|
# Server -> Client event helpers
|
||||||
|
def ev(event_type: str, **payload: Any) -> Dict[str, Any]:
|
||||||
|
"""Create a WS v1 server event payload."""
|
||||||
|
base = {"type": event_type, "timestamp": now_ms()}
|
||||||
|
base.update(payload)
|
||||||
|
return base
|
||||||
@@ -4,8 +4,7 @@ import { Plus, Search, Play, Copy, Trash2, Edit2, Mic, MessageSquare, Save, Vide
|
|||||||
import { Button, Input, Card, Badge, Drawer, Dialog } from '../components/UI';
|
import { Button, Input, Card, Badge, Drawer, Dialog } from '../components/UI';
|
||||||
import { mockLLMModels, mockASRModels } from '../services/mockData';
|
import { mockLLMModels, mockASRModels } from '../services/mockData';
|
||||||
import { Assistant, KnowledgeBase, TabValue, Voice } from '../types';
|
import { Assistant, KnowledgeBase, TabValue, Voice } from '../types';
|
||||||
import { GoogleGenAI } from "@google/genai";
|
import { createAssistant, deleteAssistant, fetchAssistantRuntimeConfig, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi';
|
||||||
import { createAssistant, deleteAssistant, fetchAssistants, fetchKnowledgeBases, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi';
|
|
||||||
|
|
||||||
interface ToolItem {
|
interface ToolItem {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -1022,11 +1021,26 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
|
|||||||
const [inputText, setInputText] = useState('');
|
const [inputText, setInputText] = useState('');
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle');
|
const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle');
|
||||||
|
const [wsStatus, setWsStatus] = useState<'disconnected' | 'connecting' | 'ready' | 'error'>('disconnected');
|
||||||
|
const [wsError, setWsError] = useState('');
|
||||||
|
const [resolvedConfigOpen, setResolvedConfigOpen] = useState(false);
|
||||||
|
const [resolvedConfigView, setResolvedConfigView] = useState<string>('');
|
||||||
|
const [wsUrl, setWsUrl] = useState<string>(() => {
|
||||||
|
const fromStorage = localStorage.getItem('debug_ws_url');
|
||||||
|
if (fromStorage) return fromStorage;
|
||||||
|
const defaultHost = window.location.hostname || 'localhost';
|
||||||
|
return `ws://${defaultHost}:8000/ws`;
|
||||||
|
});
|
||||||
|
|
||||||
// Media State
|
// Media State
|
||||||
const videoRef = useRef<HTMLVideoElement>(null);
|
const videoRef = useRef<HTMLVideoElement>(null);
|
||||||
const streamRef = useRef<MediaStream | null>(null);
|
const streamRef = useRef<MediaStream | null>(null);
|
||||||
const scrollRef = useRef<HTMLDivElement>(null);
|
const scrollRef = useRef<HTMLDivElement>(null);
|
||||||
|
const wsRef = useRef<WebSocket | null>(null);
|
||||||
|
const wsReadyRef = useRef(false);
|
||||||
|
const pendingResolveRef = useRef<(() => void) | null>(null);
|
||||||
|
const pendingRejectRef = useRef<((e: Error) => void) | null>(null);
|
||||||
|
const assistantDraftIndexRef = useRef<number | null>(null);
|
||||||
|
|
||||||
const [devices, setDevices] = useState<MediaDeviceInfo[]>([]);
|
const [devices, setDevices] = useState<MediaDeviceInfo[]>([]);
|
||||||
const [selectedCamera, setSelectedCamera] = useState<string>('');
|
const [selectedCamera, setSelectedCamera] = useState<string>('');
|
||||||
@@ -1045,11 +1059,16 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
|
|||||||
} else {
|
} else {
|
||||||
setMode('text');
|
setMode('text');
|
||||||
stopMedia();
|
stopMedia();
|
||||||
|
closeWs();
|
||||||
setIsSwapped(false);
|
setIsSwapped(false);
|
||||||
setCallStatus('idle');
|
setCallStatus('idle');
|
||||||
}
|
}
|
||||||
}, [isOpen, assistant, mode]);
|
}, [isOpen, assistant, mode]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
localStorage.setItem('debug_ws_url', wsUrl);
|
||||||
|
}, [wsUrl]);
|
||||||
|
|
||||||
// Auto-scroll logic
|
// Auto-scroll logic
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (scrollRef.current) {
|
if (scrollRef.current) {
|
||||||
@@ -1132,15 +1151,9 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
|
|||||||
setIsLoading(true);
|
setIsLoading(true);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (process.env.API_KEY) {
|
if (mode === 'text') {
|
||||||
const ai = new GoogleGenAI({ apiKey: process.env.API_KEY });
|
await ensureWsSession();
|
||||||
const chat = ai.chats.create({
|
wsRef.current?.send(JSON.stringify({ type: 'input.text', text: userMsg }));
|
||||||
model: "gemini-3-flash-preview",
|
|
||||||
config: { systemInstruction: assistant.prompt },
|
|
||||||
history: messages.map(m => ({ role: m.role, parts: [{ text: m.text }] }))
|
|
||||||
});
|
|
||||||
const result = await chat.sendMessage({ message: userMsg });
|
|
||||||
setMessages(prev => [...prev, { role: 'model', text: result.text || '' }]);
|
|
||||||
} else {
|
} else {
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]);
|
setMessages(prev => [...prev, { role: 'model', text: `[Mock Response]: Received "${userMsg}"` }]);
|
||||||
@@ -1150,11 +1163,184 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
setMessages(prev => [...prev, { role: 'model', text: "Error: Failed to connect to AI service." }]);
|
setMessages(prev => [...prev, { role: 'model', text: "Error: Failed to connect to AI service." }]);
|
||||||
} finally {
|
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
} finally {
|
||||||
|
if (mode !== 'text') setIsLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
|
||||||
|
try {
|
||||||
|
const resolved = await fetchAssistantRuntimeConfig(assistant.id);
|
||||||
|
setResolvedConfigView(
|
||||||
|
JSON.stringify(
|
||||||
|
{
|
||||||
|
assistantId: resolved.assistantId,
|
||||||
|
sources: resolved.sources,
|
||||||
|
warnings: resolved.warnings,
|
||||||
|
sessionStartMetadata: resolved.sessionStartMetadata || {},
|
||||||
|
},
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return resolved.sessionStartMetadata || {};
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to load runtime config, using fallback.', error);
|
||||||
|
const fallback = {
|
||||||
|
systemPrompt: assistant.prompt || '',
|
||||||
|
greeting: assistant.opener || '',
|
||||||
|
services: {},
|
||||||
|
};
|
||||||
|
setResolvedConfigView(
|
||||||
|
JSON.stringify(
|
||||||
|
{
|
||||||
|
assistantId: assistant.id,
|
||||||
|
sources: {},
|
||||||
|
warnings: ['runtime-config endpoint failed; using local fallback'],
|
||||||
|
sessionStartMetadata: fallback,
|
||||||
|
},
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const closeWs = () => {
|
||||||
|
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
||||||
|
wsRef.current.send(JSON.stringify({ type: 'session.stop', reason: 'debug_drawer_closed' }));
|
||||||
|
}
|
||||||
|
wsRef.current?.close();
|
||||||
|
wsRef.current = null;
|
||||||
|
wsReadyRef.current = false;
|
||||||
|
pendingResolveRef.current = null;
|
||||||
|
pendingRejectRef.current = null;
|
||||||
|
assistantDraftIndexRef.current = null;
|
||||||
|
if (isOpen) setWsStatus('disconnected');
|
||||||
|
};
|
||||||
|
|
||||||
|
const ensureWsSession = async () => {
|
||||||
|
if (wsRef.current && wsReadyRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
pendingResolveRef.current = resolve;
|
||||||
|
pendingRejectRef.current = reject;
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const metadata = await fetchRuntimeMetadata();
|
||||||
|
setWsStatus('connecting');
|
||||||
|
setWsError('');
|
||||||
|
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
pendingResolveRef.current = resolve;
|
||||||
|
pendingRejectRef.current = reject;
|
||||||
|
const ws = new WebSocket(wsUrl);
|
||||||
|
ws.binaryType = 'arraybuffer';
|
||||||
|
wsRef.current = ws;
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
ws.send(JSON.stringify({ type: 'hello', version: 'v1' }));
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
if (typeof event.data !== 'string') return;
|
||||||
|
let payload: any;
|
||||||
|
try {
|
||||||
|
payload = JSON.parse(event.data);
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const type = payload?.type;
|
||||||
|
if (type === 'hello.ack') {
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: 'session.start',
|
||||||
|
audio: { encoding: 'pcm_s16le', sample_rate_hz: 16000, channels: 1 },
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'session.started') {
|
||||||
|
wsReadyRef.current = true;
|
||||||
|
setWsStatus('ready');
|
||||||
|
pendingResolveRef.current?.();
|
||||||
|
pendingResolveRef.current = null;
|
||||||
|
pendingRejectRef.current = null;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'assistant.response.delta') {
|
||||||
|
const delta = String(payload.text || '');
|
||||||
|
if (!delta) return;
|
||||||
|
setMessages((prev) => {
|
||||||
|
const idx = assistantDraftIndexRef.current;
|
||||||
|
if (idx === null || !prev[idx] || prev[idx].role !== 'model') {
|
||||||
|
const next = [...prev, { role: 'model' as const, text: delta }];
|
||||||
|
assistantDraftIndexRef.current = next.length - 1;
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
const next = [...prev];
|
||||||
|
next[idx] = { ...next[idx], text: next[idx].text + delta };
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'assistant.response.final') {
|
||||||
|
const finalText = String(payload.text || '');
|
||||||
|
setMessages((prev) => {
|
||||||
|
const idx = assistantDraftIndexRef.current;
|
||||||
|
assistantDraftIndexRef.current = null;
|
||||||
|
if (idx !== null && prev[idx] && prev[idx].role === 'model') {
|
||||||
|
const next = [...prev];
|
||||||
|
next[idx] = { ...next[idx], text: finalText || next[idx].text };
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
return finalText ? [...prev, { role: 'model', text: finalText }] : prev;
|
||||||
|
});
|
||||||
|
setIsLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'error') {
|
||||||
|
const message = String(payload.message || 'Unknown error');
|
||||||
|
setWsStatus('error');
|
||||||
|
setWsError(message);
|
||||||
|
setIsLoading(false);
|
||||||
|
const err = new Error(message);
|
||||||
|
pendingRejectRef.current?.(err);
|
||||||
|
pendingResolveRef.current = null;
|
||||||
|
pendingRejectRef.current = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = () => {
|
||||||
|
const err = new Error('WebSocket connection error');
|
||||||
|
setWsStatus('error');
|
||||||
|
setWsError(err.message);
|
||||||
|
setIsLoading(false);
|
||||||
|
pendingRejectRef.current?.(err);
|
||||||
|
pendingResolveRef.current = null;
|
||||||
|
pendingRejectRef.current = null;
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
wsReadyRef.current = false;
|
||||||
|
if (wsStatus !== 'error') setWsStatus('disconnected');
|
||||||
|
};
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
const TranscriptionLog = () => (
|
const TranscriptionLog = () => (
|
||||||
<div ref={scrollRef} className="flex-1 overflow-y-auto space-y-4 p-2 border border-white/5 rounded-md bg-black/20 min-h-0">
|
<div ref={scrollRef} className="flex-1 overflow-y-auto space-y-4 p-2 border border-white/5 rounded-md bg-black/20 min-h-0">
|
||||||
{messages.length === 0 && <div className="text-center text-muted-foreground text-xs py-4">暂无转写记录</div>}
|
{messages.length === 0 && <div className="text-center text-muted-foreground text-xs py-4">暂无转写记录</div>}
|
||||||
@@ -1205,7 +1391,38 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assis
|
|||||||
|
|
||||||
<div className="flex-1 overflow-hidden flex flex-col min-h-0 mb-4">
|
<div className="flex-1 overflow-hidden flex flex-col min-h-0 mb-4">
|
||||||
{mode === 'text' ? (
|
{mode === 'text' ? (
|
||||||
|
<div className="flex flex-col gap-2 h-full min-h-0">
|
||||||
|
<div className="shrink-0 rounded-md border border-white/10 bg-white/5 p-2 grid grid-cols-1 md:grid-cols-3 gap-2">
|
||||||
|
<Input value={wsUrl} onChange={(e) => setWsUrl(e.target.value)} placeholder="ws://localhost:8000/ws" />
|
||||||
|
<div className="md:col-span-2 flex items-center gap-2">
|
||||||
|
<Badge variant="outline" className="text-xs">
|
||||||
|
WS: {wsStatus}
|
||||||
|
</Badge>
|
||||||
|
<Button size="sm" variant="secondary" onClick={() => ensureWsSession()} disabled={wsStatus === 'connecting'}>
|
||||||
|
Connect
|
||||||
|
</Button>
|
||||||
|
<Button size="sm" variant="ghost" onClick={closeWs}>
|
||||||
|
Disconnect
|
||||||
|
</Button>
|
||||||
|
{wsError && <span className="text-xs text-red-400 truncate">{wsError}</span>}
|
||||||
|
</div>
|
||||||
|
<div className="md:col-span-3 rounded-md border border-white/10 bg-black/20">
|
||||||
|
<button
|
||||||
|
className="w-full px-3 py-2 text-left text-xs text-muted-foreground hover:text-foreground flex items-center justify-between"
|
||||||
|
onClick={() => setResolvedConfigOpen((v) => !v)}
|
||||||
|
>
|
||||||
|
<span>View Resolved Runtime Config (read-only)</span>
|
||||||
|
<ChevronDown className={`h-3.5 w-3.5 transition-transform ${resolvedConfigOpen ? 'rotate-180' : ''}`} />
|
||||||
|
</button>
|
||||||
|
{resolvedConfigOpen && (
|
||||||
|
<pre className="px-3 pb-3 text-[11px] leading-5 text-cyan-100/90 whitespace-pre-wrap break-all max-h-52 overflow-auto">
|
||||||
|
{resolvedConfigView || 'Connect to load resolved config...'}
|
||||||
|
</pre>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<TranscriptionLog />
|
<TranscriptionLog />
|
||||||
|
</div>
|
||||||
) : callStatus === 'idle' ? (
|
) : callStatus === 'idle' ? (
|
||||||
<div className="flex-1 flex flex-col items-center justify-center space-y-6 border border-white/5 rounded-xl bg-black/20 animate-in fade-in zoom-in-95">
|
<div className="flex-1 flex flex-col items-center justify-center space-y-6 border border-white/5 rounded-xl bg-black/20 animate-in fade-in zoom-in-95">
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({
|
|||||||
configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none',
|
configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none',
|
||||||
apiUrl: readField(raw, ['apiUrl', 'api_url'], ''),
|
apiUrl: readField(raw, ['apiUrl', 'api_url'], ''),
|
||||||
apiKey: readField(raw, ['apiKey', 'api_key'], ''),
|
apiKey: readField(raw, ['apiKey', 'api_key'], ''),
|
||||||
|
llmModelId: readField(raw, ['llmModelId', 'llm_model_id'], ''),
|
||||||
|
asrModelId: readField(raw, ['asrModelId', 'asr_model_id'], ''),
|
||||||
|
embeddingModelId: readField(raw, ['embeddingModelId', 'embedding_model_id'], ''),
|
||||||
|
rerankModelId: readField(raw, ['rerankModelId', 'rerank_model_id'], ''),
|
||||||
});
|
});
|
||||||
|
|
||||||
const mapVoice = (raw: AnyRecord): Voice => ({
|
const mapVoice = (raw: AnyRecord): Voice => ({
|
||||||
@@ -218,6 +222,10 @@ export const updateAssistant = async (id: string, data: Partial<Assistant>): Pro
|
|||||||
configMode: data.configMode,
|
configMode: data.configMode,
|
||||||
apiUrl: data.apiUrl,
|
apiUrl: data.apiUrl,
|
||||||
apiKey: data.apiKey,
|
apiKey: data.apiKey,
|
||||||
|
llmModelId: data.llmModelId,
|
||||||
|
asrModelId: data.asrModelId,
|
||||||
|
embeddingModelId: data.embeddingModelId,
|
||||||
|
rerankModelId: data.rerankModelId,
|
||||||
};
|
};
|
||||||
const response = await apiRequest<AnyRecord>(`/assistants/${id}`, { method: 'PUT', body: payload });
|
const response = await apiRequest<AnyRecord>(`/assistants/${id}`, { method: 'PUT', body: payload });
|
||||||
return mapAssistant(response);
|
return mapAssistant(response);
|
||||||
@@ -227,6 +235,21 @@ export const deleteAssistant = async (id: string): Promise<void> => {
|
|||||||
await apiRequest(`/assistants/${id}`, { method: 'DELETE' });
|
await apiRequest(`/assistants/${id}`, { method: 'DELETE' });
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface AssistantRuntimeConfigResponse {
|
||||||
|
assistantId: string;
|
||||||
|
sessionStartMetadata: Record<string, any>;
|
||||||
|
sources?: {
|
||||||
|
llmModelId?: string;
|
||||||
|
asrModelId?: string;
|
||||||
|
voiceId?: string;
|
||||||
|
};
|
||||||
|
warnings?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise<AssistantRuntimeConfigResponse> => {
|
||||||
|
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/runtime-config`);
|
||||||
|
};
|
||||||
|
|
||||||
export const fetchVoices = async (): Promise<Voice[]> => {
|
export const fetchVoices = async (): Promise<Voice[]> => {
|
||||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices');
|
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices');
|
||||||
const list = Array.isArray(response) ? response : (response.list || []);
|
const list = Array.isArray(response) ? response : (response.list || []);
|
||||||
|
|||||||
Reference in New Issue
Block a user