Unify db api

This commit is contained in:
Xin Wang
2026-02-26 01:58:39 +08:00
parent 56f8aa2191
commit 72ed7d0512
40 changed files with 3926 additions and 593 deletions

View File

@@ -1,22 +1,19 @@
"""Session management for active calls."""
import asyncio
import uuid
import hashlib
import json
import time
import re
import time
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from app.backend_adapters import build_backend_adapter_from_settings
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.history_bridge import SessionHistoryBridge
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
@@ -49,7 +46,39 @@ class Session:
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(
self,
session_id: str,
transport: BaseTransport,
use_duplex: bool = None,
backend_gateway: Optional[Any] = None,
):
"""
Initialize session.
@@ -61,12 +90,23 @@ class Session:
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
self._history_bridge = SessionHistoryBridge(
history_writer=self._backend_gateway,
enabled=settings.history_enabled,
queue_max_size=settings.history_queue_max_size,
retry_max_attempts=settings.history_retry_max_attempts,
retry_backoff_sec=settings.history_retry_backoff_sec,
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
)
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
greeting=settings.duplex_greeting,
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
)
# Session state
@@ -78,17 +118,15 @@ class Session:
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +167,47 @@ class Session:
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
await self.pipeline.process_audio(audio_bytes)
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=False,
)
return
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
@@ -176,7 +248,7 @@ class Session:
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
@@ -198,9 +270,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
@@ -217,10 +289,9 @@ class Session:
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
await self._send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
@@ -231,8 +302,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -248,28 +323,37 @@ class Session:
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
await self._send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio.model_dump() if message.audio else {},
)
)
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
await self._send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
@@ -285,17 +369,24 @@ class Session:
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
await self._send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
@@ -304,13 +395,26 @@ class Session:
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
@@ -329,11 +433,12 @@ class Session:
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self._history_bridge.shutdown()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
if self._history_bridge.call_id:
return
history_meta: Dict[str, Any] = {}
@@ -349,7 +454,7 @@ class Session:
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
call_id = await self._history_bridge.start_call(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
@@ -357,10 +462,6 @@ class Session:
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
@@ -372,48 +473,11 @@ class Session:
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True
await self._history_bridge.finalize(status=status)
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
@@ -483,10 +547,9 @@ class Session:
node = transition.node
edge = transition.edge
await self.transport.send_event(
await self._send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
@@ -494,10 +557,9 @@ class Session:
reason=reason,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
@@ -510,10 +572,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
@@ -522,10 +583,9 @@ class Session:
return
if node.node_type == "human_transfer":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
@@ -534,16 +594,77 @@ class Session:
return
if node.node_type == "end":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
@@ -629,6 +750,137 @@ class Session:
merged[key] = value
return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
if not callable(provider):
return {}
payload = await provider(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata")
if isinstance(session_start_cfg, dict):
assistant_cfg.update(session_start_cfg)
if isinstance(payload.get("assistant"), dict):
assistant_cfg.update(payload.get("assistant"))
elif not assistant_cfg:
assistant_cfg = payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
passthrough_keys = {
"firstTurnMode",
"generatedOpenerEnabled",
"output",
"bargeIn",
"knowledgeBaseId",
"knowledge",
"history",
"userId",
"source",
"tools",
"services",
"configVersionId",
"config_version_id",
}
for key in passthrough_keys:
if key in assistant_cfg:
runtime[key] = assistant_cfg[key]
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
resolved_assistant_id = (
assistant_cfg.get("assistantId")
or payload.get("assistantId")
or assistant_id
)
runtime["assistantId"] = str(resolved_assistant_id)
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
runtime["configVersionId"] = payload.get("configVersionId")
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
runtime["configVersionId"] = payload.get("config_version_id")
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id")
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try: