Improve schema
This commit is contained in:
@@ -107,8 +107,6 @@ class Session:
|
||||
# Track IDs
|
||||
self.current_track_id: str = self.TRACK_CONTROL
|
||||
self._event_seq: int = 0
|
||||
self._audio_ingress_buffer: bytes = b""
|
||||
self._audio_frame_error_reported: bool = False
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
@@ -179,23 +177,18 @@ class Session:
|
||||
return
|
||||
|
||||
frame_bytes = self.AUDIO_FRAME_BYTES
|
||||
self._audio_ingress_buffer += audio_bytes
|
||||
|
||||
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
|
||||
# Allow aggregated frames in one WS message (multiple of 640).
|
||||
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
|
||||
self._audio_frame_error_reported = True
|
||||
if len(audio_bytes) % frame_bytes != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
|
||||
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
|
||||
"audio.frame_size_mismatch",
|
||||
stage="audio",
|
||||
retryable=True,
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
while len(self._audio_ingress_buffer) >= frame_bytes:
|
||||
frame = self._audio_ingress_buffer[:frame_bytes]
|
||||
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
|
||||
for i in range(0, len(audio_bytes), frame_bytes):
|
||||
frame = audio_bytes[i : i + frame_bytes]
|
||||
await self.pipeline.process_audio(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
@@ -246,7 +239,7 @@ class Session:
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
@@ -268,9 +261,9 @@ class Session:
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
auth_payload = message.auth
|
||||
api_key = auth_payload.apiKey if auth_payload else None
|
||||
jwt = auth_payload.jwt if auth_payload else None
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
@@ -330,7 +323,7 @@ class Session:
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
audio=message.audio or {},
|
||||
audio=message.audio.model_dump() if message.audio else {},
|
||||
)
|
||||
)
|
||||
await self._send_event(
|
||||
@@ -383,6 +376,7 @@ class Session:
|
||||
code: str,
|
||||
stage: Optional[str] = None,
|
||||
retryable: Optional[bool] = None,
|
||||
track_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
@@ -394,6 +388,7 @@ class Session:
|
||||
"""
|
||||
resolved_stage = stage or self._infer_error_stage(code)
|
||||
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
||||
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"error",
|
||||
@@ -402,7 +397,7 @@ class Session:
|
||||
message=error_message,
|
||||
stage=resolved_stage,
|
||||
retryable=resolved_retryable,
|
||||
trackId=self.current_track_id,
|
||||
trackId=resolved_track_id,
|
||||
data={
|
||||
"error": {
|
||||
"stage": resolved_stage,
|
||||
@@ -666,6 +661,15 @@ class Session:
|
||||
return "tts"
|
||||
return "protocol"
|
||||
|
||||
def _error_track_id(self, stage: str, code: str) -> str:
|
||||
if stage in {"audio", "asr"}:
|
||||
return self.TRACK_AUDIO_IN
|
||||
if stage in {"llm", "tts", "tool"}:
|
||||
return self.TRACK_AUDIO_OUT
|
||||
if str(code or "").strip().lower().startswith("auth."):
|
||||
return self.TRACK_CONTROL
|
||||
return self.TRACK_CONTROL
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
|
||||
Reference in New Issue
Block a user