"""Session management for active calls.""" import asyncio import json import re import time from datetime import datetime, timezone from enum import Enum from typing import Optional, Dict, Any, List from loguru import logger from app.backend_adapters import build_backend_adapter_from_settings from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from core.conversation import ConversationTurn from core.history_bridge import SessionHistoryBridge from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef from app.config import settings from services.base import LLMMessage from models.ws_v1 import ( parse_client_message, ev, SessionStartMessage, SessionStopMessage, InputTextMessage, ResponseCancelMessage, ToolCallResultsMessage, ) _DYNAMIC_VARIABLE_KEY_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$") _DYNAMIC_VARIABLE_PLACEHOLDER_RE = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}") _DYNAMIC_VARIABLE_MAX_ITEMS = 30 _DYNAMIC_VARIABLE_VALUE_MAX_CHARS = 1000 _SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"} class WsSessionState(str, Enum): """Protocol state machine for WS sessions.""" WAIT_START = "wait_start" ACTIVE = "active" STOPPED = "stopped" class Session: """ Manages a single call session. Handles command routing, audio processing, and session lifecycle. Uses full duplex voice conversation pipeline. """ TRACK_AUDIO_IN = "audio_in" TRACK_AUDIO_OUT = "audio_out" TRACK_CONTROL = "control" AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms _METADATA_ALLOWED_TOP_LEVEL_KEYS = { "overrides", "dynamicVariables", "channel", "source", "history", "workflow", # explicitly ignored for this MVP protocol version } _METADATA_ALLOWED_OVERRIDE_KEYS = { "firstTurnMode", "greeting", "generatedOpenerEnabled", "systemPrompt", "output", "bargeIn", "knowledge", "knowledgeBaseId", "openerAudio", "tools", } _METADATA_FORBIDDEN_TOP_LEVEL_KEYS = { "assistantId", "appId", "app_id", "configVersionId", "config_version_id", "services", } _METADATA_FORBIDDEN_KEY_TOKENS = { "apikey", "token", "secret", "password", "authorization", } def __init__( self, session_id: str, transport: BaseTransport, use_duplex: bool = None, backend_gateway: Optional[Any] = None, assistant_id: Optional[str] = None, ): """ Initialize session. Args: session_id: Unique session identifier transport: Transport instance for communication use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled) """ self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled self._assistant_id = str(assistant_id or "").strip() or None self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() self._history_bridge = SessionHistoryBridge( history_writer=self._backend_gateway, enabled=settings.history_enabled, queue_max_size=settings.history_queue_max_size, retry_max_attempts=settings.history_retry_max_attempts, retry_backoff_sec=settings.history_retry_backoff_sec, finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec, ) self.pipeline = DuplexPipeline( transport=transport, session_id=session_id, system_prompt=settings.duplex_system_prompt, greeting=settings.duplex_greeting, knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None), tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None), ) # Session state self.created_at = None self.state = "created" # Legacy call state for /call/lists self.ws_state = WsSessionState.WAIT_START self._pipeline_started = False # Track IDs self.current_track_id: str = self.TRACK_CONTROL self._event_seq: int = 0 self._cleanup_lock = asyncio.Lock() self._cleaned_up = False self.workflow_runner: Optional[WorkflowRunner] = None self._workflow_last_user_text: str = "" self._workflow_initial_node: Optional[WorkflowNodeDef] = None self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.conversation.on_turn_complete(self._on_turn_complete) logger.info( "Session {} created (duplex={}, assistant_id={})", self.id, self.use_duplex, self._assistant_id or "-", ) async def handle_text(self, text_data: str) -> None: """ Handle incoming text data (WS v1 JSON control messages). Args: text_data: JSON text data """ try: data = json.loads(text_data) message = parse_client_message(data) await self._handle_v1_message(message) except json.JSONDecodeError as e: logger.error(f"Session {self.id} JSON decode error: {e}") await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json") except ValueError as e: logger.error(f"Session {self.id} command parse error: {e}") await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message") except Exception as e: logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) await self._send_error("server", f"Internal error: {e}", "server.internal") async def handle_audio(self, audio_bytes: bytes) -> None: """ Handle incoming audio data. Args: audio_bytes: PCM audio data """ if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", "Audio received before session.start", "protocol.order", stage="protocol", retryable=False, ) return try: if not audio_bytes: return if len(audio_bytes) % 2 != 0: await self._send_error( "client", "Invalid PCM payload: odd number of bytes", "audio.invalid_pcm", stage="audio", retryable=False, ) return frame_bytes = self.AUDIO_FRAME_BYTES if len(audio_bytes) % frame_bytes != 0: await self._send_error( "client", f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)", "audio.frame_size_mismatch", stage="audio", retryable=False, ) return for i in range(0, len(audio_bytes), frame_bytes): frame = audio_bytes[i : i + frame_bytes] await self.pipeline.process_audio(frame) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) await self._send_error( "server", f"Audio processing failed: {e}", "audio.processing_failed", stage="audio", retryable=True, ) async def _handle_v1_message(self, message: Any) -> None: """Route validated WS v1 message to handlers.""" msg_type = message.type logger.info(f"Session {self.id} received message: {msg_type}") if isinstance(message, SessionStartMessage): await self._handle_session_start(message) return # All messages below require session.start first if self.ws_state == WsSessionState.WAIT_START: await self._send_error( "client", "Expected session.start message first", "protocol.order", ) return # All messages below require active session if self.ws_state != WsSessionState.ACTIVE: await self._send_error( "client", f"Message '{msg_type}' requires active session", "protocol.order", ) return if isinstance(message, InputTextMessage): await self.pipeline.process_text(message.text) elif isinstance(message, ResponseCancelMessage): if message.graceful: logger.info(f"Session {self.id} graceful response.cancel") else: await self.pipeline.interrupt() elif isinstance(message, ToolCallResultsMessage): await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results]) elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") async def _handle_session_start(self, message: SessionStartMessage) -> None: """Handle explicit session start.""" if self.ws_state != WsSessionState.WAIT_START: await self._send_error("client", "Duplicate session.start", "protocol.order") return raw_metadata = message.metadata or {} if not self._assistant_id: await self._send_error( "client", "Missing required query parameter assistant_id", "protocol.assistant_id_required", stage="protocol", retryable=False, ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata) if metadata_error: await self._send_error( "client", metadata_error["message"], metadata_error["code"], stage="protocol", retryable=False, ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id) if runtime_error: await self._send_error( "server", runtime_error["message"], runtime_error["code"], stage="protocol", retryable=False, ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {})) for key in ("channel", "source", "history"): if key in sanitized_metadata: metadata[key] = sanitized_metadata[key] metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata) if dynamic_var_error: await self._send_error( "client", dynamic_var_error["message"], dynamic_var_error["code"], stage="protocol", retryable=False, ) await self.transport.close() self.ws_state = WsSessionState.STOPPED return # Create history call record early so later turn callbacks can append transcripts. await self._start_history_bridge(metadata) # Apply runtime service/prompt overrides from backend if provided self.pipeline.apply_runtime_overrides(metadata) resolved_preview = self.pipeline.resolved_runtime_config() resolved_services = resolved_preview.get("services", {}) if isinstance(resolved_preview, dict) else {} llm_cfg = resolved_services.get("llm", {}) if isinstance(resolved_services, dict) else {} asr_cfg = resolved_services.get("asr", {}) if isinstance(resolved_services, dict) else {} tts_cfg = resolved_services.get("tts", {}) if isinstance(resolved_services, dict) else {} logger.info( "Session {} effective runtime services " "(assistantId={}, configVersionId={}, output_mode={}, " "llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})", self.id, metadata.get("assistantId"), metadata.get("configVersionId") or metadata.get("config_version_id"), (resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None, llm_cfg.get("provider"), llm_cfg.get("model"), asr_cfg.get("provider"), asr_cfg.get("model"), tts_cfg.get("provider"), tts_cfg.get("model"), tts_cfg.get("enabled"), ) # Start duplex pipeline if not self._pipeline_started: await self.pipeline.start() self._pipeline_started = True logger.info(f"Session {self.id} duplex pipeline started") self.state = "accepted" self.ws_state = WsSessionState.ACTIVE await self._send_event( ev( "session.started", trackId=self.current_track_id, tracks={ "audio_in": self.TRACK_AUDIO_IN, "audio_out": self.TRACK_AUDIO_OUT, "control": self.TRACK_CONTROL, }, audio=message.audio.model_dump() if message.audio else {}, ) ) if settings.ws_emit_config_resolved: await self._send_event( ev( "config.resolved", trackId=self.TRACK_CONTROL, config=self._build_config_resolved(metadata), ) ) else: logger.debug("Session {} skipped config.resolved (ws_emit_config_resolved=false)", self.id) # Emit opener only after frontend has received session.started (and optional config event). await self.pipeline.emit_initial_greeting() async def _handle_session_stop(self, reason: Optional[str]) -> None: """Handle session stop.""" if self.ws_state == WsSessionState.STOPPED: return stop_reason = reason or "client_requested" self.state = "hungup" self.ws_state = WsSessionState.STOPPED await self._send_event( ev( "session.stopped", reason=stop_reason, ) ) await self._finalize_history(status="connected") await self.transport.close() async def _send_error( self, sender: str, error_message: str, code: str, stage: Optional[str] = None, retryable: Optional[bool] = None, track_id: Optional[str] = None, ) -> None: """ Send error event to client. Args: sender: Component that generated the error error_message: Error message code: Machine-readable error code """ resolved_stage = stage or self._infer_error_stage(code) resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"}) resolved_track_id = track_id or self._error_track_id(resolved_stage, code) await self._send_event( ev( "error", sender=sender, code=code, message=error_message, stage=resolved_stage, retryable=resolved_retryable, trackId=resolved_track_id, data={ "error": { "stage": resolved_stage, "code": code, "message": error_message, "retryable": resolved_retryable, } }, ) ) def _get_timestamp_ms(self) -> int: """Get current timestamp in milliseconds.""" import time return int(time.time() * 1000) async def cleanup(self) -> None: """Cleanup session resources.""" async with self._cleanup_lock: if self._cleaned_up: return self._cleaned_up = True logger.info(f"Session {self.id} cleaning up") await self._finalize_history(status="connected") await self.pipeline.cleanup() await self._history_bridge.shutdown() await self.transport.close() async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: """Initialize backend history call record for this session.""" if self._history_bridge.call_id: return history_meta: Dict[str, Any] = {} if isinstance(metadata.get("history"), dict): history_meta = metadata["history"] raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id)) try: user_id = int(raw_user_id) except (TypeError, ValueError): user_id = settings.history_default_user_id assistant_id = history_meta.get("assistantId", metadata.get("assistantId")) source = str(history_meta.get("source", metadata.get("source", "debug"))) call_id = await self._history_bridge.start_call( user_id=user_id, assistant_id=str(assistant_id) if assistant_id else None, source=source, ) if not call_id: return logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") async def _on_turn_complete(self, turn: ConversationTurn) -> None: """Process workflow transitions and persist completed turns to history.""" if turn.text and turn.text.strip(): role = (turn.role or "").lower() if role == "user": self._workflow_last_user_text = turn.text.strip() elif role == "assistant": await self._maybe_advance_workflow(turn.text.strip()) self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "") async def _finalize_history(self, status: str) -> None: """Finalize history call record once.""" await self._history_bridge.finalize(status=status) def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Parse workflow payload and return initial runtime overrides.""" payload = metadata.get("workflow") self.workflow_runner = WorkflowRunner.from_payload(payload) self._workflow_initial_node = None if not self.workflow_runner: return {} node = self.workflow_runner.bootstrap() if not node: logger.warning(f"Session {self.id} workflow payload had no resolvable start node") self.workflow_runner = None return {} self._workflow_initial_node = node logger.info( "Session {} workflow enabled: workflow={} start_node={}", self.id, self.workflow_runner.workflow_id, node.id, ) return self.workflow_runner.build_runtime_metadata(node) async def _maybe_advance_workflow(self, assistant_text: str) -> None: """Attempt node transfer after assistant turn finalization.""" if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED: return transition = await self.workflow_runner.route( user_text=self._workflow_last_user_text, assistant_text=assistant_text, llm_router=self._workflow_llm_route, ) if not transition: return await self._apply_workflow_transition(transition, reason="rule_match") # Auto-advance through utility nodes when default edges are present. max_auto_hops = 6 auto_hops = 0 while self.workflow_runner and self.ws_state != WsSessionState.STOPPED: current = self.workflow_runner.current_node if not current or current.node_type not in {"start", "tool"}: break next_default = self.workflow_runner.next_default_transition() if not next_default: break auto_hops += 1 await self._apply_workflow_transition(next_default, reason="auto") if auto_hops >= max_auto_hops: logger.warning( "Session {} workflow auto-advance reached hop limit (possible cycle)", self.id, ) break async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None: """Apply graph transition and emit workflow lifecycle events.""" if not self.workflow_runner: return self.workflow_runner.apply_transition(transition) node = transition.node edge = transition.edge await self._send_event( ev( "workflow.edge.taken", workflowId=self.workflow_runner.workflow_id, edgeId=edge.id, fromNodeId=edge.from_node_id, toNodeId=edge.to_node_id, reason=reason, ) ) await self._send_event( ev( "workflow.node.entered", workflowId=self.workflow_runner.workflow_id, nodeId=node.id, nodeName=node.name, nodeType=node.node_type, ) ) node_runtime = self.workflow_runner.build_runtime_metadata(node) if node_runtime: self.pipeline.apply_runtime_overrides(node_runtime) if node.node_type == "tool": await self._send_event( ev( "workflow.tool.requested", workflowId=self.workflow_runner.workflow_id, nodeId=node.id, tool=node.tool or {}, ) ) return if node.node_type == "human_transfer": await self._send_event( ev( "workflow.human_transfer", workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_human_transfer") return if node.node_type == "end": await self._send_event( ev( "workflow.ended", workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_end") def _next_event_seq(self) -> int: self._event_seq += 1 return self._event_seq def _event_source(self, event_type: str) -> str: if event_type.startswith("workflow."): return "system" if event_type.startswith("session.") or event_type == "heartbeat": return "system" if event_type == "error": return "system" return "system" def _infer_error_stage(self, code: str) -> str: normalized = str(code or "").strip().lower() if normalized.startswith("audio."): return "audio" if normalized.startswith("tool."): return "tool" if normalized.startswith("asr."): return "asr" if normalized.startswith("llm."): return "llm" if normalized.startswith("tts."): return "tts" return "protocol" def _error_track_id(self, stage: str, code: str) -> str: if stage in {"audio", "asr"}: return self.TRACK_AUDIO_IN if stage in {"llm", "tts", "tool"}: return self.TRACK_AUDIO_OUT return self.TRACK_CONTROL def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: event_type = str(event.get("type") or "") source = str(event.get("source") or self._event_source(event_type)) track_id = event.get("trackId") or self.TRACK_CONTROL data = event.get("data") if not isinstance(data, dict): data = {} for k, v in event.items(): if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}: continue data.setdefault(k, v) event["sessionId"] = self.id event["seq"] = self._next_event_seq() event["source"] = source event["trackId"] = track_id event["data"] = data return event async def _send_event(self, event: Dict[str, Any]) -> None: await self.transport.send_event(self._envelope_event(event)) async def send_heartbeat(self) -> None: await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL)) async def _workflow_llm_route( self, node: WorkflowNodeDef, candidates: List[WorkflowEdgeDef], context: Dict[str, str], ) -> Optional[str]: """LLM-based edge routing for condition.type == 'llm' edges.""" llm_service = self.pipeline.llm_service if not llm_service: return None candidate_rows = [ { "edgeId": edge.id, "toNodeId": edge.to_node_id, "label": edge.label, "hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None, } for edge in candidates ] system_prompt = ( "You are a workflow router. Pick exactly one edge. " "Return JSON only: {\"edgeId\":\"...\"}." ) user_prompt = json.dumps( { "nodeId": node.id, "nodeName": node.name, "userText": context.get("userText", ""), "assistantText": context.get("assistantText", ""), "candidates": candidate_rows, }, ensure_ascii=False, ) try: reply = await llm_service.generate( [ LLMMessage(role="system", content=system_prompt), LLMMessage(role="user", content=user_prompt), ], temperature=0.0, max_tokens=64, ) except Exception as exc: logger.warning(f"Session {self.id} workflow llm routing failed: {exc}") return None if not reply: return None edge_ids = {edge.id for edge in candidates} node_ids = {edge.to_node_id for edge in candidates} parsed = self._extract_json_obj(reply) if isinstance(parsed, dict): edge_id = parsed.get("edgeId") or parsed.get("id") node_id = parsed.get("toNodeId") or parsed.get("nodeId") if isinstance(edge_id, str) and edge_id in edge_ids: return edge_id if isinstance(node_id, str) and node_id in node_ids: return node_id token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True) lowered_reply = reply.lower() for token in token_candidates: if token.lower() in lowered_reply: return token return None def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: """Merge node-level metadata overrides into session.start metadata.""" merged = dict(base or {}) if not overrides: return merged for key, value in overrides.items(): if key == "services" and isinstance(value, dict): existing = merged.get("services") merged_services = dict(existing) if isinstance(existing, dict) else {} merged_services.update(value) merged["services"] = merged_services else: merged[key] = value return merged def _extract_dynamic_template_keys(self, text: Any) -> List[str]: """Collect placeholder keys from a template string.""" if text is None: return [] keys: List[str] = [] seen = set() for match in _DYNAMIC_VARIABLE_PLACEHOLDER_RE.finditer(str(text)): key = str(match.group(1) or "") if key and key not in seen: seen.add(key) keys.append(key) return keys def _render_dynamic_template( self, template: str, dynamic_vars: Dict[str, str], ) -> tuple[str, List[str]]: """Render one template text and return missing keys if any.""" missing = set() def _replace(match: re.Match) -> str: key = str(match.group(1) or "") if key not in dynamic_vars: missing.add(key) return match.group(0) return dynamic_vars[key] rendered = _DYNAMIC_VARIABLE_PLACEHOLDER_RE.sub(_replace, str(template or "")) return rendered, sorted(missing) def _system_dynamic_variables(self) -> Dict[str, str]: """Build system-level dynamic variables for the current session timestamp.""" local_now = datetime.now().astimezone() utc_now = local_now.astimezone(timezone.utc) return { "system__time": local_now.strftime("%Y-%m-%d %H:%M:%S"), "system_utc": utc_now.strftime("%Y-%m-%d %H:%M:%S"), "system_timezone": str(local_now.tzinfo or ""), } def _apply_dynamic_variables( self, metadata: Dict[str, Any], client_metadata: Dict[str, Any], ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: """ Apply session.start metadata.dynamicVariables to prompt/greeting templates. Returns: tuple(merged_metadata, error_payload) error_payload shape: {"code": "...", "message": "..."} """ merged = dict(metadata or {}) raw_dynamic_vars = client_metadata.get("dynamicVariables") dynamic_vars: Dict[str, str] = self._system_dynamic_variables() if raw_dynamic_vars is not None: if not isinstance(raw_dynamic_vars, dict): return merged, { "code": "protocol.dynamic_variables_invalid", "message": "metadata.dynamicVariables must be an object with string key/value pairs", } if len(raw_dynamic_vars) > _DYNAMIC_VARIABLE_MAX_ITEMS: return merged, { "code": "protocol.dynamic_variables_invalid", "message": f"metadata.dynamicVariables cannot exceed {_DYNAMIC_VARIABLE_MAX_ITEMS} entries", } for raw_key, raw_value in raw_dynamic_vars.items(): if not isinstance(raw_key, str): return merged, { "code": "protocol.dynamic_variables_invalid", "message": "metadata.dynamicVariables keys must be strings", } key = raw_key.strip() if not _DYNAMIC_VARIABLE_KEY_RE.match(key): return merged, { "code": "protocol.dynamic_variables_invalid", "message": ( "Invalid dynamic variable key " f"'{raw_key}'. Expected ^[a-zA-Z_][a-zA-Z0-9_]{{0,63}}$" ), } if key in _SYSTEM_DYNAMIC_VARIABLE_KEYS: # Reserved system variables are generated by server time context. continue if key in dynamic_vars: return merged, { "code": "protocol.dynamic_variables_invalid", "message": f"Duplicate dynamic variable key '{key}'", } if not isinstance(raw_value, str): return merged, { "code": "protocol.dynamic_variables_invalid", "message": f"Dynamic variable '{key}' value must be a string", } if len(raw_value) > _DYNAMIC_VARIABLE_VALUE_MAX_CHARS: return merged, { "code": "protocol.dynamic_variables_invalid", "message": ( f"Dynamic variable '{key}' exceeds " f"{_DYNAMIC_VARIABLE_VALUE_MAX_CHARS} characters" ), } dynamic_vars[key] = raw_value template_keys = set(self._extract_dynamic_template_keys(merged.get("systemPrompt"))) template_keys.update(self._extract_dynamic_template_keys(merged.get("greeting"))) if not template_keys: return merged, None missing_keys = sorted([key for key in template_keys if key not in dynamic_vars]) if missing_keys: return merged, { "code": "protocol.dynamic_variables_missing", "message": f"Missing dynamic variables for placeholders: {', '.join(missing_keys)}", } for field in ("systemPrompt", "greeting"): value = merged.get(field) if value is None: continue rendered, unresolved = self._render_dynamic_template(str(value), dynamic_vars) if unresolved: return merged, { "code": "protocol.dynamic_variables_missing", "message": f"Missing dynamic variables for placeholders: {', '.join(unresolved)}", } merged[field] = rendered return merged, None async def _load_server_runtime_metadata( self, assistant_id: str, ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: """Load trusted runtime metadata from backend assistant config.""" if not assistant_id: return {}, { "code": "protocol.assistant_id_required", "message": "Missing required query parameter assistant_id", } provider = getattr(self._backend_gateway, "fetch_assistant_config", None) if not callable(provider): return {}, { "code": "assistant.config_unavailable", "message": "Assistant config backend unavailable", } payload = await provider(str(assistant_id).strip()) if isinstance(payload, dict): error_code = str(payload.get("__error_code") or "").strip() if error_code == "assistant.not_found": return {}, { "code": "assistant.not_found", "message": f"Assistant not found: {assistant_id}", } if error_code == "assistant.config_unavailable": return {}, { "code": "assistant.config_unavailable", "message": f"Assistant config unavailable: {assistant_id}", } if not isinstance(payload, dict): return {}, { "code": "assistant.config_unavailable", "message": f"Assistant config unavailable: {assistant_id}", } assistant_cfg: Dict[str, Any] = {} session_start_cfg = payload.get("sessionStartMetadata") if isinstance(session_start_cfg, dict): assistant_cfg.update(session_start_cfg) if isinstance(payload.get("assistant"), dict): assistant_cfg.update(payload.get("assistant")) elif not assistant_cfg: assistant_cfg = payload if not isinstance(assistant_cfg, dict): return {}, { "code": "assistant.config_unavailable", "message": f"Assistant config unavailable: {assistant_id}", } runtime: Dict[str, Any] = {} passthrough_keys = { "firstTurnMode", "generatedOpenerEnabled", "output", "bargeIn", "knowledgeBaseId", "knowledge", "openerAudio", "history", "userId", "source", "tools", "services", "configVersionId", "config_version_id", } for key in passthrough_keys: if key in assistant_cfg: runtime[key] = assistant_cfg[key] if assistant_cfg.get("systemPrompt") is not None: runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "") elif assistant_cfg.get("prompt") is not None: runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "") if assistant_cfg.get("greeting") is not None: runtime["greeting"] = assistant_cfg.get("greeting") elif assistant_cfg.get("opener") is not None: runtime["greeting"] = assistant_cfg.get("opener") resolved_assistant_id = ( assistant_cfg.get("assistantId") or payload.get("assistantId") or assistant_id ) runtime["assistantId"] = str(resolved_assistant_id) if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None: runtime["configVersionId"] = payload.get("configVersionId") if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None: runtime["configVersionId"] = payload.get("config_version_id") if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None: runtime["configVersionId"] = runtime.get("config_version_id") return runtime, None def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]: if isinstance(payload, dict): for key, value in payload.items(): key_str = str(key) normalized = key_str.lower().replace("_", "").replace("-", "") if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS): return f"{path}.{key_str}" nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}") if nested: return nested return None if isinstance(payload, list): for idx, value in enumerate(payload): nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]") if nested: return nested return None def _validate_and_sanitize_client_metadata( self, metadata: Dict[str, Any], ) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]: if not isinstance(metadata, dict): return {}, { "code": "protocol.invalid_override", "message": "metadata must be an object", } forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata] if forbidden_top_level: return {}, { "code": "protocol.invalid_override", "message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}", } unknown_keys = [ key for key in metadata.keys() if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS ] if unknown_keys: return {}, { "code": "protocol.invalid_override", "message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}", } if "workflow" in metadata: logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id) overrides_raw = metadata.get("overrides") overrides: Dict[str, Any] = {} if overrides_raw is not None: if not isinstance(overrides_raw, dict): return {}, { "code": "protocol.invalid_override", "message": "metadata.overrides must be an object", } unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS] if unsupported_override_keys: return {}, { "code": "protocol.invalid_override", "message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}", } overrides = dict(overrides_raw) dynamic_variables = metadata.get("dynamicVariables") history_raw = metadata.get("history") history: Dict[str, Any] = {} if history_raw is not None: if not isinstance(history_raw, dict): return {}, { "code": "protocol.invalid_override", "message": "metadata.history must be an object", } unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"] if unsupported_history_keys: return {}, { "code": "protocol.invalid_override", "message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}", } if "userId" in history_raw: history["userId"] = history_raw.get("userId") sanitized: Dict[str, Any] = {"overrides": overrides} if dynamic_variables is not None: sanitized["dynamicVariables"] = dynamic_variables if "channel" in metadata: sanitized["channel"] = metadata.get("channel") if "source" in metadata: sanitized["source"] = metadata.get("source") if history: sanitized["history"] = history forbidden_path = self._find_forbidden_secret_key(sanitized) if forbidden_path: return {}, { "code": "protocol.invalid_override", "message": f"Forbidden secret-like key detected at {forbidden_path}", } return sanitized, None def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Build public resolved config payload (SaaS-safe, no internal runtime details).""" runtime = self.pipeline.resolved_runtime_config() runtime_output = runtime.get("output", {}) if isinstance(runtime, dict) else {} output_mode = str(runtime_output.get("mode") or "").strip().lower() if isinstance(runtime_output, dict) else "" if output_mode not in {"audio", "text"}: output_mode = "audio" tools_allowlist: List[str] = [] runtime_tools = runtime.get("tools", {}) if isinstance(runtime, dict) else {} if isinstance(runtime_tools, dict): allowlist = runtime_tools.get("allowlist", []) if isinstance(allowlist, list): tools_allowlist = [str(item) for item in allowlist if item is not None and str(item).strip()] resolved: Dict[str, Any] = { "output": {"mode": output_mode}, "tools": { "enabled": bool(tools_allowlist), "count": len(tools_allowlist), }, "tracks": { "audio_in": self.TRACK_AUDIO_IN, "audio_out": self.TRACK_AUDIO_OUT, "control": self.TRACK_CONTROL, }, } if metadata.get("channel") is not None: resolved["channel"] = metadata.get("channel") return resolved def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: """Best-effort extraction of a JSON object from freeform text.""" try: parsed = json.loads(text) if isinstance(parsed, dict): return parsed except Exception: pass match = re.search(r"\{.*\}", text, re.DOTALL) if not match: return None try: parsed = json.loads(match.group(0)) return parsed if isinstance(parsed, dict) else None except Exception: return None