Implement WS v1 protocol and runtime-config powered debug drawer
This commit is contained in:
@@ -13,7 +13,7 @@ event-driven design.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from typing import Optional, Callable, Awaitable, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
@@ -28,6 +28,7 @@ from services.asr import BufferedASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from app.config import settings
|
||||
from models.ws_v1 import ev
|
||||
|
||||
|
||||
class DuplexPipeline:
|
||||
@@ -126,19 +127,66 @@ class DuplexPipeline:
|
||||
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||||
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
||||
|
||||
# Runtime overrides injected from session.start metadata
|
||||
self._runtime_llm: Dict[str, Any] = {}
|
||||
self._runtime_asr: Dict[str, Any] = {}
|
||||
self._runtime_tts: Dict[str, Any] = {}
|
||||
self._runtime_system_prompt: Optional[str] = None
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Apply runtime overrides from WS session.start metadata.
|
||||
|
||||
Expected metadata shape:
|
||||
{
|
||||
"systemPrompt": "...",
|
||||
"greeting": "...",
|
||||
"services": {
|
||||
"llm": {...},
|
||||
"asr": {...},
|
||||
"tts": {...}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
if "systemPrompt" in metadata:
|
||||
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
|
||||
if self._runtime_system_prompt:
|
||||
self.conversation.system_prompt = self._runtime_system_prompt
|
||||
if "greeting" in metadata:
|
||||
self._runtime_greeting = str(metadata.get("greeting") or "")
|
||||
self.conversation.greeting = self._runtime_greeting or None
|
||||
|
||||
services = metadata.get("services") or {}
|
||||
if isinstance(services, dict):
|
||||
if isinstance(services.get("llm"), dict):
|
||||
self._runtime_llm = services["llm"]
|
||||
if isinstance(services.get("asr"), dict):
|
||||
self._runtime_asr = services["asr"]
|
||||
if isinstance(services.get("tts"), dict):
|
||||
self._runtime_tts = services["tts"]
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
if settings.openai_api_key:
|
||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
||||
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
||||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||||
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
||||
|
||||
if llm_provider == "openai" and llm_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=settings.openai_api_key,
|
||||
base_url=settings.openai_api_url,
|
||||
model=settings.llm_model
|
||||
api_key=llm_api_key,
|
||||
base_url=llm_base_url,
|
||||
model=llm_model
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
@@ -148,33 +196,52 @@ class DuplexPipeline:
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
|
||||
if tts_provider == "siliconflow" and tts_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
voice=settings.tts_voice,
|
||||
model=settings.siliconflow_tts_model,
|
||||
api_key=tts_api_key,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=settings.tts_speed
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=settings.tts_voice,
|
||||
voice=tts_voice,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
await self.tts_service.connect()
|
||||
|
||||
try:
|
||||
await self.tts_service.connect()
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
|
||||
self.tts_service = MockTTSService(
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
await self.tts_service.connect()
|
||||
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
||||
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||
|
||||
if asr_provider == "siliconflow" and asr_api_key:
|
||||
self.asr_service = SiliconFlowASRService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
model=settings.siliconflow_asr_model,
|
||||
api_key=asr_api_key,
|
||||
model=asr_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
interim_interval_ms=settings.asr_interim_interval_ms,
|
||||
min_audio_for_interim_ms=settings.asr_min_audio_ms,
|
||||
interim_interval_ms=asr_interim_interval,
|
||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback
|
||||
)
|
||||
logger.info("Using SiliconFlow ASR service")
|
||||
@@ -223,6 +290,13 @@ class DuplexPipeline:
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||||
trackId=self.session_id,
|
||||
probability=probability,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No state change - keep previous status
|
||||
vad_status = self._last_vad_status
|
||||
@@ -325,11 +399,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send transcript event to client
|
||||
await self.transport.send_event({
|
||||
"event": "transcript",
|
||||
"trackId": self.session_id,
|
||||
"text": text,
|
||||
"isFinal": is_final,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"transcript.final" if is_final else "transcript.delta",
|
||||
trackId=self.session_id,
|
||||
text=text,
|
||||
)
|
||||
})
|
||||
|
||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||
@@ -383,11 +457,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send final transcription to client
|
||||
await self.transport.send_event({
|
||||
"event": "transcript",
|
||||
"trackId": self.session_id,
|
||||
"text": user_text,
|
||||
"isFinal": True,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"transcript.final",
|
||||
trackId=self.session_id,
|
||||
text=user_text,
|
||||
)
|
||||
})
|
||||
|
||||
# Clear buffers
|
||||
@@ -438,11 +512,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send LLM response streaming event to client
|
||||
await self.transport.send_event({
|
||||
"event": "llmResponse",
|
||||
"trackId": self.session_id,
|
||||
"text": text_chunk,
|
||||
"isFinal": False,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
})
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
@@ -462,9 +536,10 @@ class DuplexPipeline:
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
first_audio_sent = True
|
||||
|
||||
@@ -476,20 +551,21 @@ class DuplexPipeline:
|
||||
# Send final LLM response event
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self.transport.send_event({
|
||||
"event": "llmResponse",
|
||||
"trackId": self.session_id,
|
||||
"text": full_response,
|
||||
"isFinal": True,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
text=full_response,
|
||||
)
|
||||
})
|
||||
|
||||
# Speak any remaining text
|
||||
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(sentence_buffer.strip())
|
||||
@@ -497,9 +573,10 @@ class DuplexPipeline:
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
# End assistant turn
|
||||
@@ -545,10 +622,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
"event": "ttfb",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"latencyMs": round(ttfb_ms)
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
})
|
||||
|
||||
# Double-check interrupt right before sending audio
|
||||
@@ -579,9 +657,10 @@ class DuplexPipeline:
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
self._is_bot_speaking = True
|
||||
@@ -600,10 +679,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
"event": "ttfb",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"latencyMs": round(ttfb_ms)
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
})
|
||||
|
||||
# Send audio to client
|
||||
@@ -614,9 +694,10 @@ class DuplexPipeline:
|
||||
|
||||
# Send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -646,9 +727,10 @@ class DuplexPipeline:
|
||||
# Send interrupt event to client IMMEDIATELY
|
||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
**ev(
|
||||
"response.interrupted",
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
# Cancel TTS
|
||||
|
||||
@@ -2,13 +2,31 @@
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||
from app.config import settings
|
||||
from models.ws_v1 import (
|
||||
parse_client_message,
|
||||
ev,
|
||||
HelloMessage,
|
||||
SessionStartMessage,
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
)
|
||||
|
||||
|
||||
class WsSessionState(str, Enum):
|
||||
"""Protocol state machine for WS sessions."""
|
||||
|
||||
WAIT_HELLO = "wait_hello"
|
||||
WAIT_START = "wait_start"
|
||||
ACTIVE = "active"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class Session:
|
||||
@@ -41,8 +59,11 @@ class Session:
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||
self.state = "created" # Legacy call state for /call/lists
|
||||
self.ws_state = WsSessionState.WAIT_HELLO
|
||||
self._pipeline_started = False
|
||||
self.protocol_version: Optional[str] = None
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
@@ -51,69 +72,27 @@ class Session:
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (JSON commands).
|
||||
Handle incoming text data (WS v1 JSON control messages).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
command = parse_command(data)
|
||||
command_type = command.command
|
||||
|
||||
logger.info(f"Session {self.id} received command: {command_type}")
|
||||
|
||||
# Route command to appropriate handler
|
||||
if command_type == "invite":
|
||||
await self._handle_invite(data)
|
||||
|
||||
elif command_type == "accept":
|
||||
await self._handle_accept(data)
|
||||
|
||||
elif command_type == "reject":
|
||||
await self._handle_reject(data)
|
||||
|
||||
elif command_type == "ringing":
|
||||
await self._handle_ringing(data)
|
||||
|
||||
elif command_type == "tts":
|
||||
await self._handle_tts(command)
|
||||
|
||||
elif command_type == "play":
|
||||
await self._handle_play(data)
|
||||
|
||||
elif command_type == "interrupt":
|
||||
await self._handle_interrupt(command)
|
||||
|
||||
elif command_type == "pause":
|
||||
await self._handle_pause()
|
||||
|
||||
elif command_type == "resume":
|
||||
await self._handle_resume()
|
||||
|
||||
elif command_type == "hangup":
|
||||
await self._handle_hangup(command)
|
||||
|
||||
elif command_type == "history":
|
||||
await self._handle_history(data)
|
||||
|
||||
elif command_type == "chat":
|
||||
await self._handle_chat(command)
|
||||
|
||||
else:
|
||||
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
||||
message = parse_client_message(data)
|
||||
await self._handle_v1_message(message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid command: {e}")
|
||||
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}")
|
||||
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
@@ -122,156 +101,166 @@ class Session:
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle invite command."""
|
||||
self.state = "invited"
|
||||
option = data.get("option", {})
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
msg_type = message.type
|
||||
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||
|
||||
# Send answer event
|
||||
await self.transport.send_event({
|
||||
"event": "answer",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
if isinstance(message, HelloMessage):
|
||||
await self._handle_hello(message)
|
||||
return
|
||||
|
||||
# All messages below require hello handshake first
|
||||
if self.ws_state == WsSessionState.WAIT_HELLO:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Expected hello message first",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, SessionStartMessage):
|
||||
await self._handle_session_start(message)
|
||||
return
|
||||
|
||||
# All messages below require active session
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Message '{msg_type}' requires active session",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, InputTextMessage):
|
||||
await self.pipeline.process_text(message.text)
|
||||
elif isinstance(message, ResponseCancelMessage):
|
||||
if message.graceful:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||
|
||||
async def _handle_hello(self, message: HelloMessage) -> None:
|
||||
"""Handle initial hello/auth/version negotiation."""
|
||||
if self.ws_state != WsSessionState.WAIT_HELLO:
|
||||
await self._send_error("client", "Duplicate hello", "protocol.order")
|
||||
return
|
||||
|
||||
if message.version != settings.ws_protocol_version:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Unsupported protocol version '{message.version}'",
|
||||
"protocol.version_unsupported",
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
elif settings.ws_require_auth and not (api_key or jwt):
|
||||
await self._send_error("auth", "Authentication required", "auth.required")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||
"""Handle explicit session start after successful hello."""
|
||||
if self.ws_state != WsSessionState.WAIT_START:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
# Apply runtime service/prompt overrides from backend if provided
|
||||
self.pipeline.apply_runtime_overrides(message.metadata)
|
||||
|
||||
# Start duplex pipeline
|
||||
if not self._pipeline_started:
|
||||
try:
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start duplex pipeline: {e}")
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
|
||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||
|
||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle accept command."""
|
||||
self.state = "accepted"
|
||||
logger.info(f"Session {self.id} accepted")
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle reject command."""
|
||||
self.state = "rejected"
|
||||
reason = data.get("reason", "Rejected")
|
||||
logger.info(f"Session {self.id} rejected: {reason}")
|
||||
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
||||
"""Handle session stop."""
|
||||
if self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle ringing command."""
|
||||
self.state = "ringing"
|
||||
logger.info(f"Session {self.id} ringing")
|
||||
|
||||
async def _handle_tts(self, command: TTSCommand) -> None:
|
||||
"""Handle TTS command."""
|
||||
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
# TODO: Implement actual TTS synthesis
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle play command."""
|
||||
url = data.get("url", "")
|
||||
logger.info(f"Session {self.id} play: {url}")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": url
|
||||
})
|
||||
|
||||
# TODO: Implement actual audio playback
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": url
|
||||
})
|
||||
|
||||
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
||||
"""Handle interrupt command."""
|
||||
if command.graceful:
|
||||
logger.info(f"Session {self.id} graceful interrupt")
|
||||
else:
|
||||
logger.info(f"Session {self.id} immediate interrupt")
|
||||
await self.pipeline.interrupt()
|
||||
|
||||
async def _handle_pause(self) -> None:
|
||||
"""Handle pause command."""
|
||||
logger.info(f"Session {self.id} paused")
|
||||
|
||||
async def _handle_resume(self) -> None:
|
||||
"""Handle resume command."""
|
||||
logger.info(f"Session {self.id} resumed")
|
||||
|
||||
async def _handle_hangup(self, command: HangupCommand) -> None:
|
||||
"""Handle hangup command."""
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
reason = command.reason or "User requested"
|
||||
logger.info(f"Session {self.id} hung up: {reason}")
|
||||
|
||||
# Send hangup event
|
||||
await self.transport.send_event({
|
||||
"event": "hangup",
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"reason": reason,
|
||||
"initiator": command.initiator or "user"
|
||||
})
|
||||
|
||||
# Close transport
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self.transport.close()
|
||||
|
||||
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle history command."""
|
||||
speaker = data.get("speaker", "unknown")
|
||||
text = data.get("text", "")
|
||||
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
||||
|
||||
async def _handle_chat(self, command: ChatCommand) -> None:
|
||||
"""Handle chat command."""
|
||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||
await self.pipeline.process_text(command.text)
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": sender,
|
||||
"error": error_message
|
||||
})
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
|
||||
Reference in New Issue
Block a user