"""Session management for active calls.""" import uuid import json from enum import Enum from typing import Optional, Dict, Any from loguru import logger from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from app.config import settings from models.ws_v1 import ( parse_client_message, ev, HelloMessage, SessionStartMessage, SessionStopMessage, InputTextMessage, ResponseCancelMessage, ) class WsSessionState(str, Enum): """Protocol state machine for WS sessions.""" WAIT_HELLO = "wait_hello" WAIT_START = "wait_start" ACTIVE = "active" STOPPED = "stopped" class Session: """ Manages a single call session. Handles command routing, audio processing, and session lifecycle. Uses full duplex voice conversation pipeline. """ def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): """ Initialize session. Args: session_id: Unique session identifier transport: Transport instance for communication use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled) """ self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled self.pipeline = DuplexPipeline( transport=transport, session_id=session_id, system_prompt=settings.duplex_system_prompt, greeting=settings.duplex_greeting ) # Session state self.created_at = None self.state = "created" # Legacy call state for /call/lists self.ws_state = WsSessionState.WAIT_HELLO self._pipeline_started = False self.protocol_version: Optional[str] = None self.authenticated: bool = False # Track IDs self.current_track_id: Optional[str] = str(uuid.uuid4()) logger.info(f"Session {self.id} created (duplex={self.use_duplex})") async def handle_text(self, text_data: str) -> None: """ Handle incoming text data (WS v1 JSON control messages). Args: text_data: JSON text data """ try: data = json.loads(text_data) message = parse_client_message(data) await self._handle_v1_message(message) except json.JSONDecodeError as e: logger.error(f"Session {self.id} JSON decode error: {e}") await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json") except ValueError as e: logger.error(f"Session {self.id} command parse error: {e}") await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message") except Exception as e: logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) await self._send_error("server", f"Internal error: {e}", "server.internal") async def handle_audio(self, audio_bytes: bytes) -> None: """ Handle incoming audio data. Args: audio_bytes: PCM audio data """ if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", "Audio received before session.start", "protocol.order", ) return try: await self.pipeline.process_audio(audio_bytes) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) async def _handle_v1_message(self, message: Any) -> None: """Route validated WS v1 message to handlers.""" msg_type = message.type logger.info(f"Session {self.id} received message: {msg_type}") if isinstance(message, HelloMessage): await self._handle_hello(message) return # All messages below require hello handshake first if self.ws_state == WsSessionState.WAIT_HELLO: await self._send_error( "client", "Expected hello message first", "protocol.order", ) return if isinstance(message, SessionStartMessage): await self._handle_session_start(message) return # All messages below require active session if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", f"Message '{msg_type}' requires active session", "protocol.order", ) return if isinstance(message, InputTextMessage): await self.pipeline.process_text(message.text) elif isinstance(message, ResponseCancelMessage): if message.graceful: logger.info(f"Session {self.id} graceful response.cancel") else: await self.pipeline.interrupt() elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") async def _handle_hello(self, message: HelloMessage) -> None: """Handle initial hello/auth/version negotiation.""" if self.ws_state != WsSessionState.WAIT_HELLO: await self._send_error("client", "Duplicate hello", "protocol.order") return if message.version != settings.ws_protocol_version: await self._send_error( "client", f"Unsupported protocol version '{message.version}'", "protocol.version_unsupported", ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return auth_payload = message.auth or {} api_key = auth_payload.get("apiKey") jwt = auth_payload.get("jwt") if settings.ws_api_key: if api_key != settings.ws_api_key: await self._send_error("auth", "Invalid API key", "auth.invalid_api_key") await self.transport.close() self.ws_state = WsSessionState.STOPPED return elif settings.ws_require_auth and not (api_key or jwt): await self._send_error("auth", "Authentication required", "auth.required") await self.transport.close() self.ws_state = WsSessionState.STOPPED return self.authenticated = True self.protocol_version = message.version self.ws_state = WsSessionState.WAIT_START await self.transport.send_event( ev( "hello.ack", sessionId=self.id, version=self.protocol_version, ) ) async def _handle_session_start(self, message: SessionStartMessage) -> None: """Handle explicit session start after successful hello.""" if self.ws_state != WsSessionState.WAIT_START: await self._send_error("client", "Duplicate session.start", "protocol.order") return # Apply runtime service/prompt overrides from backend if provided self.pipeline.apply_runtime_overrides(message.metadata) # Start duplex pipeline if not self._pipeline_started: await self.pipeline.start() self._pipeline_started = True logger.info(f"Session {self.id} duplex pipeline started") self.state = "accepted" self.ws_state = WsSessionState.ACTIVE await self.transport.send_event( ev( "session.started", sessionId=self.id, trackId=self.current_track_id, audio=message.audio or {}, ) ) async def _handle_session_stop(self, reason: Optional[str]) -> None: """Handle session stop.""" if self.ws_state == WsSessionState.STOPPED: return stop_reason = reason or "client_requested" self.state = "hungup" self.ws_state = WsSessionState.STOPPED await self.transport.send_event( ev( "session.stopped", sessionId=self.id, reason=stop_reason, ) ) await self.transport.close() async def _send_error(self, sender: str, error_message: str, code: str) -> None: """ Send error event to client. Args: sender: Component that generated the error error_message: Error message code: Machine-readable error code """ await self.transport.send_event( ev( "error", sender=sender, code=code, message=error_message, trackId=self.current_track_id, ) ) def _get_timestamp_ms(self) -> int: """Get current timestamp in milliseconds.""" import time return int(time.time() * 1000) async def cleanup(self) -> None: """Cleanup session resources.""" logger.info(f"Session {self.id} cleaning up") await self.pipeline.cleanup() await self.transport.close()