Improve schema

This commit is contained in:
Xin Wang
2026-02-24 05:55:47 +08:00
parent c6c84b5af9
commit 6290fdd60e
4 changed files with 88 additions and 36 deletions

View File

@@ -107,8 +107,6 @@ class Session:
# Track IDs # Track IDs
self.current_track_id: str = self.TRACK_CONTROL self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0 self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0 self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None self._history_call_started_mono: Optional[float] = None
@@ -179,23 +177,18 @@ class Session:
return return
frame_bytes = self.AUDIO_FRAME_BYTES frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes if len(audio_bytes) % frame_bytes != 0:
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
await self._send_error( await self._send_error(
"client", "client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)", f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch", "audio.frame_size_mismatch",
stage="audio", stage="audio",
retryable=True, retryable=False,
) )
return
while len(self._audio_ingress_buffer) >= frame_bytes: for i in range(0, len(audio_bytes), frame_bytes):
frame = self._audio_ingress_buffer[:frame_bytes] frame = audio_bytes[i : i + frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
await self.pipeline.process_audio(frame) await self.pipeline.process_audio(frame)
except Exception as e: except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
@@ -246,7 +239,7 @@ class Session:
else: else:
await self.pipeline.interrupt() await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage): elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results) await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage): elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason) await self._handle_session_stop(message.reason)
else: else:
@@ -268,9 +261,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED self.ws_state = WsSessionState.STOPPED
return return
auth_payload = message.auth or {} auth_payload = message.auth
api_key = auth_payload.get("apiKey") api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.get("jwt") jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key: if settings.ws_api_key:
if api_key != settings.ws_api_key: if api_key != settings.ws_api_key:
@@ -330,7 +323,7 @@ class Session:
"audio_out": self.TRACK_AUDIO_OUT, "audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL, "control": self.TRACK_CONTROL,
}, },
audio=message.audio or {}, audio=message.audio.model_dump() if message.audio else {},
) )
) )
await self._send_event( await self._send_event(
@@ -383,6 +376,7 @@ class Session:
code: str, code: str,
stage: Optional[str] = None, stage: Optional[str] = None,
retryable: Optional[bool] = None, retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None: ) -> None:
""" """
Send error event to client. Send error event to client.
@@ -394,6 +388,7 @@ class Session:
""" """
resolved_stage = stage or self._infer_error_stage(code) resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"}) resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event( await self._send_event(
ev( ev(
"error", "error",
@@ -402,7 +397,7 @@ class Session:
message=error_message, message=error_message,
stage=resolved_stage, stage=resolved_stage,
retryable=resolved_retryable, retryable=resolved_retryable,
trackId=self.current_track_id, trackId=resolved_track_id,
data={ data={
"error": { "error": {
"stage": resolved_stage, "stage": resolved_stage,
@@ -666,6 +661,15 @@ class Session:
return "tts" return "tts"
return "protocol" return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "") event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type)) source = str(event.get("source") or self._event_source(event_type))

View File

@@ -2,6 +2,11 @@
This document defines the public WebSocket protocol for the `/ws` endpoint. This document defines the public WebSocket protocol for the `/ws` endpoint.
Validation policy:
- WS v1 JSON control messages are validated strictly.
- Unknown top-level fields are rejected for all defined client message types.
- `hello.version` is fixed to `"v1"`.
## Transport ## Transport
- A single WebSocket connection carries: - A single WebSocket connection carries:
@@ -138,7 +143,8 @@ All server events include an envelope:
Envelope notes: Envelope notes:
- `seq` is monotonically increasing within one session (for replay/resume). - `seq` is monotonically increasing within one session (for replay/resume).
- `source` is one of: `asr | llm | tts | tool | system`. - `source` is one of: `asr | llm | tts | tool | system | client | server`.
- For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side.
- `data` is structured payload; legacy top-level fields are kept for compatibility. - `data` is structured payload; legacy top-level fields are kept for compatibility.
Common events: Common events:
@@ -181,6 +187,10 @@ Common events:
- Fields: `trackId`, `latencyMs` - Fields: `trackId`, `latencyMs`
- `error` - `error`
- Fields: `sender`, `code`, `message`, `trackId` - Fields: `sender`, `code`, `message`, `trackId`
- `trackId` convention:
- `audio_in` for `stage in {audio, asr}`
- `audio_out` for `stage in {llm, tts, tool}`
- `control` otherwise (including protocol/auth errors)
Track IDs (MVP fixed values): Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) - `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
@@ -207,7 +217,7 @@ MVP fixed format:
Framing rules: Framing rules:
- Binary audio frame unit is 640 bytes. - Binary audio frame unit is 640 bytes.
- A WS binary message may carry one or multiple complete 640-byte frames. - A WS binary message may carry one or multiple complete 640-byte frames.
- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors. - Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly).
TTS boundary events: TTS boundary events:
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries. - `output.audio.start` and `output.audio.end` mark assistant playback boundaries.

View File

@@ -1,7 +1,8 @@
"""WS v1 protocol message models and helpers.""" """WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, ValidationError
def now_ms() -> int: def now_ms() -> int:
@@ -11,37 +12,66 @@ def now_ms() -> int:
return int(time.time() * 1000) return int(time.time() * 1000)
class _StrictModel(BaseModel):
"""Protocol models reject unknown fields to enforce WS v1 schema."""
model_config = ConfigDict(extra="forbid")
# Client -> Server messages # Client -> Server messages
class HelloMessage(BaseModel): class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel):
type: Literal["hello"] type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1") version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}") auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartMessage(BaseModel): class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000
channels: Literal[1] = 1
class SessionStartMessage(_StrictModel):
type: Literal["session.start"] type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata") audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel): class SessionStopMessage(_StrictModel):
type: Literal["session.stop"] type: Literal["session.stop"]
reason: Optional[str] = None reason: Optional[str] = None
class InputTextMessage(BaseModel): class InputTextMessage(_StrictModel):
type: Literal["input.text"] type: Literal["input.text"]
text: str text: str
class ResponseCancelMessage(BaseModel): class ResponseCancelMessage(_StrictModel):
type: Literal["response.cancel"] type: Literal["response.cancel"]
graceful: bool = False graceful: bool = False
class ToolCallResultsMessage(BaseModel): class ToolCallResultStatus(_StrictModel):
code: int
message: str
class ToolCallResult(_StrictModel):
tool_call_id: str
name: str
output: Any = None
status: ToolCallResultStatus
class ToolCallResultsMessage(_StrictModel):
type: Literal["tool_call.results"] type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list) results: list[ToolCallResult] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = { CLIENT_MESSAGE_TYPES = {
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class: if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}") raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data) try:
return msg_class(**data)
except ValidationError as exc:
details = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = err.get("msg", "invalid field")
details.append(f"{loc}: {msg}" if loc else msg)
raise ValueError("; ".join(details)) from exc
# Server -> Client event helpers # Server -> Client event helpers

View File

@@ -97,11 +97,11 @@ async def test_ws_message_parses_tool_call_results():
msg = parse_client_message( msg = parse_client_message(
{ {
"type": "tool_call.results", "type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}], "results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
} }
) )
assert isinstance(msg, ToolCallResultsMessage) assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1" assert msg.results[0].tool_call_id == "call_1"
@pytest.mark.asyncio @pytest.mark.asyncio