Improve schema
This commit is contained in:
@@ -107,8 +107,6 @@ class Session:
|
|||||||
# Track IDs
|
# Track IDs
|
||||||
self.current_track_id: str = self.TRACK_CONTROL
|
self.current_track_id: str = self.TRACK_CONTROL
|
||||||
self._event_seq: int = 0
|
self._event_seq: int = 0
|
||||||
self._audio_ingress_buffer: bytes = b""
|
|
||||||
self._audio_frame_error_reported: bool = False
|
|
||||||
self._history_call_id: Optional[str] = None
|
self._history_call_id: Optional[str] = None
|
||||||
self._history_turn_index: int = 0
|
self._history_turn_index: int = 0
|
||||||
self._history_call_started_mono: Optional[float] = None
|
self._history_call_started_mono: Optional[float] = None
|
||||||
@@ -179,23 +177,18 @@ class Session:
|
|||||||
return
|
return
|
||||||
|
|
||||||
frame_bytes = self.AUDIO_FRAME_BYTES
|
frame_bytes = self.AUDIO_FRAME_BYTES
|
||||||
self._audio_ingress_buffer += audio_bytes
|
if len(audio_bytes) % frame_bytes != 0:
|
||||||
|
|
||||||
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
|
|
||||||
# Allow aggregated frames in one WS message (multiple of 640).
|
|
||||||
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
|
|
||||||
self._audio_frame_error_reported = True
|
|
||||||
await self._send_error(
|
await self._send_error(
|
||||||
"client",
|
"client",
|
||||||
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
|
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
|
||||||
"audio.frame_size_mismatch",
|
"audio.frame_size_mismatch",
|
||||||
stage="audio",
|
stage="audio",
|
||||||
retryable=True,
|
retryable=False,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
while len(self._audio_ingress_buffer) >= frame_bytes:
|
for i in range(0, len(audio_bytes), frame_bytes):
|
||||||
frame = self._audio_ingress_buffer[:frame_bytes]
|
frame = audio_bytes[i : i + frame_bytes]
|
||||||
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
|
|
||||||
await self.pipeline.process_audio(frame)
|
await self.pipeline.process_audio(frame)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||||
@@ -246,7 +239,7 @@ class Session:
|
|||||||
else:
|
else:
|
||||||
await self.pipeline.interrupt()
|
await self.pipeline.interrupt()
|
||||||
elif isinstance(message, ToolCallResultsMessage):
|
elif isinstance(message, ToolCallResultsMessage):
|
||||||
await self.pipeline.handle_tool_call_results(message.results)
|
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||||
elif isinstance(message, SessionStopMessage):
|
elif isinstance(message, SessionStopMessage):
|
||||||
await self._handle_session_stop(message.reason)
|
await self._handle_session_stop(message.reason)
|
||||||
else:
|
else:
|
||||||
@@ -268,9 +261,9 @@ class Session:
|
|||||||
self.ws_state = WsSessionState.STOPPED
|
self.ws_state = WsSessionState.STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
auth_payload = message.auth or {}
|
auth_payload = message.auth
|
||||||
api_key = auth_payload.get("apiKey")
|
api_key = auth_payload.apiKey if auth_payload else None
|
||||||
jwt = auth_payload.get("jwt")
|
jwt = auth_payload.jwt if auth_payload else None
|
||||||
|
|
||||||
if settings.ws_api_key:
|
if settings.ws_api_key:
|
||||||
if api_key != settings.ws_api_key:
|
if api_key != settings.ws_api_key:
|
||||||
@@ -330,7 +323,7 @@ class Session:
|
|||||||
"audio_out": self.TRACK_AUDIO_OUT,
|
"audio_out": self.TRACK_AUDIO_OUT,
|
||||||
"control": self.TRACK_CONTROL,
|
"control": self.TRACK_CONTROL,
|
||||||
},
|
},
|
||||||
audio=message.audio or {},
|
audio=message.audio.model_dump() if message.audio else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await self._send_event(
|
await self._send_event(
|
||||||
@@ -383,6 +376,7 @@ class Session:
|
|||||||
code: str,
|
code: str,
|
||||||
stage: Optional[str] = None,
|
stage: Optional[str] = None,
|
||||||
retryable: Optional[bool] = None,
|
retryable: Optional[bool] = None,
|
||||||
|
track_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send error event to client.
|
Send error event to client.
|
||||||
@@ -394,6 +388,7 @@ class Session:
|
|||||||
"""
|
"""
|
||||||
resolved_stage = stage or self._infer_error_stage(code)
|
resolved_stage = stage or self._infer_error_stage(code)
|
||||||
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
||||||
|
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
|
||||||
await self._send_event(
|
await self._send_event(
|
||||||
ev(
|
ev(
|
||||||
"error",
|
"error",
|
||||||
@@ -402,7 +397,7 @@ class Session:
|
|||||||
message=error_message,
|
message=error_message,
|
||||||
stage=resolved_stage,
|
stage=resolved_stage,
|
||||||
retryable=resolved_retryable,
|
retryable=resolved_retryable,
|
||||||
trackId=self.current_track_id,
|
trackId=resolved_track_id,
|
||||||
data={
|
data={
|
||||||
"error": {
|
"error": {
|
||||||
"stage": resolved_stage,
|
"stage": resolved_stage,
|
||||||
@@ -666,6 +661,15 @@ class Session:
|
|||||||
return "tts"
|
return "tts"
|
||||||
return "protocol"
|
return "protocol"
|
||||||
|
|
||||||
|
def _error_track_id(self, stage: str, code: str) -> str:
|
||||||
|
if stage in {"audio", "asr"}:
|
||||||
|
return self.TRACK_AUDIO_IN
|
||||||
|
if stage in {"llm", "tts", "tool"}:
|
||||||
|
return self.TRACK_AUDIO_OUT
|
||||||
|
if str(code or "").strip().lower().startswith("auth."):
|
||||||
|
return self.TRACK_CONTROL
|
||||||
|
return self.TRACK_CONTROL
|
||||||
|
|
||||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
event_type = str(event.get("type") or "")
|
event_type = str(event.get("type") or "")
|
||||||
source = str(event.get("source") or self._event_source(event_type))
|
source = str(event.get("source") or self._event_source(event_type))
|
||||||
|
|||||||
@@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
This document defines the public WebSocket protocol for the `/ws` endpoint.
|
This document defines the public WebSocket protocol for the `/ws` endpoint.
|
||||||
|
|
||||||
|
Validation policy:
|
||||||
|
- WS v1 JSON control messages are validated strictly.
|
||||||
|
- Unknown top-level fields are rejected for all defined client message types.
|
||||||
|
- `hello.version` is fixed to `"v1"`.
|
||||||
|
|
||||||
## Transport
|
## Transport
|
||||||
|
|
||||||
- A single WebSocket connection carries:
|
- A single WebSocket connection carries:
|
||||||
@@ -138,7 +143,8 @@ All server events include an envelope:
|
|||||||
|
|
||||||
Envelope notes:
|
Envelope notes:
|
||||||
- `seq` is monotonically increasing within one session (for replay/resume).
|
- `seq` is monotonically increasing within one session (for replay/resume).
|
||||||
- `source` is one of: `asr | llm | tts | tool | system`.
|
- `source` is one of: `asr | llm | tts | tool | system | client | server`.
|
||||||
|
- For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side.
|
||||||
- `data` is structured payload; legacy top-level fields are kept for compatibility.
|
- `data` is structured payload; legacy top-level fields are kept for compatibility.
|
||||||
|
|
||||||
Common events:
|
Common events:
|
||||||
@@ -181,6 +187,10 @@ Common events:
|
|||||||
- Fields: `trackId`, `latencyMs`
|
- Fields: `trackId`, `latencyMs`
|
||||||
- `error`
|
- `error`
|
||||||
- Fields: `sender`, `code`, `message`, `trackId`
|
- Fields: `sender`, `code`, `message`, `trackId`
|
||||||
|
- `trackId` convention:
|
||||||
|
- `audio_in` for `stage in {audio, asr}`
|
||||||
|
- `audio_out` for `stage in {llm, tts, tool}`
|
||||||
|
- `control` otherwise (including protocol/auth errors)
|
||||||
|
|
||||||
Track IDs (MVP fixed values):
|
Track IDs (MVP fixed values):
|
||||||
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
|
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
|
||||||
@@ -207,7 +217,7 @@ MVP fixed format:
|
|||||||
Framing rules:
|
Framing rules:
|
||||||
- Binary audio frame unit is 640 bytes.
|
- Binary audio frame unit is 640 bytes.
|
||||||
- A WS binary message may carry one or multiple complete 640-byte frames.
|
- A WS binary message may carry one or multiple complete 640-byte frames.
|
||||||
- Non-640-multiple payloads are treated as `audio.frame_size_mismatch` protocol errors.
|
- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly).
|
||||||
|
|
||||||
TTS boundary events:
|
TTS boundary events:
|
||||||
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
|
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""WS v1 protocol message models and helpers."""
|
"""WS v1 protocol message models and helpers."""
|
||||||
|
|
||||||
from typing import Optional, Dict, Any, Literal
|
from typing import Any, Dict, Literal, Optional
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||||
|
|
||||||
|
|
||||||
def now_ms() -> int:
|
def now_ms() -> int:
|
||||||
@@ -11,37 +12,66 @@ def now_ms() -> int:
|
|||||||
return int(time.time() * 1000)
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
class _StrictModel(BaseModel):
|
||||||
|
"""Protocol models reject unknown fields to enforce WS v1 schema."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
# Client -> Server messages
|
# Client -> Server messages
|
||||||
class HelloMessage(BaseModel):
|
class HelloAuth(_StrictModel):
|
||||||
|
apiKey: Optional[str] = None
|
||||||
|
jwt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HelloMessage(_StrictModel):
|
||||||
type: Literal["hello"]
|
type: Literal["hello"]
|
||||||
version: str = Field(..., description="Protocol version, currently v1")
|
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
|
||||||
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
|
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
|
||||||
|
|
||||||
|
|
||||||
class SessionStartMessage(BaseModel):
|
class SessionStartAudio(_StrictModel):
|
||||||
|
encoding: Literal["pcm_s16le"] = "pcm_s16le"
|
||||||
|
sample_rate_hz: Literal[16000] = 16000
|
||||||
|
channels: Literal[1] = 1
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStartMessage(_StrictModel):
|
||||||
type: Literal["session.start"]
|
type: Literal["session.start"]
|
||||||
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
|
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
||||||
|
|
||||||
|
|
||||||
class SessionStopMessage(BaseModel):
|
class SessionStopMessage(_StrictModel):
|
||||||
type: Literal["session.stop"]
|
type: Literal["session.stop"]
|
||||||
reason: Optional[str] = None
|
reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class InputTextMessage(BaseModel):
|
class InputTextMessage(_StrictModel):
|
||||||
type: Literal["input.text"]
|
type: Literal["input.text"]
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class ResponseCancelMessage(BaseModel):
|
class ResponseCancelMessage(_StrictModel):
|
||||||
type: Literal["response.cancel"]
|
type: Literal["response.cancel"]
|
||||||
graceful: bool = False
|
graceful: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ToolCallResultsMessage(BaseModel):
|
class ToolCallResultStatus(_StrictModel):
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResult(_StrictModel):
|
||||||
|
tool_call_id: str
|
||||||
|
name: str
|
||||||
|
output: Any = None
|
||||||
|
status: ToolCallResultStatus
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResultsMessage(_StrictModel):
|
||||||
type: Literal["tool_call.results"]
|
type: Literal["tool_call.results"]
|
||||||
results: list[Dict[str, Any]] = Field(default_factory=list)
|
results: list[ToolCallResult] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
CLIENT_MESSAGE_TYPES = {
|
CLIENT_MESSAGE_TYPES = {
|
||||||
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
|
|||||||
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
||||||
if not msg_class:
|
if not msg_class:
|
||||||
raise ValueError(f"Unknown client message type: {msg_type}")
|
raise ValueError(f"Unknown client message type: {msg_type}")
|
||||||
|
try:
|
||||||
return msg_class(**data)
|
return msg_class(**data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
details = []
|
||||||
|
for err in exc.errors():
|
||||||
|
loc = ".".join(str(part) for part in err.get("loc", ()))
|
||||||
|
msg = err.get("msg", "invalid field")
|
||||||
|
details.append(f"{loc}: {msg}" if loc else msg)
|
||||||
|
raise ValueError("; ".join(details)) from exc
|
||||||
|
|
||||||
|
|
||||||
# Server -> Client event helpers
|
# Server -> Client event helpers
|
||||||
|
|||||||
@@ -97,11 +97,11 @@ async def test_ws_message_parses_tool_call_results():
|
|||||||
msg = parse_client_message(
|
msg = parse_client_message(
|
||||||
{
|
{
|
||||||
"type": "tool_call.results",
|
"type": "tool_call.results",
|
||||||
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
"results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert isinstance(msg, ToolCallResultsMessage)
|
assert isinstance(msg, ToolCallResultsMessage)
|
||||||
assert msg.results[0]["tool_call_id"] == "call_1"
|
assert msg.results[0].tool_call_id == "call_1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user