Add output.audio.played message handling and update documentation

- Introduced `output.audio.played` message type for client acknowledgment of audio playback completion.
- Updated `DuplexPipeline` to track client playback state and handle playback completion events.
- Enhanced session handling to route `output.audio.played` messages to the pipeline.
- Revised API documentation to include details about the new message type and its fields.
- Updated schema documentation to reflect the addition of `output.audio.played` in the message flow.
This commit is contained in:
Xin Wang
2026-03-04 10:01:34 +08:00
parent 80fff09b76
commit 7d4af18815
8 changed files with 275 additions and 19 deletions

View File

@@ -6,7 +6,7 @@ import pytest
from core.conversation import ConversationState
from core.duplex_pipeline import DuplexPipeline
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
from models.ws_v1 import OutputAudioPlayedMessage, ToolCallResultsMessage, parse_client_message
from services.base import LLMStreamEvent
@@ -432,6 +432,45 @@ async def test_ws_message_parses_tool_call_results():
assert msg.results[0].tool_call_id == "call_1"
@pytest.mark.asyncio
async def test_ws_message_parses_output_audio_played():
msg = parse_client_message(
{
"type": "output.audio.played",
"tts_id": "tts_1",
"response_id": "resp_1",
"turn_id": "turn_1",
"played_at_ms": 1234567890,
"played_ms": 2100,
}
)
assert isinstance(msg, OutputAudioPlayedMessage)
assert msg.tts_id == "tts_1"
@pytest.mark.asyncio
async def test_output_audio_played_updates_client_playback_state(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
tts_id = pipeline._start_tts()
pipeline._mark_client_playback_started(tts_id)
assert pipeline.is_client_playing_audio is True
await pipeline.handle_output_audio_played(
tts_id=tts_id,
response_id="resp_1",
turn_id="turn_1",
played_at_ms=1234567000,
played_ms=1800,
)
assert pipeline.is_client_playing_audio is False
assert pipeline._last_client_played_tts_id == tts_id
assert pipeline._last_client_played_response_id == "resp_1"
assert pipeline._last_client_played_turn_id == "turn_1"
assert pipeline._last_client_played_at_ms == 1234567000
@pytest.mark.asyncio
async def test_turn_without_tool_keeps_streaming(monkeypatch):
pipeline, events = _build_pipeline(

View File

@@ -1,7 +1,7 @@
import pytest
from core.session import Session, WsSessionState
from models.ws_v1 import SessionStartMessage, parse_client_message
from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message
def _session() -> Session:
@@ -16,6 +16,17 @@ def test_parse_client_message_rejects_hello_message():
parse_client_message({"type": "hello", "version": "v1"})
def test_parse_client_message_accepts_output_audio_played():
message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"})
assert isinstance(message, OutputAudioPlayedMessage)
assert message.tts_id == "tts_001"
def test_parse_client_message_rejects_output_audio_played_without_tts_id():
with pytest.raises(ValueError, match="tts_id"):
parse_client_message({"type": "output.audio.played", "tts_id": ""})
@pytest.mark.asyncio
async def test_handle_text_reports_invalid_message_for_hello():
session = Session.__new__(Session)
@@ -42,6 +53,45 @@ async def test_handle_text_reports_invalid_message_for_hello():
assert "Unknown client message type: hello" in message
@pytest.mark.asyncio
async def test_handle_v1_message_routes_output_audio_played_to_pipeline():
session = Session.__new__(Session)
session.id = "sess_output_audio_played"
session.ws_state = WsSessionState.ACTIVE
received = {}
class _Pipeline:
async def handle_output_audio_played(self, **payload):
received.update(payload)
session.pipeline = _Pipeline()
async def _send_error(sender, message, code, **kwargs):
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
session._send_error = _send_error
await session._handle_v1_message(
OutputAudioPlayedMessage(
type="output.audio.played",
tts_id="tts_001",
response_id="resp_001",
turn_id="turn_001",
played_at_ms=1730000018450,
played_ms=2520,
)
)
assert received == {
"tts_id": "tts_001",
"response_id": "resp_001",
"turn_id": "turn_001",
"played_at_ms": 1730000018450,
"played_ms": 2520,
}
def test_validate_metadata_rejects_services_payload():
session = _session()
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})