Add output.audio.played message handling and update documentation
- Introduced `output.audio.played` message type for client acknowledgment of audio playback completion. - Updated `DuplexPipeline` to track client playback state and handle playback completion events. - Enhanced session handling to route `output.audio.played` messages to the pipeline. - Revised API documentation to include details about the new message type and its fields. - Updated schema documentation to reflect the addition of `output.audio.played` in the message flow.
This commit is contained in:
@@ -396,6 +396,12 @@ class DuplexPipeline:
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
self._pending_client_tool_call_ids: set[str] = set()
|
||||
self._pending_client_playback_tts_ids: set[str] = set()
|
||||
self._tts_playback_context: Dict[str, Dict[str, Optional[str]]] = {}
|
||||
self._last_client_played_tts_id: Optional[str] = None
|
||||
self._last_client_played_response_id: Optional[str] = None
|
||||
self._last_client_played_turn_id: Optional[str] = None
|
||||
self._last_client_played_at_ms: Optional[int] = None
|
||||
self._next_seq: Optional[Callable[[], int]] = None
|
||||
self._local_seq: int = 0
|
||||
|
||||
@@ -632,8 +638,13 @@ class DuplexPipeline:
|
||||
|
||||
def _start_tts(self) -> str:
|
||||
self._tts_count += 1
|
||||
self._current_tts_id = self._new_id("tts", self._tts_count)
|
||||
return self._current_tts_id
|
||||
tts_id = self._new_id("tts", self._tts_count)
|
||||
self._current_tts_id = tts_id
|
||||
self._tts_playback_context[tts_id] = {
|
||||
"turn_id": self._current_turn_id,
|
||||
"response_id": self._current_response_id,
|
||||
}
|
||||
return tts_id
|
||||
|
||||
def _finalize_utterance(self) -> str:
|
||||
if self._current_utterance_id:
|
||||
@@ -644,6 +655,53 @@ class DuplexPipeline:
|
||||
self._start_turn()
|
||||
return self._current_utterance_id
|
||||
|
||||
def _mark_client_playback_started(self, tts_id: Optional[str]) -> None:
|
||||
normalized_tts_id = str(tts_id or "").strip()
|
||||
if not normalized_tts_id:
|
||||
return
|
||||
self._pending_client_playback_tts_ids.add(normalized_tts_id)
|
||||
|
||||
def _clear_client_playback_tracking(self) -> None:
|
||||
self._pending_client_playback_tts_ids.clear()
|
||||
self._tts_playback_context.clear()
|
||||
|
||||
async def handle_output_audio_played(
|
||||
self,
|
||||
*,
|
||||
tts_id: str,
|
||||
response_id: Optional[str] = None,
|
||||
turn_id: Optional[str] = None,
|
||||
played_at_ms: Optional[int] = None,
|
||||
played_ms: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Record client-side playback completion for a TTS segment."""
|
||||
normalized_tts_id = str(tts_id or "").strip()
|
||||
if not normalized_tts_id:
|
||||
return
|
||||
|
||||
was_pending = normalized_tts_id in self._pending_client_playback_tts_ids
|
||||
self._pending_client_playback_tts_ids.discard(normalized_tts_id)
|
||||
|
||||
context = self._tts_playback_context.pop(normalized_tts_id, {})
|
||||
resolved_response_id = str(response_id or context.get("response_id") or "").strip() or None
|
||||
resolved_turn_id = str(turn_id or context.get("turn_id") or "").strip() or None
|
||||
|
||||
self._last_client_played_tts_id = normalized_tts_id
|
||||
self._last_client_played_response_id = resolved_response_id
|
||||
self._last_client_played_turn_id = resolved_turn_id
|
||||
if isinstance(played_at_ms, int) and played_at_ms >= 0:
|
||||
self._last_client_played_at_ms = played_at_ms
|
||||
else:
|
||||
self._last_client_played_at_ms = self._get_timestamp_ms()
|
||||
|
||||
duration_ms = played_ms if isinstance(played_ms, int) and played_ms >= 0 else None
|
||||
logger.info(
|
||||
f"[PlaybackAck] tts_id={normalized_tts_id} response_id={resolved_response_id or '-'} "
|
||||
f"turn_id={resolved_turn_id or '-'} pending_before={was_pending} "
|
||||
f"pending_now={len(self._pending_client_playback_tts_ids)} "
|
||||
f"played_ms={duration_ms if duration_ms is not None else '-'}"
|
||||
)
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
@@ -1046,7 +1104,8 @@ class DuplexPipeline:
|
||||
|
||||
try:
|
||||
self._drop_outbound_audio = False
|
||||
self._start_tts()
|
||||
tts_id = self._start_tts()
|
||||
self._mark_client_playback_started(tts_id)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
@@ -2254,7 +2313,8 @@ class DuplexPipeline:
|
||||
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
tts_id = self._start_tts()
|
||||
self._mark_client_playback_started(tts_id)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
@@ -2294,7 +2354,8 @@ class DuplexPipeline:
|
||||
and not self._interrupt_event.is_set()
|
||||
):
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
tts_id = self._start_tts()
|
||||
self._mark_client_playback_started(tts_id)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
@@ -2554,7 +2615,8 @@ class DuplexPipeline:
|
||||
first_audio_sent = False
|
||||
|
||||
# Send track start event
|
||||
self._start_tts()
|
||||
tts_id = self._start_tts()
|
||||
self._mark_client_playback_started(tts_id)
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
@@ -2625,6 +2687,7 @@ class DuplexPipeline:
|
||||
self._is_bot_speaking = False
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
self._clear_client_playback_tracking()
|
||||
interrupted_turn_id = self._current_turn_id
|
||||
interrupted_utterance_id = self._current_utterance_id
|
||||
interrupted_response_id = self._current_response_id
|
||||
@@ -2666,6 +2729,7 @@ class DuplexPipeline:
|
||||
"""Stop any current speech task."""
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
self._clear_client_playback_tracking()
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
@@ -2709,8 +2773,13 @@ class DuplexPipeline:
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if bot is currently speaking."""
|
||||
return self._is_bot_speaking
|
||||
"""Check if assistant audio is still active (server send or client playback)."""
|
||||
return self._is_bot_speaking or self.is_client_playing_audio
|
||||
|
||||
@property
|
||||
def is_client_playing_audio(self) -> bool:
|
||||
"""Check if client has unacknowledged assistant audio playback."""
|
||||
return bool(self._pending_client_playback_tts_ids)
|
||||
|
||||
@property
|
||||
def state(self) -> ConversationState:
|
||||
|
||||
@@ -24,6 +24,7 @@ from models.ws_v1 import (
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
OutputAudioPlayedMessage,
|
||||
ToolCallResultsMessage,
|
||||
)
|
||||
|
||||
@@ -267,6 +268,14 @@ class Session:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, OutputAudioPlayedMessage):
|
||||
await self.pipeline.handle_output_audio_played(
|
||||
tts_id=message.tts_id,
|
||||
response_id=message.response_id,
|
||||
turn_id=message.turn_id,
|
||||
played_at_ms=message.played_at_ms,
|
||||
played_ms=message.played_ms,
|
||||
)
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
|
||||
Reference in New Issue
Block a user