diff --git a/docs/content/api-reference/websocket.md b/docs/content/api-reference/websocket.md index 995fd1d..66b2924 100644 --- a/docs/content/api-reference/websocket.md +++ b/docs/content/api-reference/websocket.md @@ -30,6 +30,7 @@ Server <- assistant.response.delta / assistant.response.final Server <- output.audio.start Server <- (binary pcm frames...) Server <- output.audio.end +Client -> output.audio.played (optional) Client -> session.stop Server <- session.stopped ``` @@ -143,7 +144,33 @@ Server <- session.stopped --- -### 4. Tool Call Results: `tool_call.results` +### 4. Output Audio Played: `output.audio.played` + +客户端回执音频已在本地播放完成(含本地 jitter buffer / 播放队列)。 + +```json +{ + "type": "output.audio.played", + "tts_id": "tts_001", + "response_id": "resp_001", + "turn_id": "turn_001", + "played_at_ms": 1730000018450, + "played_ms": 2520 +} +``` + +| 字段 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `type` | string | 是 | 固定为 `"output.audio.played"` | +| `tts_id` | string | 是 | 已完成播放的 TTS 段 ID | +| `response_id` | string | 否 | 所属回复 ID(建议回传) | +| `turn_id` | string | 否 | 所属轮次 ID(建议回传) | +| `played_at_ms` | number | 否 | 客户端本地播放完成时间戳(毫秒) | +| `played_ms` | number | 否 | 本次播放耗时(毫秒) | + +--- + +### 5. Tool Call Results: `tool_call.results` 回传客户端执行的工具结果。 @@ -174,7 +201,7 @@ Server <- session.stopped --- -### 5. Session Stop: `session.stop` +### 6. Session Stop: `session.stop` 结束对话会话。 @@ -192,7 +219,7 @@ Server <- session.stopped --- -### 6. Binary Audio +### 7. Binary Audio 在 `session.started` 之后可持续发送二进制 PCM 音频。 @@ -707,6 +734,8 @@ TTS 音频播放结束标记。 | `data.tts_id` | string | TTS 播放段 ID | | `data.turn_id` | string | 当前对话轮次 ID | +**说明**:`output.audio.end` 表示服务端已发送完成,不代表客户端扬声器已播完。若需要“真实播完”信号,客户端应发送 `output.audio.played`。 + --- #### `response.interrupted` diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 77deceb..13f1852 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -396,6 +396,12 @@ class DuplexPipeline: self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() self._pending_client_tool_call_ids: set[str] = set() + self._pending_client_playback_tts_ids: set[str] = set() + self._tts_playback_context: Dict[str, Dict[str, Optional[str]]] = {} + self._last_client_played_tts_id: Optional[str] = None + self._last_client_played_response_id: Optional[str] = None + self._last_client_played_turn_id: Optional[str] = None + self._last_client_played_at_ms: Optional[int] = None self._next_seq: Optional[Callable[[], int]] = None self._local_seq: int = 0 @@ -632,8 +638,13 @@ class DuplexPipeline: def _start_tts(self) -> str: self._tts_count += 1 - self._current_tts_id = self._new_id("tts", self._tts_count) - return self._current_tts_id + tts_id = self._new_id("tts", self._tts_count) + self._current_tts_id = tts_id + self._tts_playback_context[tts_id] = { + "turn_id": self._current_turn_id, + "response_id": self._current_response_id, + } + return tts_id def _finalize_utterance(self) -> str: if self._current_utterance_id: @@ -644,6 +655,53 @@ class DuplexPipeline: self._start_turn() return self._current_utterance_id + def _mark_client_playback_started(self, tts_id: Optional[str]) -> None: + normalized_tts_id = str(tts_id or "").strip() + if not normalized_tts_id: + return + self._pending_client_playback_tts_ids.add(normalized_tts_id) + + def _clear_client_playback_tracking(self) -> None: + self._pending_client_playback_tts_ids.clear() + self._tts_playback_context.clear() + + async def handle_output_audio_played( + self, + *, + tts_id: str, + response_id: Optional[str] = None, + turn_id: Optional[str] = None, + played_at_ms: Optional[int] = None, + played_ms: Optional[int] = None, + ) -> None: + """Record client-side playback completion for a TTS segment.""" + normalized_tts_id = str(tts_id or "").strip() + if not normalized_tts_id: + return + + was_pending = normalized_tts_id in self._pending_client_playback_tts_ids + self._pending_client_playback_tts_ids.discard(normalized_tts_id) + + context = self._tts_playback_context.pop(normalized_tts_id, {}) + resolved_response_id = str(response_id or context.get("response_id") or "").strip() or None + resolved_turn_id = str(turn_id or context.get("turn_id") or "").strip() or None + + self._last_client_played_tts_id = normalized_tts_id + self._last_client_played_response_id = resolved_response_id + self._last_client_played_turn_id = resolved_turn_id + if isinstance(played_at_ms, int) and played_at_ms >= 0: + self._last_client_played_at_ms = played_at_ms + else: + self._last_client_played_at_ms = self._get_timestamp_ms() + + duration_ms = played_ms if isinstance(played_ms, int) and played_ms >= 0 else None + logger.info( + f"[PlaybackAck] tts_id={normalized_tts_id} response_id={resolved_response_id or '-'} " + f"turn_id={resolved_turn_id or '-'} pending_before={was_pending} " + f"pending_now={len(self._pending_client_playback_tts_ids)} " + f"played_ms={duration_ms if duration_ms is not None else '-'}" + ) + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: event_type = str(event.get("type") or "") source = str(event.get("source") or self._event_source(event_type)) @@ -1046,7 +1104,8 @@ class DuplexPipeline: try: self._drop_outbound_audio = False - self._start_tts() + tts_id = self._start_tts() + self._mark_client_playback_started(tts_id) await self._send_event( { **ev( @@ -2254,7 +2313,8 @@ class DuplexPipeline: if self._tts_output_enabled() and not self._interrupt_event.is_set(): if not first_audio_sent: - self._start_tts() + tts_id = self._start_tts() + self._mark_client_playback_started(tts_id) await self._send_event( { **ev( @@ -2294,7 +2354,8 @@ class DuplexPipeline: and not self._interrupt_event.is_set() ): if not first_audio_sent: - self._start_tts() + tts_id = self._start_tts() + self._mark_client_playback_started(tts_id) await self._send_event( { **ev( @@ -2554,7 +2615,8 @@ class DuplexPipeline: first_audio_sent = False # Send track start event - self._start_tts() + tts_id = self._start_tts() + self._mark_client_playback_started(tts_id) await self._send_event({ **ev( "output.audio.start", @@ -2625,6 +2687,7 @@ class DuplexPipeline: self._is_bot_speaking = False self._drop_outbound_audio = True self._audio_out_frame_buffer = b"" + self._clear_client_playback_tracking() interrupted_turn_id = self._current_turn_id interrupted_utterance_id = self._current_utterance_id interrupted_response_id = self._current_response_id @@ -2666,6 +2729,7 @@ class DuplexPipeline: """Stop any current speech task.""" self._drop_outbound_audio = True self._audio_out_frame_buffer = b"" + self._clear_client_playback_tracking() if self._current_turn_task and not self._current_turn_task.done(): self._interrupt_event.set() self._current_turn_task.cancel() @@ -2709,8 +2773,13 @@ class DuplexPipeline: @property def is_speaking(self) -> bool: - """Check if bot is currently speaking.""" - return self._is_bot_speaking + """Check if assistant audio is still active (server send or client playback).""" + return self._is_bot_speaking or self.is_client_playing_audio + + @property + def is_client_playing_audio(self) -> bool: + """Check if client has unacknowledged assistant audio playback.""" + return bool(self._pending_client_playback_tts_ids) @property def state(self) -> ConversationState: diff --git a/engine/core/session.py b/engine/core/session.py index 1603bda..de00855 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -24,6 +24,7 @@ from models.ws_v1 import ( SessionStopMessage, InputTextMessage, ResponseCancelMessage, + OutputAudioPlayedMessage, ToolCallResultsMessage, ) @@ -267,6 +268,14 @@ class Session: logger.info(f"Session {self.id} graceful response.cancel") else: await self.pipeline.interrupt() + elif isinstance(message, OutputAudioPlayedMessage): + await self.pipeline.handle_output_audio_played( + tts_id=message.tts_id, + response_id=message.response_id, + turn_id=message.turn_id, + played_at_ms=message.played_at_ms, + played_ms=message.played_ms, + ) elif isinstance(message, ToolCallResultsMessage): await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results]) elif isinstance(message, SessionStopMessage): diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index 22a9dcf..94fdf73 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -20,7 +20,7 @@ Required message order: 1. Client connects to `/ws?assistant_id=`. 2. Client sends `session.start`. 3. Server replies `session.started`. -4. Client may stream binary audio and/or send `input.text`. +4. Client may stream binary audio and/or send `input.text`, `response.cancel`, `output.audio.played`, `tool_call.results`. 5. Client sends `session.stop` (or closes socket). If order is violated, server emits `error` with `code = "protocol.order"`. @@ -100,6 +100,22 @@ Text-only mode: } ``` +### `output.audio.played` + +Client playback ACK after assistant audio is actually drained on local speakers +(including jitter buffer / playback queue). + +```json +{ + "type": "output.audio.played", + "tts_id": "tts_001", + "response_id": "resp_001", + "turn_id": "turn_001", + "played_at_ms": 1730000018450, + "played_ms": 2520 +} +``` + ### `session.stop` ```json @@ -223,6 +239,8 @@ Framing rules: TTS boundary events: - `output.audio.start` and `output.audio.end` mark assistant playback boundaries. +- `output.audio.end` means server-side audio send completed (not guaranteed speaker drain). +- For speaker-drain confirmation, client should send `output.audio.played`. ## Event Throttling diff --git a/engine/docs/ws_v1_schema_zh.md b/engine/docs/ws_v1_schema_zh.md index 3313b73..ce5f175 100644 --- a/engine/docs/ws_v1_schema_zh.md +++ b/engine/docs/ws_v1_schema_zh.md @@ -46,6 +46,7 @@ - 二进制音频 - `input.text`(可选) - `response.cancel`(可选) + - `output.audio.played`(可选) - `tool_call.results`(可选) 6. 客户端发送 `session.stop` 或直接断开连接 @@ -190,7 +191,35 @@ | `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 | | `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 | -## 3.5 `tool_call.results` +## 3.5 `output.audio.played` + +客户端在本地扬声器真正播完后回执(含 jitter buffer / 播放队列)。 + +示例: + +```json +{ + "type": "output.audio.played", + "tts_id": "tts_001", + "response_id": "resp_001", + "turn_id": "turn_001", + "played_at_ms": 1730000018450, + "played_ms": 2520 +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"output.audio.played"` | 播放完成回执 | 客户端播完后上送 | +| `tts_id` | string | 是 | 非空字符串 | TTS 段 ID | 建议使用 `output.audio.start/end` 中同一 `tts_id` | +| `response_id` | string \| null | 否 | 任意字符串 | 回复 ID | 建议回传,便于聚合 | +| `turn_id` | string \| null | 否 | 任意字符串 | 轮次 ID | 建议回传,便于聚合 | +| `played_at_ms` | number \| null | 否 | 毫秒时间戳 | 客户端播放完成时间 | 用于时延分析 | +| `played_ms` | number \| null | 否 | 非负数 | 客户端播放耗时 | 用于播放器统计 | + +## 3.6 `tool_call.results` 仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。 @@ -228,7 +257,7 @@ - 重复回传会被忽略; - 超时未回传会由服务端合成超时结果(`504`)。 -## 3.6 `session.stop` +## 3.7 `session.stop` 示例: @@ -406,7 +435,7 @@ - 含义:TTS 音频输出开始边界 6. `output.audio.end` -- 含义:TTS 音频输出结束边界 +- 含义:TTS 音频输出结束边界(服务端发送完成,不等价于扬声器已播完) 7. `response.interrupted` - 含义:当前回答被打断(barge-in 或 cancel) @@ -434,6 +463,7 @@ - 音频为 PCM 二进制帧; - 发送单位对齐到 `640 bytes`(不足会补零后发送); - 前端通常结合 `output.audio.start/end` 做播放边界控制; +- 若需要“扬声器真实播完”语义,前端应在播完后发送 `output.audio.played`; - 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。 --- @@ -502,8 +532,9 @@ 2. 语音输入严格按 16k/16bit/mono,并保证每个 WS 二进制消息长度是 `640*n`。 3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。 4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。 -5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。 -6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。 +5. 若业务依赖“扬声器真实播完”,请在播完时上送 `output.audio.played`。 +6. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。 +7. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。 --- @@ -521,6 +552,7 @@ Server <- assistant.response.delta / assistant.response.final Server <- output.audio.start Server <- (binary pcm frames...) Server <- output.audio.end +Client -> output.audio.played (optional) Client -> session.stop Server <- session.stopped ``` diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py index 39bd61e..5cc4fac 100644 --- a/engine/models/ws_v1.py +++ b/engine/models/ws_v1.py @@ -45,6 +45,15 @@ class ResponseCancelMessage(_StrictModel): graceful: bool = False +class OutputAudioPlayedMessage(_StrictModel): + type: Literal["output.audio.played"] + tts_id: str = Field(..., min_length=1) + response_id: Optional[str] = None + turn_id: Optional[str] = None + played_at_ms: Optional[int] = Field(default=None, ge=0) + played_ms: Optional[int] = Field(default=None, ge=0) + + class ToolCallResultStatus(_StrictModel): code: int message: str @@ -67,6 +76,7 @@ CLIENT_MESSAGE_TYPES = { "session.stop": SessionStopMessage, "input.text": InputTextMessage, "response.cancel": ResponseCancelMessage, + "output.audio.played": OutputAudioPlayedMessage, "tool_call.results": ToolCallResultsMessage, } diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 5b78213..11a7b77 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -6,7 +6,7 @@ import pytest from core.conversation import ConversationState from core.duplex_pipeline import DuplexPipeline -from models.ws_v1 import ToolCallResultsMessage, parse_client_message +from models.ws_v1 import OutputAudioPlayedMessage, ToolCallResultsMessage, parse_client_message from services.base import LLMStreamEvent @@ -432,6 +432,45 @@ async def test_ws_message_parses_tool_call_results(): assert msg.results[0].tool_call_id == "call_1" +@pytest.mark.asyncio +async def test_ws_message_parses_output_audio_played(): + msg = parse_client_message( + { + "type": "output.audio.played", + "tts_id": "tts_1", + "response_id": "resp_1", + "turn_id": "turn_1", + "played_at_ms": 1234567890, + "played_ms": 2100, + } + ) + assert isinstance(msg, OutputAudioPlayedMessage) + assert msg.tts_id == "tts_1" + + +@pytest.mark.asyncio +async def test_output_audio_played_updates_client_playback_state(monkeypatch): + pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + tts_id = pipeline._start_tts() + pipeline._mark_client_playback_started(tts_id) + + assert pipeline.is_client_playing_audio is True + + await pipeline.handle_output_audio_played( + tts_id=tts_id, + response_id="resp_1", + turn_id="turn_1", + played_at_ms=1234567000, + played_ms=1800, + ) + + assert pipeline.is_client_playing_audio is False + assert pipeline._last_client_played_tts_id == tts_id + assert pipeline._last_client_played_response_id == "resp_1" + assert pipeline._last_client_played_turn_id == "turn_1" + assert pipeline._last_client_played_at_ms == 1234567000 + + @pytest.mark.asyncio async def test_turn_without_tool_keeps_streaming(monkeypatch): pipeline, events = _build_pipeline( diff --git a/engine/tests/test_ws_protocol_session_start.py b/engine/tests/test_ws_protocol_session_start.py index fb2c3f4..90ac179 100644 --- a/engine/tests/test_ws_protocol_session_start.py +++ b/engine/tests/test_ws_protocol_session_start.py @@ -1,7 +1,7 @@ import pytest from core.session import Session, WsSessionState -from models.ws_v1 import SessionStartMessage, parse_client_message +from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message def _session() -> Session: @@ -16,6 +16,17 @@ def test_parse_client_message_rejects_hello_message(): parse_client_message({"type": "hello", "version": "v1"}) +def test_parse_client_message_accepts_output_audio_played(): + message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"}) + assert isinstance(message, OutputAudioPlayedMessage) + assert message.tts_id == "tts_001" + + +def test_parse_client_message_rejects_output_audio_played_without_tts_id(): + with pytest.raises(ValueError, match="tts_id"): + parse_client_message({"type": "output.audio.played", "tts_id": ""}) + + @pytest.mark.asyncio async def test_handle_text_reports_invalid_message_for_hello(): session = Session.__new__(Session) @@ -42,6 +53,45 @@ async def test_handle_text_reports_invalid_message_for_hello(): assert "Unknown client message type: hello" in message +@pytest.mark.asyncio +async def test_handle_v1_message_routes_output_audio_played_to_pipeline(): + session = Session.__new__(Session) + session.id = "sess_output_audio_played" + session.ws_state = WsSessionState.ACTIVE + + received = {} + + class _Pipeline: + async def handle_output_audio_played(self, **payload): + received.update(payload) + + session.pipeline = _Pipeline() + + async def _send_error(sender, message, code, **kwargs): + raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}") + + session._send_error = _send_error + + await session._handle_v1_message( + OutputAudioPlayedMessage( + type="output.audio.played", + tts_id="tts_001", + response_id="resp_001", + turn_id="turn_001", + played_at_ms=1730000018450, + played_ms=2520, + ) + ) + + assert received == { + "tts_id": "tts_001", + "response_id": "resp_001", + "turn_id": "turn_001", + "played_at_ms": 1730000018450, + "played_ms": 2520, + } + + def test_validate_metadata_rejects_services_payload(): session = _session() sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})