Improve schema

This commit is contained in:
Xin Wang
2026-02-24 05:55:47 +08:00
parent c6c84b5af9
commit 6290fdd60e
4 changed files with 88 additions and 36 deletions

View File

@@ -1,7 +1,8 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationError
def now_ms() -> int:
@@ -11,37 +12,66 @@ def now_ms() -> int:
return int(time.time() * 1000)
class _StrictModel(BaseModel):
"""Protocol models reject unknown fields to enforce WS v1 schema."""
model_config = ConfigDict(extra="forbid")
# Client -> Server messages
class HelloMessage(BaseModel):
class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartMessage(BaseModel):
class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000
channels: Literal[1] = 1
class SessionStartMessage(_StrictModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
class SessionStopMessage(_StrictModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
class InputTextMessage(_StrictModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
class ResponseCancelMessage(_StrictModel):
type: Literal["response.cancel"]
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
class ToolCallResultStatus(_StrictModel):
code: int
message: str
class ToolCallResult(_StrictModel):
tool_call_id: str
name: str
output: Any = None
status: ToolCallResultStatus
class ToolCallResultsMessage(_StrictModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
results: list[ToolCallResult] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data)
try:
return msg_class(**data)
except ValidationError as exc:
details = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = err.get("msg", "invalid field")
details.append(f"{loc}: {msg}" if loc else msg)
raise ValueError("; ".join(details)) from exc
# Server -> Client event helpers