649 lines
23 KiB
Python
649 lines
23 KiB
Python
"""Session management for active calls."""
|
|
|
|
import asyncio
|
|
import uuid
|
|
import json
|
|
import time
|
|
import re
|
|
from enum import Enum
|
|
from typing import Optional, Dict, Any, List
|
|
from loguru import logger
|
|
|
|
from app.backend_client import (
|
|
create_history_call_record,
|
|
add_history_transcript,
|
|
finalize_history_call_record,
|
|
)
|
|
from core.transports import BaseTransport
|
|
from core.duplex_pipeline import DuplexPipeline
|
|
from core.conversation import ConversationTurn
|
|
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
|
from app.config import settings
|
|
from services.base import LLMMessage
|
|
from models.ws_v1 import (
|
|
parse_client_message,
|
|
ev,
|
|
HelloMessage,
|
|
SessionStartMessage,
|
|
SessionStopMessage,
|
|
InputTextMessage,
|
|
ResponseCancelMessage,
|
|
ToolCallResultsMessage,
|
|
)
|
|
|
|
|
|
class WsSessionState(str, Enum):
|
|
"""Protocol state machine for WS sessions."""
|
|
|
|
WAIT_HELLO = "wait_hello"
|
|
WAIT_START = "wait_start"
|
|
ACTIVE = "active"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Session:
|
|
"""
|
|
Manages a single call session.
|
|
|
|
Handles command routing, audio processing, and session lifecycle.
|
|
Uses full duplex voice conversation pipeline.
|
|
"""
|
|
|
|
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
|
"""
|
|
Initialize session.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
transport: Transport instance for communication
|
|
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
|
"""
|
|
self.id = session_id
|
|
self.transport = transport
|
|
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
|
|
|
self.pipeline = DuplexPipeline(
|
|
transport=transport,
|
|
session_id=session_id,
|
|
system_prompt=settings.duplex_system_prompt,
|
|
greeting=settings.duplex_greeting
|
|
)
|
|
|
|
# Session state
|
|
self.created_at = None
|
|
self.state = "created" # Legacy call state for /call/lists
|
|
self.ws_state = WsSessionState.WAIT_HELLO
|
|
self._pipeline_started = False
|
|
self.protocol_version: Optional[str] = None
|
|
self.authenticated: bool = False
|
|
|
|
# Track IDs
|
|
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
|
self._history_call_id: Optional[str] = None
|
|
self._history_turn_index: int = 0
|
|
self._history_call_started_mono: Optional[float] = None
|
|
self._history_finalized: bool = False
|
|
self._cleanup_lock = asyncio.Lock()
|
|
self._cleaned_up = False
|
|
self.workflow_runner: Optional[WorkflowRunner] = None
|
|
self._workflow_last_user_text: str = ""
|
|
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
|
|
|
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
|
|
|
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
|
|
|
async def handle_text(self, text_data: str) -> None:
|
|
"""
|
|
Handle incoming text data (WS v1 JSON control messages).
|
|
|
|
Args:
|
|
text_data: JSON text data
|
|
"""
|
|
try:
|
|
data = json.loads(text_data)
|
|
message = parse_client_message(data)
|
|
await self._handle_v1_message(message)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Session {self.id} JSON decode error: {e}")
|
|
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Session {self.id} command parse error: {e}")
|
|
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
|
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
|
|
|
async def handle_audio(self, audio_bytes: bytes) -> None:
|
|
"""
|
|
Handle incoming audio data.
|
|
|
|
Args:
|
|
audio_bytes: PCM audio data
|
|
"""
|
|
if self.ws_state != WsSessionState.ACTIVE:
|
|
await self._send_error(
|
|
"client",
|
|
"Audio received before session.start",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self.pipeline.process_audio(audio_bytes)
|
|
except Exception as e:
|
|
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
|
|
|
async def _handle_v1_message(self, message: Any) -> None:
|
|
"""Route validated WS v1 message to handlers."""
|
|
msg_type = message.type
|
|
logger.info(f"Session {self.id} received message: {msg_type}")
|
|
|
|
if isinstance(message, HelloMessage):
|
|
await self._handle_hello(message)
|
|
return
|
|
|
|
# All messages below require hello handshake first
|
|
if self.ws_state == WsSessionState.WAIT_HELLO:
|
|
await self._send_error(
|
|
"client",
|
|
"Expected hello message first",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
if isinstance(message, SessionStartMessage):
|
|
await self._handle_session_start(message)
|
|
return
|
|
|
|
# All messages below require active session
|
|
if self.ws_state != WsSessionState.ACTIVE:
|
|
await self._send_error(
|
|
"client",
|
|
f"Message '{msg_type}' requires active session",
|
|
"protocol.order",
|
|
)
|
|
return
|
|
|
|
if isinstance(message, InputTextMessage):
|
|
await self.pipeline.process_text(message.text)
|
|
elif isinstance(message, ResponseCancelMessage):
|
|
if message.graceful:
|
|
logger.info(f"Session {self.id} graceful response.cancel")
|
|
else:
|
|
await self.pipeline.interrupt()
|
|
elif isinstance(message, ToolCallResultsMessage):
|
|
await self.pipeline.handle_tool_call_results(message.results)
|
|
elif isinstance(message, SessionStopMessage):
|
|
await self._handle_session_stop(message.reason)
|
|
else:
|
|
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
|
|
|
async def _handle_hello(self, message: HelloMessage) -> None:
|
|
"""Handle initial hello/auth/version negotiation."""
|
|
if self.ws_state != WsSessionState.WAIT_HELLO:
|
|
await self._send_error("client", "Duplicate hello", "protocol.order")
|
|
return
|
|
|
|
if message.version != settings.ws_protocol_version:
|
|
await self._send_error(
|
|
"client",
|
|
f"Unsupported protocol version '{message.version}'",
|
|
"protocol.version_unsupported",
|
|
)
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
|
|
auth_payload = message.auth or {}
|
|
api_key = auth_payload.get("apiKey")
|
|
jwt = auth_payload.get("jwt")
|
|
|
|
if settings.ws_api_key:
|
|
if api_key != settings.ws_api_key:
|
|
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
elif settings.ws_require_auth and not (api_key or jwt):
|
|
await self._send_error("auth", "Authentication required", "auth.required")
|
|
await self.transport.close()
|
|
self.ws_state = WsSessionState.STOPPED
|
|
return
|
|
|
|
self.authenticated = True
|
|
self.protocol_version = message.version
|
|
self.ws_state = WsSessionState.WAIT_START
|
|
await self.transport.send_event(
|
|
ev(
|
|
"hello.ack",
|
|
sessionId=self.id,
|
|
version=self.protocol_version,
|
|
)
|
|
)
|
|
|
|
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
|
"""Handle explicit session start after successful hello."""
|
|
if self.ws_state != WsSessionState.WAIT_START:
|
|
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
|
return
|
|
|
|
metadata = message.metadata or {}
|
|
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
|
|
|
# Create history call record early so later turn callbacks can append transcripts.
|
|
await self._start_history_bridge(metadata)
|
|
|
|
# Apply runtime service/prompt overrides from backend if provided
|
|
self.pipeline.apply_runtime_overrides(metadata)
|
|
|
|
# Start duplex pipeline
|
|
if not self._pipeline_started:
|
|
await self.pipeline.start()
|
|
self._pipeline_started = True
|
|
logger.info(f"Session {self.id} duplex pipeline started")
|
|
|
|
self.state = "accepted"
|
|
self.ws_state = WsSessionState.ACTIVE
|
|
await self.transport.send_event(
|
|
ev(
|
|
"session.started",
|
|
sessionId=self.id,
|
|
trackId=self.current_track_id,
|
|
audio=message.audio or {},
|
|
)
|
|
)
|
|
if self.workflow_runner and self._workflow_initial_node:
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.started",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
workflowName=self.workflow_runner.name,
|
|
nodeId=self._workflow_initial_node.id,
|
|
)
|
|
)
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.node.entered",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
nodeId=self._workflow_initial_node.id,
|
|
nodeName=self._workflow_initial_node.name,
|
|
nodeType=self._workflow_initial_node.node_type,
|
|
)
|
|
)
|
|
|
|
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
|
"""Handle session stop."""
|
|
if self.ws_state == WsSessionState.STOPPED:
|
|
return
|
|
|
|
stop_reason = reason or "client_requested"
|
|
self.state = "hungup"
|
|
self.ws_state = WsSessionState.STOPPED
|
|
await self.transport.send_event(
|
|
ev(
|
|
"session.stopped",
|
|
sessionId=self.id,
|
|
reason=stop_reason,
|
|
)
|
|
)
|
|
await self._finalize_history(status="connected")
|
|
await self.transport.close()
|
|
|
|
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
|
"""
|
|
Send error event to client.
|
|
|
|
Args:
|
|
sender: Component that generated the error
|
|
error_message: Error message
|
|
code: Machine-readable error code
|
|
"""
|
|
await self.transport.send_event(
|
|
ev(
|
|
"error",
|
|
sender=sender,
|
|
code=code,
|
|
message=error_message,
|
|
trackId=self.current_track_id,
|
|
)
|
|
)
|
|
|
|
def _get_timestamp_ms(self) -> int:
|
|
"""Get current timestamp in milliseconds."""
|
|
import time
|
|
return int(time.time() * 1000)
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Cleanup session resources."""
|
|
async with self._cleanup_lock:
|
|
if self._cleaned_up:
|
|
return
|
|
|
|
self._cleaned_up = True
|
|
logger.info(f"Session {self.id} cleaning up")
|
|
await self._finalize_history(status="connected")
|
|
await self.pipeline.cleanup()
|
|
await self.transport.close()
|
|
|
|
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
|
"""Initialize backend history call record for this session."""
|
|
if self._history_call_id:
|
|
return
|
|
|
|
history_meta: Dict[str, Any] = {}
|
|
if isinstance(metadata.get("history"), dict):
|
|
history_meta = metadata["history"]
|
|
|
|
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
|
|
try:
|
|
user_id = int(raw_user_id)
|
|
except (TypeError, ValueError):
|
|
user_id = settings.history_default_user_id
|
|
|
|
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
|
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
|
|
|
call_id = await create_history_call_record(
|
|
user_id=user_id,
|
|
assistant_id=str(assistant_id) if assistant_id else None,
|
|
source=source,
|
|
)
|
|
if not call_id:
|
|
return
|
|
|
|
self._history_call_id = call_id
|
|
self._history_call_started_mono = time.monotonic()
|
|
self._history_turn_index = 0
|
|
self._history_finalized = False
|
|
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
|
|
|
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
|
"""Process workflow transitions and persist completed turns to history."""
|
|
if turn.text and turn.text.strip():
|
|
role = (turn.role or "").lower()
|
|
if role == "user":
|
|
self._workflow_last_user_text = turn.text.strip()
|
|
elif role == "assistant":
|
|
await self._maybe_advance_workflow(turn.text.strip())
|
|
|
|
if not self._history_call_id:
|
|
return
|
|
if not turn.text or not turn.text.strip():
|
|
return
|
|
|
|
role = (turn.role or "").lower()
|
|
speaker = "human" if role == "user" else "ai"
|
|
|
|
end_ms = 0
|
|
if self._history_call_started_mono is not None:
|
|
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
|
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
|
start_ms = max(0, end_ms - estimated_duration_ms)
|
|
|
|
turn_index = self._history_turn_index
|
|
await add_history_transcript(
|
|
call_id=self._history_call_id,
|
|
turn_index=turn_index,
|
|
speaker=speaker,
|
|
content=turn.text.strip(),
|
|
start_ms=start_ms,
|
|
end_ms=end_ms,
|
|
duration_ms=max(1, end_ms - start_ms),
|
|
)
|
|
self._history_turn_index += 1
|
|
|
|
async def _finalize_history(self, status: str) -> None:
|
|
"""Finalize history call record once."""
|
|
if not self._history_call_id or self._history_finalized:
|
|
return
|
|
|
|
duration_seconds = 0
|
|
if self._history_call_started_mono is not None:
|
|
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
|
|
|
ok = await finalize_history_call_record(
|
|
call_id=self._history_call_id,
|
|
status=status,
|
|
duration_seconds=duration_seconds,
|
|
)
|
|
if ok:
|
|
self._history_finalized = True
|
|
|
|
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Parse workflow payload and return initial runtime overrides."""
|
|
payload = metadata.get("workflow")
|
|
self.workflow_runner = WorkflowRunner.from_payload(payload)
|
|
self._workflow_initial_node = None
|
|
if not self.workflow_runner:
|
|
return {}
|
|
|
|
node = self.workflow_runner.bootstrap()
|
|
if not node:
|
|
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
|
|
self.workflow_runner = None
|
|
return {}
|
|
|
|
self._workflow_initial_node = node
|
|
logger.info(
|
|
"Session {} workflow enabled: workflow={} start_node={}",
|
|
self.id,
|
|
self.workflow_runner.workflow_id,
|
|
node.id,
|
|
)
|
|
return self.workflow_runner.build_runtime_metadata(node)
|
|
|
|
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
|
|
"""Attempt node transfer after assistant turn finalization."""
|
|
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
|
|
return
|
|
|
|
transition = await self.workflow_runner.route(
|
|
user_text=self._workflow_last_user_text,
|
|
assistant_text=assistant_text,
|
|
llm_router=self._workflow_llm_route,
|
|
)
|
|
if not transition:
|
|
return
|
|
|
|
await self._apply_workflow_transition(transition, reason="rule_match")
|
|
|
|
# Auto-advance through utility nodes when default edges are present.
|
|
max_auto_hops = 6
|
|
auto_hops = 0
|
|
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
|
|
current = self.workflow_runner.current_node
|
|
if not current or current.node_type not in {"start", "tool"}:
|
|
break
|
|
|
|
next_default = self.workflow_runner.next_default_transition()
|
|
if not next_default:
|
|
break
|
|
|
|
auto_hops += 1
|
|
await self._apply_workflow_transition(next_default, reason="auto")
|
|
if auto_hops >= max_auto_hops:
|
|
logger.warning(
|
|
"Session {} workflow auto-advance reached hop limit (possible cycle)",
|
|
self.id,
|
|
)
|
|
break
|
|
|
|
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
|
|
"""Apply graph transition and emit workflow lifecycle events."""
|
|
if not self.workflow_runner:
|
|
return
|
|
|
|
self.workflow_runner.apply_transition(transition)
|
|
node = transition.node
|
|
edge = transition.edge
|
|
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.edge.taken",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
edgeId=edge.id,
|
|
fromNodeId=edge.from_node_id,
|
|
toNodeId=edge.to_node_id,
|
|
reason=reason,
|
|
)
|
|
)
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.node.entered",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
nodeId=node.id,
|
|
nodeName=node.name,
|
|
nodeType=node.node_type,
|
|
)
|
|
)
|
|
|
|
node_runtime = self.workflow_runner.build_runtime_metadata(node)
|
|
if node_runtime:
|
|
self.pipeline.apply_runtime_overrides(node_runtime)
|
|
|
|
if node.node_type == "tool":
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.tool.requested",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
nodeId=node.id,
|
|
tool=node.tool or {},
|
|
)
|
|
)
|
|
return
|
|
|
|
if node.node_type == "human_transfer":
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.human_transfer",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
nodeId=node.id,
|
|
)
|
|
)
|
|
await self._handle_session_stop("workflow_human_transfer")
|
|
return
|
|
|
|
if node.node_type == "end":
|
|
await self.transport.send_event(
|
|
ev(
|
|
"workflow.ended",
|
|
sessionId=self.id,
|
|
workflowId=self.workflow_runner.workflow_id,
|
|
nodeId=node.id,
|
|
)
|
|
)
|
|
await self._handle_session_stop("workflow_end")
|
|
|
|
async def _workflow_llm_route(
|
|
self,
|
|
node: WorkflowNodeDef,
|
|
candidates: List[WorkflowEdgeDef],
|
|
context: Dict[str, str],
|
|
) -> Optional[str]:
|
|
"""LLM-based edge routing for condition.type == 'llm' edges."""
|
|
llm_service = self.pipeline.llm_service
|
|
if not llm_service:
|
|
return None
|
|
|
|
candidate_rows = [
|
|
{
|
|
"edgeId": edge.id,
|
|
"toNodeId": edge.to_node_id,
|
|
"label": edge.label,
|
|
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
|
|
}
|
|
for edge in candidates
|
|
]
|
|
system_prompt = (
|
|
"You are a workflow router. Pick exactly one edge. "
|
|
"Return JSON only: {\"edgeId\":\"...\"}."
|
|
)
|
|
user_prompt = json.dumps(
|
|
{
|
|
"nodeId": node.id,
|
|
"nodeName": node.name,
|
|
"userText": context.get("userText", ""),
|
|
"assistantText": context.get("assistantText", ""),
|
|
"candidates": candidate_rows,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
|
|
try:
|
|
reply = await llm_service.generate(
|
|
[
|
|
LLMMessage(role="system", content=system_prompt),
|
|
LLMMessage(role="user", content=user_prompt),
|
|
],
|
|
temperature=0.0,
|
|
max_tokens=64,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
|
|
return None
|
|
|
|
if not reply:
|
|
return None
|
|
|
|
edge_ids = {edge.id for edge in candidates}
|
|
node_ids = {edge.to_node_id for edge in candidates}
|
|
|
|
parsed = self._extract_json_obj(reply)
|
|
if isinstance(parsed, dict):
|
|
edge_id = parsed.get("edgeId") or parsed.get("id")
|
|
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
|
|
if isinstance(edge_id, str) and edge_id in edge_ids:
|
|
return edge_id
|
|
if isinstance(node_id, str) and node_id in node_ids:
|
|
return node_id
|
|
|
|
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
|
|
lowered_reply = reply.lower()
|
|
for token in token_candidates:
|
|
if token.lower() in lowered_reply:
|
|
return token
|
|
return None
|
|
|
|
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Merge node-level metadata overrides into session.start metadata."""
|
|
merged = dict(base or {})
|
|
if not overrides:
|
|
return merged
|
|
for key, value in overrides.items():
|
|
if key == "services" and isinstance(value, dict):
|
|
existing = merged.get("services")
|
|
merged_services = dict(existing) if isinstance(existing, dict) else {}
|
|
merged_services.update(value)
|
|
merged["services"] = merged_services
|
|
else:
|
|
merged[key] = value
|
|
return merged
|
|
|
|
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
|
"""Best-effort extraction of a JSON object from freeform text."""
|
|
try:
|
|
parsed = json.loads(text)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except Exception:
|
|
pass
|
|
|
|
match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
if not match:
|
|
return None
|
|
try:
|
|
parsed = json.loads(match.group(0))
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except Exception:
|
|
return None
|