275 lines
9.1 KiB
Python
275 lines
9.1 KiB
Python
"""Session management for active calls."""
|
|
|
|
import uuid
|
|
import json
|
|
from enum import Enum
|
|
from typing import Optional, Dict, Any
|
|
from loguru import logger
|
|
|
|
from core.transports import BaseTransport
|
|
from core.duplex_pipeline import DuplexPipeline
|
|
from app.config import settings
|
|
from models.ws_v1 import (
|
|
parse_client_message,
|
|
ev,
|
|
HelloMessage,
|
|
SessionStartMessage,
|
|
SessionStopMessage,
|
|
InputTextMessage,
|
|
ResponseCancelMessage,
|
|
)
|
|
|
|
|
|
class WsSessionState(str, Enum):
|
|
"""Protocol state machine for WS sessions."""
|
|
|
|
WAIT_HELLO = "wait_hello"
|
|
WAIT_START = "wait_start"
|
|
ACTIVE = "active"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Session:
|
|
"""
|
|
Manages a single call session.
|
|
|
|
Handles command routing, audio processing, and session lifecycle.
|
|
Uses full duplex voice conversation pipeline.
|
|
"""
|
|
|
|
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
|
"""
|
|
Initialize session.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
transport: Transport instance for communication
|
|
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
|
"""
|
|
self.id = session_id
|
|
self.transport = transport
|
|
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
|
|
|
self.pipeline = DuplexPipeline(
|
|
transport=transport,
|
|
session_id=session_id,
|
|
system_prompt=settings.duplex_system_prompt,
|
|
greeting=settings.duplex_greeting
|
|
)
|
|
|
|
# Session state
|
|
self.created_at = None
|
|
self.state = "created" # Legacy call state for /call/lists
|
|
self.ws_state = WsSessionState.WAIT_HELLO
|
|
self._pipeline_started = False
|
|
self.protocol_version: Optional[str] = None
|
|
self.authenticated: bool = False
|
|
|
|
# Track IDs
|
|
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
|
|
|
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
|
|
|
async def handle_text(self, text_data: str) -> None:
|
|
"""
|
|
Handle incoming text data (WS v1 JSON control messages).
|
|
|
|
Args:
|
|
text_data: JSON text data
|
|
"""
|
|
try:
|
|
data = json.loads(text_data)
|
|
message = parse_client_message(data)
|
|
await self._handle_v1_message(message)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Session {self.id} JSON decode error: {e}")
|
|
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Session {self.id} command parse error: {e}")
|
|
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
|
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
|
|
|
async def handle_audio(self, audio_bytes: bytes) -> None:
|
|
"""
|
|
Handle incoming audio data.
|
|
|
|
Args:
|
|
audio_bytes: PCM audio data
|
|
"""
|
|
if self.ws_state != WsSessionState.ACTIVE:
|
|
await self._send_error(
|
|
"client",
|
|
"Audio received before session.start",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self.pipeline.process_audio(audio_bytes)
|
|
except Exception as e:
|
|
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
|
|
|
async def _handle_v1_message(self, message: Any) -> None:
|
|
"""Route validated WS v1 message to handlers."""
|
|
msg_type = message.type
|
|
logger.info(f"Session {self.id} received message: {msg_type}")
|
|
|
|
if isinstance(message, HelloMessage):
|
|
await self._handle_hello(message)
|
|
return
|
|
|
|
# All messages below require hello handshake first
|
|
if self.ws_state == WsSessionState.WAIT_HELLO:
|
|
await self._send_error(
|
|
"client",
|
|
"Expected hello message first",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
if isinstance(message, SessionStartMessage):
|
|
await self._handle_session_start(message)
|
|
return
|
|
|
|
# All messages below require active session
|
|
if self.ws_state != WsSessionState.ACTIVE:
|
|
await self._send_error(
|
|
"client",
|
|
f"Message '{msg_type}' requires active session",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
if isinstance(message, InputTextMessage):
|
|
await self.pipeline.process_text(message.text)
|
|
elif isinstance(message, ResponseCancelMessage):
|
|
if message.graceful:
|
|
logger.info(f"Session {self.id} graceful response.cancel")
|
|
else:
|
|
await self.pipeline.interrupt()
|
|
elif isinstance(message, SessionStopMessage):
|
|
await self._handle_session_stop(message.reason)
|
|
else:
|
|
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
|
|
|
async def _handle_hello(self, message: HelloMessage) -> None:
|
|
"""Handle initial hello/auth/version negotiation."""
|
|
if self.ws_state != WsSessionState.WAIT_HELLO:
|
|
await self._send_error("client", "Duplicate hello", "protocol.order")
|
|
return
|
|
|
|
if message.version != settings.ws_protocol_version:
|
|
await self._send_error(
|
|
"client",
|
|
f"Unsupported protocol version '{message.version}'",
|
|
"protocol.version_unsupported",
|
|
)
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
|
|
auth_payload = message.auth or {}
|
|
api_key = auth_payload.get("apiKey")
|
|
jwt = auth_payload.get("jwt")
|
|
|
|
if settings.ws_api_key:
|
|
if api_key != settings.ws_api_key:
|
|
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
elif settings.ws_require_auth and not (api_key or jwt):
|
|
await self._send_error("auth", "Authentication required", "auth.required")
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
|
|
self.authenticated = True
|
|
self.protocol_version = message.version
|
|
self.ws_state = WsSessionState.WAIT_START
|
|
await self.transport.send_event(
|
|
ev(
|
|
"hello.ack",
|
|
sessionId=self.id,
|
|
version=self.protocol_version,
|
|
)
|
|
)
|
|
|
|
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
|
"""Handle explicit session start after successful hello."""
|
|
if self.ws_state != WsSessionState.WAIT_START:
|
|
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
|
return
|
|
|
|
# Apply runtime service/prompt overrides from backend if provided
|
|
self.pipeline.apply_runtime_overrides(message.metadata)
|
|
|
|
# Start duplex pipeline
|
|
if not self._pipeline_started:
|
|
await self.pipeline.start()
|
|
self._pipeline_started = True
|
|
logger.info(f"Session {self.id} duplex pipeline started")
|
|
|
|
self.state = "accepted"
|
|
self.ws_state = WsSessionState.ACTIVE
|
|
await self.transport.send_event(
|
|
ev(
|
|
"session.started",
|
|
sessionId=self.id,
|
|
trackId=self.current_track_id,
|
|
audio=message.audio or {},
|
|
)
|
|
)
|
|
|
|
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
|
"""Handle session stop."""
|
|
if self.ws_state == WsSessionState.STOPPED:
|
|
return
|
|
|
|
stop_reason = reason or "client_requested"
|
|
self.state = "hungup"
|
|
self.ws_state = WsSessionState.STOPPED
|
|
await self.transport.send_event(
|
|
ev(
|
|
"session.stopped",
|
|
sessionId=self.id,
|
|
reason=stop_reason,
|
|
)
|
|
)
|
|
await self.transport.close()
|
|
|
|
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
|
"""
|
|
Send error event to client.
|
|
|
|
Args:
|
|
sender: Component that generated the error
|
|
error_message: Error message
|
|
code: Machine-readable error code
|
|
"""
|
|
await self.transport.send_event(
|
|
ev(
|
|
"error",
|
|
sender=sender,
|
|
code=code,
|
|
message=error_message,
|
|
trackId=self.current_track_id,
|
|
)
|
|
)
|
|
|
|
def _get_timestamp_ms(self) -> int:
|
|
"""Get current timestamp in milliseconds."""
|
|
import time
|
|
return int(time.time() * 1000)
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Cleanup session resources."""
|
|
logger.info(f"Session {self.id} cleaning up")
|
|
await self.pipeline.cleanup()
|
|
await self.transport.close()
|