Improve schema
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""WS v1 protocol message models and helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
def now_ms() -> int:
|
||||
@@ -11,37 +12,66 @@ def now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
class _StrictModel(BaseModel):
|
||||
"""Protocol models reject unknown fields to enforce WS v1 schema."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
# Client -> Server messages
|
||||
class HelloMessage(BaseModel):
|
||||
class HelloAuth(_StrictModel):
|
||||
apiKey: Optional[str] = None
|
||||
jwt: Optional[str] = None
|
||||
|
||||
|
||||
class HelloMessage(_StrictModel):
|
||||
type: Literal["hello"]
|
||||
version: str = Field(..., description="Protocol version, currently v1")
|
||||
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
|
||||
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
|
||||
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
|
||||
|
||||
|
||||
class SessionStartMessage(BaseModel):
|
||||
class SessionStartAudio(_StrictModel):
|
||||
encoding: Literal["pcm_s16le"] = "pcm_s16le"
|
||||
sample_rate_hz: Literal[16000] = 16000
|
||||
channels: Literal[1] = 1
|
||||
|
||||
|
||||
class SessionStartMessage(_StrictModel):
|
||||
type: Literal["session.start"]
|
||||
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
|
||||
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
||||
|
||||
|
||||
class SessionStopMessage(BaseModel):
|
||||
class SessionStopMessage(_StrictModel):
|
||||
type: Literal["session.stop"]
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class InputTextMessage(BaseModel):
|
||||
class InputTextMessage(_StrictModel):
|
||||
type: Literal["input.text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseCancelMessage(BaseModel):
|
||||
class ResponseCancelMessage(_StrictModel):
|
||||
type: Literal["response.cancel"]
|
||||
graceful: bool = False
|
||||
|
||||
|
||||
class ToolCallResultsMessage(BaseModel):
|
||||
class ToolCallResultStatus(_StrictModel):
|
||||
code: int
|
||||
message: str
|
||||
|
||||
|
||||
class ToolCallResult(_StrictModel):
|
||||
tool_call_id: str
|
||||
name: str
|
||||
output: Any = None
|
||||
status: ToolCallResultStatus
|
||||
|
||||
|
||||
class ToolCallResultsMessage(_StrictModel):
|
||||
type: Literal["tool_call.results"]
|
||||
results: list[Dict[str, Any]] = Field(default_factory=list)
|
||||
results: list[ToolCallResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
CLIENT_MESSAGE_TYPES = {
|
||||
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
|
||||
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
||||
if not msg_class:
|
||||
raise ValueError(f"Unknown client message type: {msg_type}")
|
||||
return msg_class(**data)
|
||||
try:
|
||||
return msg_class(**data)
|
||||
except ValidationError as exc:
|
||||
details = []
|
||||
for err in exc.errors():
|
||||
loc = ".".join(str(part) for part in err.get("loc", ()))
|
||||
msg = err.get("msg", "invalid field")
|
||||
details.append(f"{loc}: {msg}" if loc else msg)
|
||||
raise ValueError("; ".join(details)) from exc
|
||||
|
||||
|
||||
# Server -> Client event helpers
|
||||
|
||||
Reference in New Issue
Block a user