"""WS v1 protocol message models and helpers.""" from typing import Any, Dict, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationError def now_ms() -> int: """Current unix timestamp in milliseconds.""" import time return int(time.time() * 1000) class _StrictModel(BaseModel): """Protocol models reject unknown fields to enforce WS v1 schema.""" model_config = ConfigDict(extra="forbid") # Client -> Server messages class HelloAuth(_StrictModel): apiKey: Optional[str] = None jwt: Optional[str] = None class HelloMessage(_StrictModel): type: Literal["hello"] version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") auth: Optional[HelloAuth] = Field(default=None, description="Auth payload") class SessionStartAudio(_StrictModel): encoding: Literal["pcm_s16le"] = "pcm_s16le" sample_rate_hz: Literal[16000] = 16000 channels: Literal[1] = 1 class SessionStartMessage(_StrictModel): type: Literal["session.start"] audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata") metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") class SessionStopMessage(_StrictModel): type: Literal["session.stop"] reason: Optional[str] = None class InputTextMessage(_StrictModel): type: Literal["input.text"] text: str class ResponseCancelMessage(_StrictModel): type: Literal["response.cancel"] graceful: bool = False class ToolCallResultStatus(_StrictModel): code: int message: str class ToolCallResult(_StrictModel): tool_call_id: str name: str output: Any = None status: ToolCallResultStatus class ToolCallResultsMessage(_StrictModel): type: Literal["tool_call.results"] results: list[ToolCallResult] = Field(default_factory=list) CLIENT_MESSAGE_TYPES = { "hello": HelloMessage, "session.start": SessionStartMessage, "session.stop": SessionStopMessage, "input.text": InputTextMessage, "response.cancel": ResponseCancelMessage, "tool_call.results": ToolCallResultsMessage, } def parse_client_message(data: Dict[str, Any]) -> BaseModel: """Parse and validate a WS v1 client message.""" msg_type = data.get("type") if not msg_type: raise ValueError("Missing 'type' field") msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) if not msg_class: raise ValueError(f"Unknown client message type: {msg_type}") try: return msg_class(**data) except ValidationError as exc: details = [] for err in exc.errors(): loc = ".".join(str(part) for part in err.get("loc", ())) msg = err.get("msg", "invalid field") details.append(f"{loc}: {msg}" if loc else msg) raise ValueError("; ".join(details)) from exc # Server -> Client event helpers def ev(event_type: str, **payload: Any) -> Dict[str, Any]: """Create a WS v1 server event payload.""" base = {"type": event_type, "timestamp": now_ms()} base.update(payload) return base