Unify db api

This commit is contained in:
Xin Wang
2026-02-26 01:58:39 +08:00
parent 56f8aa2191
commit 72ed7d0512
40 changed files with 3926 additions and 593 deletions

View File

@@ -1,7 +1,8 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationError
def now_ms() -> int:
@@ -11,37 +12,66 @@ def now_ms() -> int:
return int(time.time() * 1000)
class _StrictModel(BaseModel):
"""Protocol models reject unknown fields to enforce WS v1 schema."""
model_config = ConfigDict(extra="forbid")
# Client -> Server messages
class HelloMessage(BaseModel):
class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartMessage(BaseModel):
class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000
channels: Literal[1] = 1
class SessionStartMessage(_StrictModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
class SessionStopMessage(_StrictModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
class InputTextMessage(_StrictModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
class ResponseCancelMessage(_StrictModel):
type: Literal["response.cancel"]
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
class ToolCallResultStatus(_StrictModel):
code: int
message: str
class ToolCallResult(_StrictModel):
tool_call_id: str
name: str
output: Any = None
status: ToolCallResultStatus
class ToolCallResultsMessage(_StrictModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
results: list[ToolCallResult] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data)
try:
return msg_class(**data)
except ValidationError as exc:
details = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = err.get("msg", "invalid field")
details.append(f"{loc}: {msg}" if loc else msg)
raise ValueError("; ".join(details)) from exc
# Server -> Client event helpers