diff --git a/core/session.py b/core/session.py index 597b694..b741425 100644 --- a/core/session.py +++ b/core/session.py @@ -107,8 +107,6 @@ class Session: # Track IDs self.current_track_id: str = self.TRACK_CONTROL self._event_seq: int = 0 - self._audio_ingress_buffer: bytes = b"" - self._audio_frame_error_reported: bool = False self._history_call_id: Optional[str] = None self._history_turn_index: int = 0 self._history_call_started_mono: Optional[float] = None @@ -179,23 +177,18 @@ class Session: return frame_bytes = self.AUDIO_FRAME_BYTES - self._audio_ingress_buffer += audio_bytes - - # Protocol v1 audio framing: 20ms PCM frame (640 bytes). - # Allow aggregated frames in one WS message (multiple of 640). - if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported: - self._audio_frame_error_reported = True + if len(audio_bytes) % frame_bytes != 0: await self._send_error( "client", - f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)", + f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)", "audio.frame_size_mismatch", stage="audio", - retryable=True, + retryable=False, ) + return - while len(self._audio_ingress_buffer) >= frame_bytes: - frame = self._audio_ingress_buffer[:frame_bytes] - self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:] + for i in range(0, len(audio_bytes), frame_bytes): + frame = audio_bytes[i : i + frame_bytes] await self.pipeline.process_audio(frame) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) @@ -246,7 +239,7 @@ class Session: else: await self.pipeline.interrupt() elif isinstance(message, ToolCallResultsMessage): - await self.pipeline.handle_tool_call_results(message.results) + await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results]) elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: @@ -268,9 +261,9 @@ class Session: self.ws_state = WsSessionState.STOPPED return - auth_payload = message.auth or {} - api_key = auth_payload.get("apiKey") - jwt = auth_payload.get("jwt") + auth_payload = message.auth + api_key = auth_payload.apiKey if auth_payload else None + jwt = auth_payload.jwt if auth_payload else None if settings.ws_api_key: if api_key != settings.ws_api_key: @@ -330,7 +323,7 @@ class Session: "audio_out": self.TRACK_AUDIO_OUT, "control": self.TRACK_CONTROL, }, - audio=message.audio or {}, + audio=message.audio.model_dump() if message.audio else {}, ) ) await self._send_event( @@ -383,6 +376,7 @@ class Session: code: str, stage: Optional[str] = None, retryable: Optional[bool] = None, + track_id: Optional[str] = None, ) -> None: """ Send error event to client. @@ -394,6 +388,7 @@ class Session: """ resolved_stage = stage or self._infer_error_stage(code) resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"}) + resolved_track_id = track_id or self._error_track_id(resolved_stage, code) await self._send_event( ev( "error", @@ -402,7 +397,7 @@ class Session: message=error_message, stage=resolved_stage, retryable=resolved_retryable, - trackId=self.current_track_id, + trackId=resolved_track_id, data={ "error": { "stage": resolved_stage, @@ -666,6 +661,15 @@ class Session: return "tts" return "protocol" + def _error_track_id(self, stage: str, code: str) -> str: + if stage in {"audio", "asr"}: + return self.TRACK_AUDIO_IN + if stage in {"llm", "tts", "tool"}: + return self.TRACK_AUDIO_OUT + if str(code or "").strip().lower().startswith("auth."): + return self.TRACK_CONTROL + return self.TRACK_CONTROL + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: event_type = str(event.get("type") or "") source = str(event.get("source") or self._event_source(event_type)) diff --git a/docs/ws_v1_schema.md b/docs/ws_v1_schema.md index c2a7ab4..1d35e98 100644 --- a/docs/ws_v1_schema.md +++ b/docs/ws_v1_schema.md @@ -2,6 +2,11 @@ This document defines the public WebSocket protocol for the `/ws` endpoint. +Validation policy: +- WS v1 JSON control messages are validated strictly. +- Unknown top-level fields are rejected for all defined client message types. +- `hello.version` is fixed to `"v1"`. + ## Transport - A single WebSocket connection carries: @@ -138,7 +143,8 @@ All server events include an envelope: Envelope notes: - `seq` is monotonically increasing within one session (for replay/resume). -- `source` is one of: `asr | llm | tts | tool | system`. +- `source` is one of: `asr | llm | tts | tool | system | client | server`. + - For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side. - `data` is structured payload; legacy top-level fields are kept for compatibility. Common events: @@ -181,6 +187,10 @@ Common events: - Fields: `trackId`, `latencyMs` - `error` - Fields: `sender`, `code`, `message`, `trackId` + - `trackId` convention: + - `audio_in` for `stage in {audio, asr}` + - `audio_out` for `stage in {llm, tts, tool}` + - `control` otherwise (including protocol/auth errors) Track IDs (MVP fixed values): - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) @@ -207,7 +217,7 @@ MVP fixed format: Framing rules: - Binary audio frame unit is 640 bytes. - A WS binary message may carry one or multiple complete 640-byte frames. -- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors. +- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly). TTS boundary events: - `output.audio.start` and `output.audio.end` mark assistant playback boundaries. diff --git a/models/ws_v1.py b/models/ws_v1.py index b8f5524..6e67164 100644 --- a/models/ws_v1.py +++ b/models/ws_v1.py @@ -1,7 +1,8 @@ """WS v1 protocol message models and helpers.""" -from typing import Optional, Dict, Any, Literal -from pydantic import BaseModel, Field +from typing import Any, Dict, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, ValidationError def now_ms() -> int: @@ -11,37 +12,66 @@ def now_ms() -> int: return int(time.time() * 1000) +class _StrictModel(BaseModel): + """Protocol models reject unknown fields to enforce WS v1 schema.""" + + model_config = ConfigDict(extra="forbid") + + # Client -> Server messages -class HelloMessage(BaseModel): +class HelloAuth(_StrictModel): + apiKey: Optional[str] = None + jwt: Optional[str] = None + + +class HelloMessage(_StrictModel): type: Literal["hello"] - version: str = Field(..., description="Protocol version, currently v1") - auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}") + version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") + auth: Optional[HelloAuth] = Field(default=None, description="Auth payload") -class SessionStartMessage(BaseModel): +class SessionStartAudio(_StrictModel): + encoding: Literal["pcm_s16le"] = "pcm_s16le" + sample_rate_hz: Literal[16000] = 16000 + channels: Literal[1] = 1 + + +class SessionStartMessage(_StrictModel): type: Literal["session.start"] - audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata") + audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata") metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") -class SessionStopMessage(BaseModel): +class SessionStopMessage(_StrictModel): type: Literal["session.stop"] reason: Optional[str] = None -class InputTextMessage(BaseModel): +class InputTextMessage(_StrictModel): type: Literal["input.text"] text: str -class ResponseCancelMessage(BaseModel): +class ResponseCancelMessage(_StrictModel): type: Literal["response.cancel"] graceful: bool = False -class ToolCallResultsMessage(BaseModel): +class ToolCallResultStatus(_StrictModel): + code: int + message: str + + +class ToolCallResult(_StrictModel): + tool_call_id: str + name: str + output: Any = None + status: ToolCallResultStatus + + +class ToolCallResultsMessage(_StrictModel): type: Literal["tool_call.results"] - results: list[Dict[str, Any]] = Field(default_factory=list) + results: list[ToolCallResult] = Field(default_factory=list) CLIENT_MESSAGE_TYPES = { @@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel: msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) if not msg_class: raise ValueError(f"Unknown client message type: {msg_type}") - return msg_class(**data) + try: + return msg_class(**data) + except ValidationError as exc: + details = [] + for err in exc.errors(): + loc = ".".join(str(part) for part in err.get("loc", ())) + msg = err.get("msg", "invalid field") + details.append(f"{loc}: {msg}" if loc else msg) + raise ValueError("; ".join(details)) from exc # Server -> Client event helpers diff --git a/tests/test_tool_call_flow.py b/tests/test_tool_call_flow.py index e5f241b..20d264b 100644 --- a/tests/test_tool_call_flow.py +++ b/tests/test_tool_call_flow.py @@ -97,11 +97,11 @@ async def test_ws_message_parses_tool_call_results(): msg = parse_client_message( { "type": "tool_call.results", - "results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}], + "results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}], } ) assert isinstance(msg, ToolCallResultsMessage) - assert msg.results[0]["tool_call_id"] == "call_1" + assert msg.results[0].tool_call_id == "call_1" @pytest.mark.asyncio