Improve schema

This commit is contained in:
Xin Wang
2026-02-24 05:55:47 +08:00
parent c6c84b5af9
commit 6290fdd60e
4 changed files with 88 additions and 36 deletions

View File

@@ -107,8 +107,6 @@ class Session:
# Track IDs
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
@@ -179,23 +177,18 @@ class Session:
return
frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=True,
retryable=False,
)
return
while len(self._audio_ingress_buffer) >= frame_bytes:
frame = self._audio_ingress_buffer[:frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
@@ -246,7 +239,7 @@ class Session:
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
@@ -268,9 +261,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
@@ -330,7 +323,7 @@ class Session:
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio or {},
audio=message.audio.model_dump() if message.audio else {},
)
)
await self._send_event(
@@ -383,6 +376,7 @@ class Session:
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
@@ -394,6 +388,7 @@ class Session:
"""
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
@@ -402,7 +397,7 @@ class Session:
message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=self.current_track_id,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
@@ -666,6 +661,15 @@ class Session:
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))

View File

@@ -2,6 +2,11 @@
This document defines the public WebSocket protocol for the `/ws` endpoint.
Validation policy:
- WS v1 JSON control messages are validated strictly.
- Unknown top-level fields are rejected for all defined client message types.
- `hello.version` is fixed to `"v1"`.
## Transport
- A single WebSocket connection carries:
@@ -138,7 +143,8 @@ All server events include an envelope:
Envelope notes:
- `seq` is monotonically increasing within one session (for replay/resume).
- `source` is one of: `asr | llm | tts | tool | system`.
- `source` is one of: `asr | llm | tts | tool | system | client | server`.
- For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side.
- `data` is structured payload; legacy top-level fields are kept for compatibility.
Common events:
@@ -181,6 +187,10 @@ Common events:
- Fields: `trackId`, `latencyMs`
- `error`
- Fields: `sender`, `code`, `message`, `trackId`
- `trackId` convention:
- `audio_in` for `stage in {audio, asr}`
- `audio_out` for `stage in {llm, tts, tool}`
- `control` otherwise (including protocol/auth errors)
Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
@@ -207,7 +217,7 @@ MVP fixed format:
Framing rules:
- Binary audio frame unit is 640 bytes.
- A WS binary message may carry one or multiple complete 640-byte frames.
- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors.
- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly).
TTS boundary events:
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.

View File

@@ -1,7 +1,8 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationError
def now_ms() -> int:
@@ -11,37 +12,66 @@ def now_ms() -> int:
return int(time.time() * 1000)
class _StrictModel(BaseModel):
"""Protocol models reject unknown fields to enforce WS v1 schema."""
model_config = ConfigDict(extra="forbid")
# Client -> Server messages
class HelloMessage(BaseModel):
class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartMessage(BaseModel):
class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000
channels: Literal[1] = 1
class SessionStartMessage(_StrictModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
class SessionStopMessage(_StrictModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
class InputTextMessage(_StrictModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
class ResponseCancelMessage(_StrictModel):
type: Literal["response.cancel"]
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
class ToolCallResultStatus(_StrictModel):
code: int
message: str
class ToolCallResult(_StrictModel):
tool_call_id: str
name: str
output: Any = None
status: ToolCallResultStatus
class ToolCallResultsMessage(_StrictModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
results: list[ToolCallResult] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data)
try:
return msg_class(**data)
except ValidationError as exc:
details = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = err.get("msg", "invalid field")
details.append(f"{loc}: {msg}" if loc else msg)
raise ValueError("; ".join(details)) from exc
# Server -> Client event helpers

View File

@@ -97,11 +97,11 @@ async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
"results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1"
assert msg.results[0].tool_call_id == "call_1"
@pytest.mark.asyncio