Files
AI-VideoAssistant/engine/runtime/session/manager.py
Xin Wang 7e0b777923 Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules.
- Introduced new `adapters` and `protocol` packages for better organization.
- Added backend adapter implementations for control plane integration.
- Updated main application imports to reflect new package structure.
- Removed deprecated core components and adjusted documentation accordingly.
- Enhanced architecture documentation to clarify the new runtime and integration layers.
2026-03-06 09:51:56 +08:00

1230 lines
47 KiB
Python

"""Session management for active calls."""
import asyncio
import json
import re
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from adapters.control_plane.backend import build_backend_adapter_from_settings
from runtime.transports import BaseTransport
from runtime.ports import (
AssistantRuntimeConfigProvider,
ControlPlaneGateway,
ConversationHistoryStore,
KnowledgeRetriever,
ToolCatalog,
)
from runtime.pipeline.duplex import DuplexPipeline
from runtime.conversation import ConversationTurn
from runtime.history.bridge import SessionHistoryBridge
from workflow.runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from providers.common.base import LLMMessage
from protocol.ws_v1.schema import (
parse_client_message,
ev,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
OutputAudioPlayedMessage,
ToolCallResultsMessage,
)
_DYNAMIC_VARIABLE_KEY_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
_DYNAMIC_VARIABLE_PLACEHOLDER_RE = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}")
_DYNAMIC_VARIABLE_MAX_ITEMS = 30
_DYNAMIC_VARIABLE_VALUE_MAX_CHARS = 1000
_SYSTEM_DYNAMIC_VARIABLE_KEYS = {"system__time", "system_utc", "system_timezone"}
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
class Session:
"""
Manages a single call session.
Handles command routing, audio processing, and session lifecycle.
Uses full duplex voice conversation pipeline.
"""
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # Legacy fallback: 16k mono pcm_s16le, 20ms
_METADATA_ALLOWED_TOP_LEVEL_KEYS = {
"overrides",
"dynamicVariables",
"channel",
"source",
"history",
"workflow", # explicitly ignored for this MVP protocol version
}
_METADATA_ALLOWED_OVERRIDE_KEYS = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"manualOpenerToolCalls",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"openerAudio",
"tools",
}
_METADATA_FORBIDDEN_TOP_LEVEL_KEYS = {
"assistantId",
"appId",
"app_id",
"configVersionId",
"config_version_id",
"services",
}
_METADATA_FORBIDDEN_KEY_TOKENS = {
"apikey",
"token",
"secret",
"password",
"authorization",
}
def __init__(
self,
session_id: str,
transport: BaseTransport,
use_duplex: bool = None,
control_plane_gateway: Optional[ControlPlaneGateway] = None,
runtime_config_provider: Optional[AssistantRuntimeConfigProvider] = None,
history_store: Optional[ConversationHistoryStore] = None,
knowledge_retriever: Optional[KnowledgeRetriever] = None,
tool_catalog: Optional[ToolCatalog] = None,
assistant_id: Optional[str] = None,
):
"""
Initialize session.
Args:
session_id: Unique session identifier
transport: Transport instance for communication
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
control_plane_gateway: Optional composite control-plane dependency
runtime_config_provider: Optional assistant runtime config provider
history_store: Optional conversation history store
knowledge_retriever: Optional knowledge retrieval dependency
tool_catalog: Optional tool resource catalog
"""
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self.audio_frame_bytes = self._compute_audio_frame_bytes()
self._assistant_id = str(assistant_id or "").strip() or None
self._control_plane_gateway = control_plane_gateway or build_backend_adapter_from_settings()
self._runtime_config_provider = runtime_config_provider or self._control_plane_gateway
self._history_store = history_store or self._control_plane_gateway
self._knowledge_retriever = knowledge_retriever or self._control_plane_gateway
self._tool_catalog = tool_catalog or self._control_plane_gateway
self._history_bridge = SessionHistoryBridge(
history_writer=self._history_store,
enabled=settings.history_enabled,
queue_max_size=settings.history_queue_max_size,
retry_max_attempts=settings.history_retry_max_attempts,
retry_backoff_sec=settings.history_retry_backoff_sec,
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
)
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting,
knowledge_searcher=getattr(self._knowledge_retriever, "search_knowledge_context", None),
tool_resource_resolver=getattr(self._tool_catalog, "fetch_tool_resource", None),
)
# Session state
self.created_at = None
self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_START
self._pipeline_started = False
# Track IDs
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(
"Session {} created (duplex={}, assistant_id={})",
self.id,
self.use_duplex,
self._assistant_id or "-",
)
async def handle_text(self, text_data: str) -> None:
"""
Handle incoming text data (WS v1 JSON control messages).
Args:
text_data: JSON text data
"""
try:
data = json.loads(text_data)
message = parse_client_message(data)
await self._handle_v1_message(message)
except json.JSONDecodeError as e:
logger.error(f"Session {self.id} JSON decode error: {e}")
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
except ValueError as e:
logger.error(f"Session {self.id} command parse error: {e}")
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
except Exception as e:
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
await self._send_error("server", f"Internal error: {e}", "server.internal")
async def handle_audio(self, audio_bytes: bytes) -> None:
"""
Handle incoming audio data.
Args:
audio_bytes: PCM audio data
"""
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = getattr(self, "audio_frame_bytes", self._compute_audio_frame_bytes())
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
(
f"Audio frame size must be a multiple of {frame_bytes} bytes "
f"({settings.chunk_size_ms}ms PCM @ {settings.sample_rate}Hz)"
),
"audio.frame_size_mismatch",
stage="audio",
retryable=False,
)
return
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require session.start first
if self.ws_state == WsSessionState.WAIT_START:
await self._send_error(
"client",
"Expected session.start message first",
"protocol.order",
)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
f"Message '{msg_type}' requires active session",
"protocol.order",
)
return
if isinstance(message, InputTextMessage):
await self.pipeline.process_text(message.text)
elif isinstance(message, ResponseCancelMessage):
if message.graceful:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, OutputAudioPlayedMessage):
await self.pipeline.handle_output_audio_played(
tts_id=message.tts_id,
response_id=message.response_id,
turn_id=message.turn_id,
played_at_ms=message.played_at_ms,
played_ms=message.played_ms,
)
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
raw_metadata = message.metadata or {}
if not self._assistant_id:
await self._send_error(
"client",
"Missing required query parameter assistant_id",
"protocol.assistant_id_required",
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
sanitized_metadata, metadata_error = self._validate_and_sanitize_client_metadata(raw_metadata)
if metadata_error:
await self._send_error(
"client",
metadata_error["message"],
metadata_error["code"],
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
server_runtime, runtime_error = await self._load_server_runtime_metadata(self._assistant_id)
if runtime_error:
await self._send_error(
"server",
runtime_error["message"],
runtime_error["code"],
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
metadata = self._merge_runtime_metadata(server_runtime, sanitized_metadata.get("overrides", {}))
for key in ("channel", "source", "history"):
if key in sanitized_metadata:
metadata[key] = sanitized_metadata[key]
metadata, dynamic_var_error = self._apply_dynamic_variables(metadata, sanitized_metadata)
if dynamic_var_error:
await self._send_error(
"client",
dynamic_var_error["message"],
dynamic_var_error["code"],
stage="protocol",
retryable=False,
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
# Apply runtime service/prompt overrides from backend if provided
self.pipeline.apply_runtime_overrides(metadata)
resolved_preview = self.pipeline.resolved_runtime_config()
resolved_services = resolved_preview.get("services", {}) if isinstance(resolved_preview, dict) else {}
llm_cfg = resolved_services.get("llm", {}) if isinstance(resolved_services, dict) else {}
asr_cfg = resolved_services.get("asr", {}) if isinstance(resolved_services, dict) else {}
tts_cfg = resolved_services.get("tts", {}) if isinstance(resolved_services, dict) else {}
logger.info(
"Session {} effective runtime services "
"(assistantId={}, configVersionId={}, output_mode={}, "
"llm={}/{}, asr={}/{}, tts={}/{}, tts_enabled={})",
self.id,
metadata.get("assistantId"),
metadata.get("configVersionId") or metadata.get("config_version_id"),
(resolved_preview.get("output") or {}).get("mode") if isinstance(resolved_preview, dict) else None,
llm_cfg.get("provider"),
llm_cfg.get("model"),
asr_cfg.get("provider"),
asr_cfg.get("model"),
tts_cfg.get("provider"),
tts_cfg.get("model"),
tts_cfg.get("enabled"),
)
# Start duplex pipeline
if not self._pipeline_started:
await self.pipeline.start()
self._pipeline_started = True
logger.info(f"Session {self.id} duplex pipeline started")
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self._send_event(
ev(
"session.started",
trackId=self.current_track_id,
protocolVersion=self._public_ws_protocol_version(),
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio.model_dump() if message.audio else {},
)
)
if settings.ws_emit_config_resolved:
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
else:
logger.debug("Session {} skipped config.resolved (ws_emit_config_resolved=false)", self.id)
# Emit opener only after frontend has received session.started (and optional config event).
await self.pipeline.emit_initial_greeting()
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
if self.ws_state == WsSessionState.STOPPED:
return
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self._send_event(
ev(
"session.stopped",
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
Args:
sender: Component that generated the error
error_message: Error message
code: Machine-readable error code
"""
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds."""
import time
return int(time.time() * 1000)
async def cleanup(self) -> None:
"""Cleanup session resources."""
async with self._cleanup_lock:
if self._cleaned_up:
return
self._cleaned_up = True
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self._history_bridge.shutdown()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_bridge.call_id:
return
history_meta: Dict[str, Any] = {}
if isinstance(metadata.get("history"), dict):
history_meta = metadata["history"]
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
try:
user_id = int(raw_user_id)
except (TypeError, ValueError):
user_id = settings.history_default_user_id
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await self._history_bridge.start_call(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
)
if not call_id:
return
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Process workflow transitions and persist completed turns to history."""
if turn.text and turn.text.strip():
role = (turn.role or "").lower()
if role == "user":
self._workflow_last_user_text = turn.text.strip()
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
await self._history_bridge.finalize(status=status)
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
payload = metadata.get("workflow")
self.workflow_runner = WorkflowRunner.from_payload(payload)
self._workflow_initial_node = None
if not self.workflow_runner:
return {}
node = self.workflow_runner.bootstrap()
if not node:
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
self.workflow_runner = None
return {}
self._workflow_initial_node = node
logger.info(
"Session {} workflow enabled: workflow={} start_node={}",
self.id,
self.workflow_runner.workflow_id,
node.id,
)
return self.workflow_runner.build_runtime_metadata(node)
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
"""Attempt node transfer after assistant turn finalization."""
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
return
transition = await self.workflow_runner.route(
user_text=self._workflow_last_user_text,
assistant_text=assistant_text,
llm_router=self._workflow_llm_route,
)
if not transition:
return
await self._apply_workflow_transition(transition, reason="rule_match")
# Auto-advance through utility nodes when default edges are present.
max_auto_hops = 6
auto_hops = 0
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
current = self.workflow_runner.current_node
if not current or current.node_type not in {"start", "tool"}:
break
next_default = self.workflow_runner.next_default_transition()
if not next_default:
break
auto_hops += 1
await self._apply_workflow_transition(next_default, reason="auto")
if auto_hops >= max_auto_hops:
logger.warning(
"Session {} workflow auto-advance reached hop limit (possible cycle)",
self.id,
)
break
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
"""Apply graph transition and emit workflow lifecycle events."""
if not self.workflow_runner:
return
self.workflow_runner.apply_transition(transition)
node = transition.node
edge = transition.edge
await self._send_event(
ev(
"workflow.edge.taken",
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
toNodeId=edge.to_node_id,
reason=reason,
)
)
await self._send_event(
ev(
"workflow.node.entered",
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
nodeType=node.node_type,
)
)
node_runtime = self.workflow_runner.build_runtime_metadata(node)
if node_runtime:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self._send_event(
ev(
"workflow.tool.requested",
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
)
)
return
if node.node_type == "human_transfer":
await self._send_event(
ev(
"workflow.human_transfer",
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_human_transfer")
return
if node.node_type == "end":
await self._send_event(
ev(
"workflow.ended",
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
candidates: List[WorkflowEdgeDef],
context: Dict[str, str],
) -> Optional[str]:
"""LLM-based edge routing for condition.type == 'llm' edges."""
llm_service = self.pipeline.llm_service
if not llm_service:
return None
candidate_rows = [
{
"edgeId": edge.id,
"toNodeId": edge.to_node_id,
"label": edge.label,
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
}
for edge in candidates
]
system_prompt = (
"You are a workflow router. Pick exactly one edge. "
"Return JSON only: {\"edgeId\":\"...\"}."
)
user_prompt = json.dumps(
{
"nodeId": node.id,
"nodeName": node.name,
"userText": context.get("userText", ""),
"assistantText": context.get("assistantText", ""),
"candidates": candidate_rows,
},
ensure_ascii=False,
)
try:
reply = await llm_service.generate(
[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=user_prompt),
],
temperature=0.0,
max_tokens=64,
)
except Exception as exc:
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
return None
if not reply:
return None
edge_ids = {edge.id for edge in candidates}
node_ids = {edge.to_node_id for edge in candidates}
parsed = self._extract_json_obj(reply)
if isinstance(parsed, dict):
edge_id = parsed.get("edgeId") or parsed.get("id")
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
if isinstance(edge_id, str) and edge_id in edge_ids:
return edge_id
if isinstance(node_id, str) and node_id in node_ids:
return node_id
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
lowered_reply = reply.lower()
for token in token_candidates:
if token.lower() in lowered_reply:
return token
return None
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Merge node-level metadata overrides into session.start metadata."""
merged = dict(base or {})
if not overrides:
return merged
for key, value in overrides.items():
if key == "services" and isinstance(value, dict):
existing = merged.get("services")
merged_services = dict(existing) if isinstance(existing, dict) else {}
merged_services.update(value)
merged["services"] = merged_services
else:
merged[key] = value
return merged
def _extract_dynamic_template_keys(self, text: Any) -> List[str]:
"""Collect placeholder keys from a template string."""
if text is None:
return []
keys: List[str] = []
seen = set()
for match in _DYNAMIC_VARIABLE_PLACEHOLDER_RE.finditer(str(text)):
key = str(match.group(1) or "")
if key and key not in seen:
seen.add(key)
keys.append(key)
return keys
def _render_dynamic_template(
self,
template: str,
dynamic_vars: Dict[str, str],
) -> tuple[str, List[str]]:
"""Render one template text and return missing keys if any."""
missing = set()
def _replace(match: re.Match) -> str:
key = str(match.group(1) or "")
if key not in dynamic_vars:
missing.add(key)
return match.group(0)
return dynamic_vars[key]
rendered = _DYNAMIC_VARIABLE_PLACEHOLDER_RE.sub(_replace, str(template or ""))
return rendered, sorted(missing)
def _system_dynamic_variables(self) -> Dict[str, str]:
"""Build system-level dynamic variables for the current session timestamp."""
local_now = datetime.now().astimezone()
utc_now = local_now.astimezone(timezone.utc)
return {
"system__time": local_now.strftime("%Y-%m-%d %H:%M:%S"),
"system_utc": utc_now.strftime("%Y-%m-%d %H:%M:%S"),
"system_timezone": str(local_now.tzinfo or ""),
}
def _apply_dynamic_variables(
self,
metadata: Dict[str, Any],
client_metadata: Dict[str, Any],
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
"""
Apply session.start metadata.dynamicVariables to prompt/greeting templates.
Returns:
tuple(merged_metadata, error_payload)
error_payload shape: {"code": "...", "message": "..."}
"""
merged = dict(metadata or {})
raw_dynamic_vars = client_metadata.get("dynamicVariables")
dynamic_vars: Dict[str, str] = self._system_dynamic_variables()
if raw_dynamic_vars is not None:
if not isinstance(raw_dynamic_vars, dict):
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": "metadata.dynamicVariables must be an object with string key/value pairs",
}
if len(raw_dynamic_vars) > _DYNAMIC_VARIABLE_MAX_ITEMS:
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": f"metadata.dynamicVariables cannot exceed {_DYNAMIC_VARIABLE_MAX_ITEMS} entries",
}
for raw_key, raw_value in raw_dynamic_vars.items():
if not isinstance(raw_key, str):
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": "metadata.dynamicVariables keys must be strings",
}
key = raw_key.strip()
if not _DYNAMIC_VARIABLE_KEY_RE.match(key):
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": (
"Invalid dynamic variable key "
f"'{raw_key}'. Expected ^[a-zA-Z_][a-zA-Z0-9_]{{0,63}}$"
),
}
if key in _SYSTEM_DYNAMIC_VARIABLE_KEYS:
# Reserved system variables are generated by server time context.
continue
if key in dynamic_vars:
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": f"Duplicate dynamic variable key '{key}'",
}
if not isinstance(raw_value, str):
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": f"Dynamic variable '{key}' value must be a string",
}
if len(raw_value) > _DYNAMIC_VARIABLE_VALUE_MAX_CHARS:
return merged, {
"code": "protocol.dynamic_variables_invalid",
"message": (
f"Dynamic variable '{key}' exceeds "
f"{_DYNAMIC_VARIABLE_VALUE_MAX_CHARS} characters"
),
}
dynamic_vars[key] = raw_value
template_keys = set(self._extract_dynamic_template_keys(merged.get("systemPrompt")))
template_keys.update(self._extract_dynamic_template_keys(merged.get("greeting")))
if not template_keys:
return merged, None
missing_keys = sorted([key for key in template_keys if key not in dynamic_vars])
if missing_keys:
return merged, {
"code": "protocol.dynamic_variables_missing",
"message": f"Missing dynamic variables for placeholders: {', '.join(missing_keys)}",
}
for field in ("systemPrompt", "greeting"):
value = merged.get(field)
if value is None:
continue
rendered, unresolved = self._render_dynamic_template(str(value), dynamic_vars)
if unresolved:
return merged, {
"code": "protocol.dynamic_variables_missing",
"message": f"Missing dynamic variables for placeholders: {', '.join(unresolved)}",
}
merged[field] = rendered
return merged, None
async def _load_server_runtime_metadata(
self,
assistant_id: str,
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
"""Load trusted runtime metadata from control-plane assistant config."""
if not assistant_id:
return {}, {
"code": "protocol.assistant_id_required",
"message": "Missing required query parameter assistant_id",
}
provider = getattr(self._runtime_config_provider, "fetch_assistant_config", None)
if not callable(provider):
return {}, {
"code": "assistant.config_unavailable",
"message": "Assistant config control plane unavailable",
}
payload = await provider(str(assistant_id).strip())
if isinstance(payload, dict):
error_code = str(payload.get("__error_code") or "").strip()
if error_code == "assistant.not_found":
return {}, {
"code": "assistant.not_found",
"message": f"Assistant not found: {assistant_id}",
}
if error_code == "assistant.config_unavailable":
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
if not isinstance(payload, dict):
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata")
if isinstance(session_start_cfg, dict):
assistant_cfg.update(session_start_cfg)
if isinstance(payload.get("assistant"), dict):
assistant_cfg.update(payload.get("assistant"))
elif not assistant_cfg:
assistant_cfg = payload
if not isinstance(assistant_cfg, dict):
return {}, {
"code": "assistant.config_unavailable",
"message": f"Assistant config unavailable: {assistant_id}",
}
runtime: Dict[str, Any] = {}
passthrough_keys = {
"firstTurnMode",
"generatedOpenerEnabled",
"manualOpenerToolCalls",
"output",
"bargeIn",
"knowledgeBaseId",
"knowledge",
"openerAudio",
"history",
"userId",
"source",
"tools",
"services",
"configVersionId",
"config_version_id",
}
for key in passthrough_keys:
if key in assistant_cfg:
runtime[key] = assistant_cfg[key]
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
resolved_assistant_id = (
assistant_cfg.get("assistantId")
or payload.get("assistantId")
or assistant_id
)
runtime["assistantId"] = str(resolved_assistant_id)
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
runtime["configVersionId"] = payload.get("configVersionId")
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
runtime["configVersionId"] = payload.get("config_version_id")
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id")
return runtime, None
def _find_forbidden_secret_key(self, payload: Any, path: str = "metadata") -> Optional[str]:
if isinstance(payload, dict):
for key, value in payload.items():
key_str = str(key)
normalized = key_str.lower().replace("_", "").replace("-", "")
if any(token in normalized for token in self._METADATA_FORBIDDEN_KEY_TOKENS):
return f"{path}.{key_str}"
nested = self._find_forbidden_secret_key(value, f"{path}.{key_str}")
if nested:
return nested
return None
if isinstance(payload, list):
for idx, value in enumerate(payload):
nested = self._find_forbidden_secret_key(value, f"{path}[{idx}]")
if nested:
return nested
return None
def _validate_and_sanitize_client_metadata(
self,
metadata: Dict[str, Any],
) -> tuple[Dict[str, Any], Optional[Dict[str, str]]]:
if not isinstance(metadata, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata must be an object",
}
forbidden_top_level = [key for key in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS if key in metadata]
if forbidden_top_level:
return {}, {
"code": "protocol.invalid_override",
"message": f"Forbidden metadata keys: {', '.join(sorted(forbidden_top_level))}",
}
unknown_keys = [
key
for key in metadata.keys()
if key not in self._METADATA_ALLOWED_TOP_LEVEL_KEYS
and key not in self._METADATA_FORBIDDEN_TOP_LEVEL_KEYS
]
if unknown_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata keys: {', '.join(sorted(unknown_keys))}",
}
if "workflow" in metadata:
logger.warning("Session {} received metadata.workflow; workflow payload is ignored in MVP", self.id)
overrides_raw = metadata.get("overrides")
overrides: Dict[str, Any] = {}
if overrides_raw is not None:
if not isinstance(overrides_raw, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata.overrides must be an object",
}
unsupported_override_keys = [key for key in overrides_raw.keys() if key not in self._METADATA_ALLOWED_OVERRIDE_KEYS]
if unsupported_override_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.overrides keys: {', '.join(sorted(unsupported_override_keys))}",
}
overrides = dict(overrides_raw)
dynamic_variables = metadata.get("dynamicVariables")
history_raw = metadata.get("history")
history: Dict[str, Any] = {}
if history_raw is not None:
if not isinstance(history_raw, dict):
return {}, {
"code": "protocol.invalid_override",
"message": "metadata.history must be an object",
}
unsupported_history_keys = [key for key in history_raw.keys() if key != "userId"]
if unsupported_history_keys:
return {}, {
"code": "protocol.invalid_override",
"message": f"Unsupported metadata.history keys: {', '.join(sorted(unsupported_history_keys))}",
}
if "userId" in history_raw:
history["userId"] = history_raw.get("userId")
sanitized: Dict[str, Any] = {"overrides": overrides}
if dynamic_variables is not None:
sanitized["dynamicVariables"] = dynamic_variables
if "channel" in metadata:
sanitized["channel"] = metadata.get("channel")
if "source" in metadata:
sanitized["source"] = metadata.get("source")
if history:
sanitized["history"] = history
forbidden_path = self._find_forbidden_secret_key(sanitized)
if forbidden_path:
return {}, {
"code": "protocol.invalid_override",
"message": f"Forbidden secret-like key detected at {forbidden_path}",
}
return sanitized, None
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (SaaS-safe, no internal runtime details)."""
runtime = self.pipeline.resolved_runtime_config()
runtime_output = runtime.get("output", {}) if isinstance(runtime, dict) else {}
output_mode = str(runtime_output.get("mode") or "").strip().lower() if isinstance(runtime_output, dict) else ""
if output_mode not in {"audio", "text"}:
output_mode = "audio"
output_codec = str(runtime_output.get("codec") or settings.default_codec or "pcm").strip().lower() or "pcm"
tools_allowlist: List[str] = []
runtime_tools = runtime.get("tools", {}) if isinstance(runtime, dict) else {}
if isinstance(runtime_tools, dict):
allowlist = runtime_tools.get("allowlist", [])
if isinstance(allowlist, list):
tools_allowlist = [str(item) for item in allowlist if item is not None and str(item).strip()]
resolved: Dict[str, Any] = {
"protocolVersion": self._public_ws_protocol_version(),
"output": {
"mode": output_mode,
"codec": output_codec,
},
"tools": {
"enabled": bool(tools_allowlist),
"count": len(tools_allowlist),
},
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
if metadata.get("channel") is not None:
resolved["channel"] = metadata.get("channel")
return resolved
@staticmethod
def _compute_audio_frame_bytes() -> int:
"""Compute expected PCM frame bytes from SAMPLE_RATE and CHUNK_SIZE_MS."""
sample_rate = max(1, int(getattr(settings, "sample_rate", 16000)))
chunk_ms = max(1, int(getattr(settings, "chunk_size_ms", 20)))
bytes_per_frame = int(round(sample_rate * 2 * (chunk_ms / 1000.0)))
if bytes_per_frame < 2:
bytes_per_frame = 2
if bytes_per_frame % 2 != 0:
bytes_per_frame += 1
return bytes_per_frame
@staticmethod
def _public_ws_protocol_version() -> str:
"""Return public protocol version label announced to clients."""
version = str(getattr(settings, "ws_protocol_version", "v1") or "v1").strip()
return version or "v1"
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None