Add output.audio.played message handling and update documentation
- Introduced `output.audio.played` message type for client acknowledgment of audio playback completion. - Updated `DuplexPipeline` to track client playback state and handle playback completion events. - Enhanced session handling to route `output.audio.played` messages to the pipeline. - Revised API documentation to include details about the new message type and its fields. - Updated schema documentation to reflect the addition of `output.audio.played` in the message flow.
This commit is contained in:
@@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from core.conversation import ConversationState
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
|
||||
from models.ws_v1 import OutputAudioPlayedMessage, ToolCallResultsMessage, parse_client_message
|
||||
from services.base import LLMStreamEvent
|
||||
|
||||
|
||||
@@ -432,6 +432,45 @@ async def test_ws_message_parses_tool_call_results():
|
||||
assert msg.results[0].tool_call_id == "call_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_output_audio_played():
|
||||
msg = parse_client_message(
|
||||
{
|
||||
"type": "output.audio.played",
|
||||
"tts_id": "tts_1",
|
||||
"response_id": "resp_1",
|
||||
"turn_id": "turn_1",
|
||||
"played_at_ms": 1234567890,
|
||||
"played_ms": 2100,
|
||||
}
|
||||
)
|
||||
assert isinstance(msg, OutputAudioPlayedMessage)
|
||||
assert msg.tts_id == "tts_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_audio_played_updates_client_playback_state(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||
tts_id = pipeline._start_tts()
|
||||
pipeline._mark_client_playback_started(tts_id)
|
||||
|
||||
assert pipeline.is_client_playing_audio is True
|
||||
|
||||
await pipeline.handle_output_audio_played(
|
||||
tts_id=tts_id,
|
||||
response_id="resp_1",
|
||||
turn_id="turn_1",
|
||||
played_at_ms=1234567000,
|
||||
played_ms=1800,
|
||||
)
|
||||
|
||||
assert pipeline.is_client_playing_audio is False
|
||||
assert pipeline._last_client_played_tts_id == tts_id
|
||||
assert pipeline._last_client_played_response_id == "resp_1"
|
||||
assert pipeline._last_client_played_turn_id == "turn_1"
|
||||
assert pipeline._last_client_played_at_ms == 1234567000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from core.session import Session, WsSessionState
|
||||
from models.ws_v1 import SessionStartMessage, parse_client_message
|
||||
from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message
|
||||
|
||||
|
||||
def _session() -> Session:
|
||||
@@ -16,6 +16,17 @@ def test_parse_client_message_rejects_hello_message():
|
||||
parse_client_message({"type": "hello", "version": "v1"})
|
||||
|
||||
|
||||
def test_parse_client_message_accepts_output_audio_played():
|
||||
message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"})
|
||||
assert isinstance(message, OutputAudioPlayedMessage)
|
||||
assert message.tts_id == "tts_001"
|
||||
|
||||
|
||||
def test_parse_client_message_rejects_output_audio_played_without_tts_id():
|
||||
with pytest.raises(ValueError, match="tts_id"):
|
||||
parse_client_message({"type": "output.audio.played", "tts_id": ""})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_text_reports_invalid_message_for_hello():
|
||||
session = Session.__new__(Session)
|
||||
@@ -42,6 +53,45 @@ async def test_handle_text_reports_invalid_message_for_hello():
|
||||
assert "Unknown client message type: hello" in message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_v1_message_routes_output_audio_played_to_pipeline():
|
||||
session = Session.__new__(Session)
|
||||
session.id = "sess_output_audio_played"
|
||||
session.ws_state = WsSessionState.ACTIVE
|
||||
|
||||
received = {}
|
||||
|
||||
class _Pipeline:
|
||||
async def handle_output_audio_played(self, **payload):
|
||||
received.update(payload)
|
||||
|
||||
session.pipeline = _Pipeline()
|
||||
|
||||
async def _send_error(sender, message, code, **kwargs):
|
||||
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
||||
|
||||
session._send_error = _send_error
|
||||
|
||||
await session._handle_v1_message(
|
||||
OutputAudioPlayedMessage(
|
||||
type="output.audio.played",
|
||||
tts_id="tts_001",
|
||||
response_id="resp_001",
|
||||
turn_id="turn_001",
|
||||
played_at_ms=1730000018450,
|
||||
played_ms=2520,
|
||||
)
|
||||
)
|
||||
|
||||
assert received == {
|
||||
"tts_id": "tts_001",
|
||||
"response_id": "resp_001",
|
||||
"turn_id": "turn_001",
|
||||
"played_at_ms": 1730000018450,
|
||||
"played_ms": 2520,
|
||||
}
|
||||
|
||||
|
||||
def test_validate_metadata_rejects_services_payload():
|
||||
session = _session()
|
||||
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
|
||||
|
||||
Reference in New Issue
Block a user