"""Session management for active calls.""" import asyncio import uuid import json import time import re from enum import Enum from typing import Optional, Dict, Any, List from loguru import logger from app.backend_client import ( create_history_call_record, add_history_transcript, finalize_history_call_record, ) from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from core.conversation import ConversationTurn from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef from app.config import settings from services.base import LLMMessage from models.ws_v1 import ( parse_client_message, ev, HelloMessage, SessionStartMessage, SessionStopMessage, InputTextMessage, ResponseCancelMessage, ) class WsSessionState(str, Enum): """Protocol state machine for WS sessions.""" WAIT_HELLO = "wait_hello" WAIT_START = "wait_start" ACTIVE = "active" STOPPED = "stopped" class Session: """ Manages a single call session. Handles command routing, audio processing, and session lifecycle. Uses full duplex voice conversation pipeline. """ def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): """ Initialize session. Args: session_id: Unique session identifier transport: Transport instance for communication use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled) """ self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled self.pipeline = DuplexPipeline( transport=transport, session_id=session_id, system_prompt=settings.duplex_system_prompt, greeting=settings.duplex_greeting ) # Session state self.created_at = None self.state = "created" # Legacy call state for /call/lists self.ws_state = WsSessionState.WAIT_HELLO self._pipeline_started = False self.protocol_version: Optional[str] = None self.authenticated: bool = False # Track IDs self.current_track_id: Optional[str] = str(uuid.uuid4()) self._history_call_id: Optional[str] = None self._history_turn_index: int = 0 self._history_call_started_mono: Optional[float] = None self._history_finalized: bool = False self._cleanup_lock = asyncio.Lock() self._cleaned_up = False self.workflow_runner: Optional[WorkflowRunner] = None self._workflow_last_user_text: str = "" self._workflow_initial_node: Optional[WorkflowNodeDef] = None self.pipeline.conversation.on_turn_complete(self._on_turn_complete) logger.info(f"Session {self.id} created (duplex={self.use_duplex})") async def handle_text(self, text_data: str) -> None: """ Handle incoming text data (WS v1 JSON control messages). Args: text_data: JSON text data """ try: data = json.loads(text_data) message = parse_client_message(data) await self._handle_v1_message(message) except json.JSONDecodeError as e: logger.error(f"Session {self.id} JSON decode error: {e}") await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json") except ValueError as e: logger.error(f"Session {self.id} command parse error: {e}") await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message") except Exception as e: logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) await self._send_error("server", f"Internal error: {e}", "server.internal") async def handle_audio(self, audio_bytes: bytes) -> None: """ Handle incoming audio data. Args: audio_bytes: PCM audio data """ if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", "Audio received before session.start", "protocol.order", ) return try: await self.pipeline.process_audio(audio_bytes) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) async def _handle_v1_message(self, message: Any) -> None: """Route validated WS v1 message to handlers.""" msg_type = message.type logger.info(f"Session {self.id} received message: {msg_type}") if isinstance(message, HelloMessage): await self._handle_hello(message) return # All messages below require hello handshake first if self.ws_state == WsSessionState.WAIT_HELLO: await self._send_error( "client", "Expected hello message first", "protocol.order", ) return if isinstance(message, SessionStartMessage): await self._handle_session_start(message) return # All messages below require active session if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", f"Message '{msg_type}' requires active session", "protocol.order", ) return if isinstance(message, InputTextMessage): await self.pipeline.process_text(message.text) elif isinstance(message, ResponseCancelMessage): if message.graceful: logger.info(f"Session {self.id} graceful response.cancel") else: await self.pipeline.interrupt() elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") async def _handle_hello(self, message: HelloMessage) -> None: """Handle initial hello/auth/version negotiation.""" if self.ws_state != WsSessionState.WAIT_HELLO: await self._send_error("client", "Duplicate hello", "protocol.order") return if message.version != settings.ws_protocol_version: await self._send_error( "client", f"Unsupported protocol version '{message.version}'", "protocol.version_unsupported", ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return auth_payload = message.auth or {} api_key = auth_payload.get("apiKey") jwt = auth_payload.get("jwt") if settings.ws_api_key: if api_key != settings.ws_api_key: await self._send_error("auth", "Invalid API key", "auth.invalid_api_key") await self.transport.close() self.ws_state = WsSessionState.STOPPED return elif settings.ws_require_auth and not (api_key or jwt): await self._send_error("auth", "Authentication required", "auth.required") await self.transport.close() self.ws_state = WsSessionState.STOPPED return self.authenticated = True self.protocol_version = message.version self.ws_state = WsSessionState.WAIT_START await self.transport.send_event( ev( "hello.ack", sessionId=self.id, version=self.protocol_version, ) ) async def _handle_session_start(self, message: SessionStartMessage) -> None: """Handle explicit session start after successful hello.""" if self.ws_state != WsSessionState.WAIT_START: await self._send_error("client", "Duplicate session.start", "protocol.order") return metadata = message.metadata or {} metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) # Create history call record early so later turn callbacks can append transcripts. await self._start_history_bridge(metadata) # Apply runtime service/prompt overrides from backend if provided self.pipeline.apply_runtime_overrides(metadata) # Start duplex pipeline if not self._pipeline_started: await self.pipeline.start() self._pipeline_started = True logger.info(f"Session {self.id} duplex pipeline started") self.state = "accepted" self.ws_state = WsSessionState.ACTIVE await self.transport.send_event( ev( "session.started", sessionId=self.id, trackId=self.current_track_id, audio=message.audio or {}, ) ) if self.workflow_runner and self._workflow_initial_node: await self.transport.send_event( ev( "workflow.started", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, workflowName=self.workflow_runner.name, nodeId=self._workflow_initial_node.id, ) ) await self.transport.send_event( ev( "workflow.node.entered", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=self._workflow_initial_node.id, nodeName=self._workflow_initial_node.name, nodeType=self._workflow_initial_node.node_type, ) ) async def _handle_session_stop(self, reason: Optional[str]) -> None: """Handle session stop.""" if self.ws_state == WsSessionState.STOPPED: return stop_reason = reason or "client_requested" self.state = "hungup" self.ws_state = WsSessionState.STOPPED await self.transport.send_event( ev( "session.stopped", sessionId=self.id, reason=stop_reason, ) ) await self._finalize_history(status="connected") await self.transport.close() async def _send_error(self, sender: str, error_message: str, code: str) -> None: """ Send error event to client. Args: sender: Component that generated the error error_message: Error message code: Machine-readable error code """ await self.transport.send_event( ev( "error", sender=sender, code=code, message=error_message, trackId=self.current_track_id, ) ) def _get_timestamp_ms(self) -> int: """Get current timestamp in milliseconds.""" import time return int(time.time() * 1000) async def cleanup(self) -> None: """Cleanup session resources.""" async with self._cleanup_lock: if self._cleaned_up: return self._cleaned_up = True logger.info(f"Session {self.id} cleaning up") await self._finalize_history(status="connected") await self.pipeline.cleanup() await self.transport.close() async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: """Initialize backend history call record for this session.""" if self._history_call_id: return history_meta: Dict[str, Any] = {} if isinstance(metadata.get("history"), dict): history_meta = metadata["history"] raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id)) try: user_id = int(raw_user_id) except (TypeError, ValueError): user_id = settings.history_default_user_id assistant_id = history_meta.get("assistantId", metadata.get("assistantId")) source = str(history_meta.get("source", metadata.get("source", "debug"))) call_id = await create_history_call_record( user_id=user_id, assistant_id=str(assistant_id) if assistant_id else None, source=source, ) if not call_id: return self._history_call_id = call_id self._history_call_started_mono = time.monotonic() self._history_turn_index = 0 self._history_finalized = False logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") async def _on_turn_complete(self, turn: ConversationTurn) -> None: """Process workflow transitions and persist completed turns to history.""" if turn.text and turn.text.strip(): role = (turn.role or "").lower() if role == "user": self._workflow_last_user_text = turn.text.strip() elif role == "assistant": await self._maybe_advance_workflow(turn.text.strip()) if not self._history_call_id: return if not turn.text or not turn.text.strip(): return role = (turn.role or "").lower() speaker = "human" if role == "user" else "ai" end_ms = 0 if self._history_call_started_mono is not None: end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000)) estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80)) start_ms = max(0, end_ms - estimated_duration_ms) turn_index = self._history_turn_index await add_history_transcript( call_id=self._history_call_id, turn_index=turn_index, speaker=speaker, content=turn.text.strip(), start_ms=start_ms, end_ms=end_ms, duration_ms=max(1, end_ms - start_ms), ) self._history_turn_index += 1 async def _finalize_history(self, status: str) -> None: """Finalize history call record once.""" if not self._history_call_id or self._history_finalized: return duration_seconds = 0 if self._history_call_started_mono is not None: duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono)) ok = await finalize_history_call_record( call_id=self._history_call_id, status=status, duration_seconds=duration_seconds, ) if ok: self._history_finalized = True def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Parse workflow payload and return initial runtime overrides.""" payload = metadata.get("workflow") self.workflow_runner = WorkflowRunner.from_payload(payload) self._workflow_initial_node = None if not self.workflow_runner: return {} node = self.workflow_runner.bootstrap() if not node: logger.warning(f"Session {self.id} workflow payload had no resolvable start node") self.workflow_runner = None return {} self._workflow_initial_node = node logger.info( "Session {} workflow enabled: workflow={} start_node={}", self.id, self.workflow_runner.workflow_id, node.id, ) return self.workflow_runner.build_runtime_metadata(node) async def _maybe_advance_workflow(self, assistant_text: str) -> None: """Attempt node transfer after assistant turn finalization.""" if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED: return transition = await self.workflow_runner.route( user_text=self._workflow_last_user_text, assistant_text=assistant_text, llm_router=self._workflow_llm_route, ) if not transition: return await self._apply_workflow_transition(transition, reason="rule_match") # Auto-advance through utility nodes when default edges are present. max_auto_hops = 6 auto_hops = 0 while self.workflow_runner and self.ws_state != WsSessionState.STOPPED: current = self.workflow_runner.current_node if not current or current.node_type not in {"start", "tool"}: break next_default = self.workflow_runner.next_default_transition() if not next_default: break auto_hops += 1 await self._apply_workflow_transition(next_default, reason="auto") if auto_hops >= max_auto_hops: logger.warning( "Session {} workflow auto-advance reached hop limit (possible cycle)", self.id, ) break async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None: """Apply graph transition and emit workflow lifecycle events.""" if not self.workflow_runner: return self.workflow_runner.apply_transition(transition) node = transition.node edge = transition.edge await self.transport.send_event( ev( "workflow.edge.taken", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, edgeId=edge.id, fromNodeId=edge.from_node_id, toNodeId=edge.to_node_id, reason=reason, ) ) await self.transport.send_event( ev( "workflow.node.entered", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, nodeName=node.name, nodeType=node.node_type, ) ) node_runtime = self.workflow_runner.build_runtime_metadata(node) if node_runtime: self.pipeline.apply_runtime_overrides(node_runtime) if node.node_type == "tool": await self.transport.send_event( ev( "workflow.tool.requested", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, tool=node.tool or {}, ) ) return if node.node_type == "human_transfer": await self.transport.send_event( ev( "workflow.human_transfer", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_human_transfer") return if node.node_type == "end": await self.transport.send_event( ev( "workflow.ended", sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_end") async def _workflow_llm_route( self, node: WorkflowNodeDef, candidates: List[WorkflowEdgeDef], context: Dict[str, str], ) -> Optional[str]: """LLM-based edge routing for condition.type == 'llm' edges.""" llm_service = self.pipeline.llm_service if not llm_service: return None candidate_rows = [ { "edgeId": edge.id, "toNodeId": edge.to_node_id, "label": edge.label, "hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None, } for edge in candidates ] system_prompt = ( "You are a workflow router. Pick exactly one edge. " "Return JSON only: {\"edgeId\":\"...\"}." ) user_prompt = json.dumps( { "nodeId": node.id, "nodeName": node.name, "userText": context.get("userText", ""), "assistantText": context.get("assistantText", ""), "candidates": candidate_rows, }, ensure_ascii=False, ) try: reply = await llm_service.generate( [ LLMMessage(role="system", content=system_prompt), LLMMessage(role="user", content=user_prompt), ], temperature=0.0, max_tokens=64, ) except Exception as exc: logger.warning(f"Session {self.id} workflow llm routing failed: {exc}") return None if not reply: return None edge_ids = {edge.id for edge in candidates} node_ids = {edge.to_node_id for edge in candidates} parsed = self._extract_json_obj(reply) if isinstance(parsed, dict): edge_id = parsed.get("edgeId") or parsed.get("id") node_id = parsed.get("toNodeId") or parsed.get("nodeId") if isinstance(edge_id, str) and edge_id in edge_ids: return edge_id if isinstance(node_id, str) and node_id in node_ids: return node_id token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True) lowered_reply = reply.lower() for token in token_candidates: if token.lower() in lowered_reply: return token return None def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: """Merge node-level metadata overrides into session.start metadata.""" merged = dict(base or {}) if not overrides: return merged for key, value in overrides.items(): if key == "services" and isinstance(value, dict): existing = merged.get("services") merged_services = dict(existing) if isinstance(existing, dict) else {} merged_services.update(value) merged["services"] = merged_services else: merged[key] = value return merged def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: """Best-effort extraction of a JSON object from freeform text.""" try: parsed = json.loads(text) if isinstance(parsed, dict): return parsed except Exception: pass match = re.search(r"\{.*\}", text, re.DOTALL) if not match: return None try: parsed = json.loads(match.group(0)) return parsed if isinstance(parsed, dict) else None except Exception: return None