Add output.audio.played message handling and update documentation

- Introduced `output.audio.played` message type for client acknowledgment of audio playback completion.
- Updated `DuplexPipeline` to track client playback state and handle playback completion events.
- Enhanced session handling to route `output.audio.played` messages to the pipeline.
- Revised API documentation to include details about the new message type and its fields.
- Updated schema documentation to reflect the addition of `output.audio.played` in the message flow.
This commit is contained in:
Xin Wang
2026-03-04 10:01:34 +08:00
parent 80fff09b76
commit 7d4af18815
8 changed files with 275 additions and 19 deletions

View File

@@ -396,6 +396,12 @@ class DuplexPipeline:
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._pending_client_playback_tts_ids: set[str] = set()
self._tts_playback_context: Dict[str, Dict[str, Optional[str]]] = {}
self._last_client_played_tts_id: Optional[str] = None
self._last_client_played_response_id: Optional[str] = None
self._last_client_played_turn_id: Optional[str] = None
self._last_client_played_at_ms: Optional[int] = None
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
@@ -632,8 +638,13 @@ class DuplexPipeline:
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
tts_id = self._new_id("tts", self._tts_count)
self._current_tts_id = tts_id
self._tts_playback_context[tts_id] = {
"turn_id": self._current_turn_id,
"response_id": self._current_response_id,
}
return tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
@@ -644,6 +655,53 @@ class DuplexPipeline:
self._start_turn()
return self._current_utterance_id
def _mark_client_playback_started(self, tts_id: Optional[str]) -> None:
normalized_tts_id = str(tts_id or "").strip()
if not normalized_tts_id:
return
self._pending_client_playback_tts_ids.add(normalized_tts_id)
def _clear_client_playback_tracking(self) -> None:
self._pending_client_playback_tts_ids.clear()
self._tts_playback_context.clear()
async def handle_output_audio_played(
self,
*,
tts_id: str,
response_id: Optional[str] = None,
turn_id: Optional[str] = None,
played_at_ms: Optional[int] = None,
played_ms: Optional[int] = None,
) -> None:
"""Record client-side playback completion for a TTS segment."""
normalized_tts_id = str(tts_id or "").strip()
if not normalized_tts_id:
return
was_pending = normalized_tts_id in self._pending_client_playback_tts_ids
self._pending_client_playback_tts_ids.discard(normalized_tts_id)
context = self._tts_playback_context.pop(normalized_tts_id, {})
resolved_response_id = str(response_id or context.get("response_id") or "").strip() or None
resolved_turn_id = str(turn_id or context.get("turn_id") or "").strip() or None
self._last_client_played_tts_id = normalized_tts_id
self._last_client_played_response_id = resolved_response_id
self._last_client_played_turn_id = resolved_turn_id
if isinstance(played_at_ms, int) and played_at_ms >= 0:
self._last_client_played_at_ms = played_at_ms
else:
self._last_client_played_at_ms = self._get_timestamp_ms()
duration_ms = played_ms if isinstance(played_ms, int) and played_ms >= 0 else None
logger.info(
f"[PlaybackAck] tts_id={normalized_tts_id} response_id={resolved_response_id or '-'} "
f"turn_id={resolved_turn_id or '-'} pending_before={was_pending} "
f"pending_now={len(self._pending_client_playback_tts_ids)} "
f"played_ms={duration_ms if duration_ms is not None else '-'}"
)
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
@@ -1046,7 +1104,8 @@ class DuplexPipeline:
try:
self._drop_outbound_audio = False
self._start_tts()
tts_id = self._start_tts()
self._mark_client_playback_started(tts_id)
await self._send_event(
{
**ev(
@@ -2254,7 +2313,8 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
tts_id = self._start_tts()
self._mark_client_playback_started(tts_id)
await self._send_event(
{
**ev(
@@ -2294,7 +2354,8 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
tts_id = self._start_tts()
self._mark_client_playback_started(tts_id)
await self._send_event(
{
**ev(
@@ -2554,7 +2615,8 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
tts_id = self._start_tts()
self._mark_client_playback_started(tts_id)
await self._send_event({
**ev(
"output.audio.start",
@@ -2625,6 +2687,7 @@ class DuplexPipeline:
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
self._clear_client_playback_tracking()
interrupted_turn_id = self._current_turn_id
interrupted_utterance_id = self._current_utterance_id
interrupted_response_id = self._current_response_id
@@ -2666,6 +2729,7 @@ class DuplexPipeline:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
self._clear_client_playback_tracking()
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()
@@ -2709,8 +2773,13 @@ class DuplexPipeline:
@property
def is_speaking(self) -> bool:
"""Check if bot is currently speaking."""
return self._is_bot_speaking
"""Check if assistant audio is still active (server send or client playback)."""
return self._is_bot_speaking or self.is_client_playing_audio
@property
def is_client_playing_audio(self) -> bool:
"""Check if client has unacknowledged assistant audio playback."""
return bool(self._pending_client_playback_tts_ids)
@property
def state(self) -> ConversationState:

View File

@@ -24,6 +24,7 @@ from models.ws_v1 import (
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
OutputAudioPlayedMessage,
ToolCallResultsMessage,
)
@@ -267,6 +268,14 @@ class Session:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, OutputAudioPlayedMessage):
await self.pipeline.handle_output_audio_played(
tts_id=message.tts_id,
response_id=message.response_id,
turn_id=message.turn_id,
played_at_ms=message.played_at_ms,
played_ms=message.played_ms,
)
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):