Improve schema

This commit is contained in:
Xin Wang
2026-02-24 05:55:47 +08:00
parent c6c84b5af9
commit 6290fdd60e
4 changed files with 88 additions and 36 deletions

View File

@@ -107,8 +107,6 @@ class Session:
# Track IDs
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._audio_ingress_buffer: bytes = b""
self._audio_frame_error_reported: bool = False
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
@@ -179,23 +177,18 @@ class Session:
return
frame_bytes = self.AUDIO_FRAME_BYTES
self._audio_ingress_buffer += audio_bytes
# Protocol v1 audio framing: 20ms PCM frame (640 bytes).
# Allow aggregated frames in one WS message (multiple of 640).
if len(audio_bytes) % frame_bytes != 0 and not self._audio_frame_error_reported:
self._audio_frame_error_reported = True
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
f"Audio frame size should be multiple of {frame_bytes} bytes (20ms PCM)",
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=True,
retryable=False,
)
return
while len(self._audio_ingress_buffer) >= frame_bytes:
frame = self._audio_ingress_buffer[:frame_bytes]
self._audio_ingress_buffer = self._audio_ingress_buffer[frame_bytes:]
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
@@ -246,7 +239,7 @@ class Session:
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
@@ -268,9 +261,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
@@ -330,7 +323,7 @@ class Session:
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio or {},
audio=message.audio.model_dump() if message.audio else {},
)
)
await self._send_event(
@@ -383,6 +376,7 @@ class Session:
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
@@ -394,6 +388,7 @@ class Session:
"""
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
@@ -402,7 +397,7 @@ class Session:
message=error_message,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=self.current_track_id,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
@@ -666,6 +661,15 @@ class Session:
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))