Files
AI-VideoAssistant/engine/core/session.py
2026-02-09 15:49:00 +08:00

380 lines
13 KiB
Python

"""Session management for active calls."""
import asyncio
import uuid
import json
import time
from enum import Enum
from typing import Optional, Dict, Any
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from app.config import settings
from models.ws_v1 import (
parse_client_message,
ev,
HelloMessage,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
)
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
class Session:
"""
Manages a single call session.
Handles command routing, audio processing, and session lifecycle.
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
"""
Initialize session.
Args:
session_id: Unique session identifier
transport: Transport instance for communication
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
"""
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
)
# Session state
self.created_at = None
self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO
self._pipeline_started = False
self.protocol_version: Optional[str] = None
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
async def handle_text(self, text_data: str) -> None:
"""
Handle incoming text data (WS v1 JSON control messages).
Args:
text_data: JSON text data
"""
try:
data = json.loads(text_data)
message = parse_client_message(data)
await self._handle_v1_message(message)
except json.JSONDecodeError as e:
logger.error(f"Session {self.id} JSON decode error: {e}")
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
except ValueError as e:
logger.error(f"Session {self.id} command parse error: {e}")
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
except Exception as e:
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
await self._send_error("server", f"Internal error: {e}", "server.internal")
async def handle_audio(self, audio_bytes: bytes) -> None:
"""
Handle incoming audio data.
Args:
audio_bytes: PCM audio data
"""
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
"Audio received before session.start",
"protocol.order",
)
return
try:
await self.pipeline.process_audio(audio_bytes)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, HelloMessage):
await self._handle_hello(message)
return
# All messages below require hello handshake first
if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
f"Message '{msg_type}' requires active session",
"protocol.order",
)
return
if isinstance(message, InputTextMessage):
await self.pipeline.process_text(message.text)
elif isinstance(message, ResponseCancelMessage):
if message.graceful:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/auth/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
if settings.ws_api_key:
if api_key != settings.ws_api_key:
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
elif settings.ws_require_auth and not (api_key or jwt):
await self._send_error("auth", "Authentication required", "auth.required")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
# Apply runtime service/prompt overrides from backend if provided
self.pipeline.apply_runtime_overrides(metadata)
# Start duplex pipeline
if not self._pipeline_started:
await self.pipeline.start()
self._pipeline_started = True
logger.info(f"Session {self.id} duplex pipeline started")
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
)
)
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
if self.ws_state == WsSessionState.STOPPED:
return
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
"""
Send error event to client.
Args:
sender: Component that generated the error
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
)
)
def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds."""
import time
return int(time.time() * 1000)
async def cleanup(self) -> None:
"""Cleanup session resources."""
async with self._cleanup_lock:
if self._cleaned_up:
return
self._cleaned_up = True
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
return
history_meta: Dict[str, Any] = {}
if isinstance(metadata.get("history"), dict):
history_meta = metadata["history"]
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
try:
user_id = int(raw_user_id)
except (TypeError, ValueError):
user_id = settings.history_default_user_id
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
)
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Persist completed turns to backend call transcripts."""
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True