Unify db api
This commit is contained in:
@@ -1,22 +1,19 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
create_history_call_record,
|
||||
add_history_transcript,
|
||||
finalize_history_call_record,
|
||||
)
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.history_bridge import SessionHistoryBridge
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
@@ -49,7 +46,39 @@ class Session:
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_CLIENT_METADATA_OVERRIDES = {
|
||||
"firstTurnMode",
|
||||
"greeting",
|
||||
"generatedOpenerEnabled",
|
||||
"systemPrompt",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledge",
|
||||
"knowledgeBaseId",
|
||||
"history",
|
||||
"userId",
|
||||
"assistantId",
|
||||
"source",
|
||||
}
|
||||
_CLIENT_METADATA_ID_KEYS = {
|
||||
"appId",
|
||||
"app_id",
|
||||
"channel",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
transport: BaseTransport,
|
||||
use_duplex: bool = None,
|
||||
backend_gateway: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
@@ -61,12 +90,23 @@ class Session:
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
||||
self._history_bridge = SessionHistoryBridge(
|
||||
history_writer=self._backend_gateway,
|
||||
enabled=settings.history_enabled,
|
||||
queue_max_size=settings.history_queue_max_size,
|
||||
retry_max_attempts=settings.history_retry_max_attempts,
|
||||
retry_backoff_sec=settings.history_retry_backoff_sec,
|
||||
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
|
||||
)
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
greeting=settings.duplex_greeting,
|
||||
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
|
||||
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
|
||||
)
|
||||
|
||||
# Session state
|
||||
@@ -78,17 +118,15 @@ class Session:
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
self._history_finalized: bool = False
|
||||
self.current_track_id: str = self.TRACK_CONTROL
|
||||
self._event_seq: int = 0
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
@@ -129,13 +167,47 @@ class Session:
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
if not audio_bytes:
|
||||
return
|
||||
if len(audio_bytes) % 2 != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Invalid PCM payload: odd number of bytes",
|
||||
"audio.invalid_pcm",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
frame_bytes = self.AUDIO_FRAME_BYTES
|
||||
if len(audio_bytes) % frame_bytes != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
|
||||
"audio.frame_size_mismatch",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
for i in range(0, len(audio_bytes), frame_bytes):
|
||||
frame = audio_bytes[i : i + frame_bytes]
|
||||
await self.pipeline.process_audio(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
await self._send_error(
|
||||
"server",
|
||||
f"Audio processing failed: {e}",
|
||||
"audio.processing_failed",
|
||||
stage="audio",
|
||||
retryable=True,
|
||||
)
|
||||
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
@@ -176,7 +248,7 @@ class Session:
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
@@ -198,9 +270,9 @@ class Session:
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
auth_payload = message.auth
|
||||
api_key = auth_payload.apiKey if auth_payload else None
|
||||
jwt = auth_payload.jwt if auth_payload else None
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
@@ -217,10 +289,9 @@ class Session:
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
@@ -231,8 +302,12 @@ class Session:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
raw_metadata = message.metadata or {}
|
||||
workflow_runtime = self._bootstrap_workflow(raw_metadata)
|
||||
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
|
||||
client_runtime = self._sanitize_client_metadata(raw_metadata)
|
||||
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
|
||||
metadata = self._merge_runtime_metadata(metadata, client_runtime)
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
@@ -248,28 +323,37 @@ class Session:
|
||||
|
||||
self.state = "accepted"
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
tracks={
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
audio=message.audio.model_dump() if message.audio else {},
|
||||
)
|
||||
)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"config.resolved",
|
||||
trackId=self.TRACK_CONTROL,
|
||||
config=self._build_config_resolved(metadata),
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
@@ -285,17 +369,24 @@ class Session:
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self._finalize_history(status="connected")
|
||||
await self.transport.close()
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
async def _send_error(
|
||||
self,
|
||||
sender: str,
|
||||
error_message: str,
|
||||
code: str,
|
||||
stage: Optional[str] = None,
|
||||
retryable: Optional[bool] = None,
|
||||
track_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
@@ -304,13 +395,26 @@ class Session:
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event(
|
||||
resolved_stage = stage or self._infer_error_stage(code)
|
||||
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
||||
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
stage=resolved_stage,
|
||||
retryable=resolved_retryable,
|
||||
trackId=resolved_track_id,
|
||||
data={
|
||||
"error": {
|
||||
"stage": resolved_stage,
|
||||
"code": code,
|
||||
"message": error_message,
|
||||
"retryable": resolved_retryable,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -329,11 +433,12 @@ class Session:
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self._finalize_history(status="connected")
|
||||
await self.pipeline.cleanup()
|
||||
await self._history_bridge.shutdown()
|
||||
await self.transport.close()
|
||||
|
||||
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
||||
"""Initialize backend history call record for this session."""
|
||||
if self._history_call_id:
|
||||
if self._history_bridge.call_id:
|
||||
return
|
||||
|
||||
history_meta: Dict[str, Any] = {}
|
||||
@@ -349,7 +454,7 @@ class Session:
|
||||
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
||||
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
||||
|
||||
call_id = await create_history_call_record(
|
||||
call_id = await self._history_bridge.start_call(
|
||||
user_id=user_id,
|
||||
assistant_id=str(assistant_id) if assistant_id else None,
|
||||
source=source,
|
||||
@@ -357,10 +462,6 @@ class Session:
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
self._history_call_id = call_id
|
||||
self._history_call_started_mono = time.monotonic()
|
||||
self._history_turn_index = 0
|
||||
self._history_finalized = False
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
@@ -372,48 +473,11 @@ class Session:
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
return
|
||||
|
||||
role = (turn.role or "").lower()
|
||||
speaker = "human" if role == "user" else "ai"
|
||||
|
||||
end_ms = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
||||
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
turn_index = self._history_turn_index
|
||||
await add_history_transcript(
|
||||
call_id=self._history_call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=turn.text.strip(),
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._history_turn_index += 1
|
||||
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
|
||||
|
||||
async def _finalize_history(self, status: str) -> None:
|
||||
"""Finalize history call record once."""
|
||||
if not self._history_call_id or self._history_finalized:
|
||||
return
|
||||
|
||||
duration_seconds = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
||||
|
||||
ok = await finalize_history_call_record(
|
||||
call_id=self._history_call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
await self._history_bridge.finalize(status=status)
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
@@ -483,10 +547,9 @@ class Session:
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
@@ -494,10 +557,9 @@ class Session:
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
@@ -510,10 +572,9 @@ class Session:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
@@ -522,10 +583,9 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
@@ -534,16 +594,77 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
def _next_event_seq(self) -> int:
|
||||
self._event_seq += 1
|
||||
return self._event_seq
|
||||
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("workflow."):
|
||||
return "system"
|
||||
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
|
||||
return "system"
|
||||
if event_type == "error":
|
||||
return "system"
|
||||
return "system"
|
||||
|
||||
def _infer_error_stage(self, code: str) -> str:
|
||||
normalized = str(code or "").strip().lower()
|
||||
if normalized.startswith("audio."):
|
||||
return "audio"
|
||||
if normalized.startswith("tool."):
|
||||
return "tool"
|
||||
if normalized.startswith("asr."):
|
||||
return "asr"
|
||||
if normalized.startswith("llm."):
|
||||
return "llm"
|
||||
if normalized.startswith("tts."):
|
||||
return "tts"
|
||||
return "protocol"
|
||||
|
||||
def _error_track_id(self, stage: str, code: str) -> str:
|
||||
if stage in {"audio", "asr"}:
|
||||
return self.TRACK_AUDIO_IN
|
||||
if stage in {"llm", "tts", "tool"}:
|
||||
return self.TRACK_AUDIO_OUT
|
||||
if str(code or "").strip().lower().startswith("auth."):
|
||||
return self.TRACK_CONTROL
|
||||
return self.TRACK_CONTROL
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
track_id = event.get("trackId") or self.TRACK_CONTROL
|
||||
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
for k, v in event.items():
|
||||
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
|
||||
continue
|
||||
data.setdefault(k, v)
|
||||
|
||||
event["sessionId"] = self.id
|
||||
event["seq"] = self._next_event_seq()
|
||||
event["source"] = source
|
||||
event["trackId"] = track_id
|
||||
event["data"] = data
|
||||
return event
|
||||
|
||||
async def _send_event(self, event: Dict[str, Any]) -> None:
|
||||
await self.transport.send_event(self._envelope_event(event))
|
||||
|
||||
async def send_heartbeat(self) -> None:
|
||||
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
@@ -629,6 +750,137 @@ class Session:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
async def _load_server_runtime_metadata(
|
||||
self,
|
||||
client_metadata: Dict[str, Any],
|
||||
workflow_runtime: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Load trusted runtime metadata from backend assistant config."""
|
||||
assistant_id = (
|
||||
workflow_runtime.get("assistantId")
|
||||
or client_metadata.get("assistantId")
|
||||
or client_metadata.get("appId")
|
||||
or client_metadata.get("app_id")
|
||||
)
|
||||
if assistant_id is None:
|
||||
return {}
|
||||
|
||||
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
||||
if not callable(provider):
|
||||
return {}
|
||||
|
||||
payload = await provider(str(assistant_id).strip())
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
|
||||
assistant_cfg: Dict[str, Any] = {}
|
||||
session_start_cfg = payload.get("sessionStartMetadata")
|
||||
if isinstance(session_start_cfg, dict):
|
||||
assistant_cfg.update(session_start_cfg)
|
||||
if isinstance(payload.get("assistant"), dict):
|
||||
assistant_cfg.update(payload.get("assistant"))
|
||||
elif not assistant_cfg:
|
||||
assistant_cfg = payload
|
||||
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
return {}
|
||||
|
||||
runtime: Dict[str, Any] = {}
|
||||
passthrough_keys = {
|
||||
"firstTurnMode",
|
||||
"generatedOpenerEnabled",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledgeBaseId",
|
||||
"knowledge",
|
||||
"history",
|
||||
"userId",
|
||||
"source",
|
||||
"tools",
|
||||
"services",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
for key in passthrough_keys:
|
||||
if key in assistant_cfg:
|
||||
runtime[key] = assistant_cfg[key]
|
||||
|
||||
if assistant_cfg.get("systemPrompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
|
||||
elif assistant_cfg.get("prompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
|
||||
|
||||
if assistant_cfg.get("greeting") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("greeting")
|
||||
elif assistant_cfg.get("opener") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("opener")
|
||||
|
||||
resolved_assistant_id = (
|
||||
assistant_cfg.get("assistantId")
|
||||
or payload.get("assistantId")
|
||||
or assistant_id
|
||||
)
|
||||
runtime["assistantId"] = str(resolved_assistant_id)
|
||||
|
||||
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
|
||||
runtime["configVersionId"] = payload.get("configVersionId")
|
||||
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
|
||||
runtime["configVersionId"] = payload.get("config_version_id")
|
||||
|
||||
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
||||
runtime["configVersionId"] = runtime.get("config_version_id")
|
||||
return runtime
|
||||
|
||||
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitize untrusted metadata sources.
|
||||
|
||||
This keeps only a small override whitelist and stable config ID fields.
|
||||
"""
|
||||
if not isinstance(metadata, dict):
|
||||
return {}
|
||||
|
||||
sanitized: Dict[str, Any] = {}
|
||||
for key in self._CLIENT_METADATA_ID_KEYS:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
for key in self._CLIENT_METADATA_OVERRIDES:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
|
||||
return sanitized
|
||||
|
||||
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply client metadata whitelist and remove forbidden secrets."""
|
||||
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
|
||||
if isinstance(metadata.get("services"), dict):
|
||||
logger.warning(
|
||||
"Session {} provided metadata.services from client; client-side service config is ignored",
|
||||
self.id,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build public resolved config payload (secrets removed)."""
|
||||
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
|
||||
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
|
||||
runtime = self.pipeline.resolved_runtime_config()
|
||||
|
||||
return {
|
||||
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
|
||||
"channel": metadata.get("channel"),
|
||||
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||
"prompt": {"sha256": prompt_hash},
|
||||
"output": runtime.get("output", {}),
|
||||
"services": runtime.get("services", {}),
|
||||
"tools": runtime.get("tools", {}),
|
||||
"tracks": {
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
}
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user