- Introduced `output.audio.played` message type for client acknowledgment of audio playback completion. - Updated `DuplexPipeline` to track client playback state and handle playback completion events. - Enhanced session handling to route `output.audio.played` messages to the pipeline. - Revised API documentation to include details about the new message type and its fields. - Updated schema documentation to reflect the addition of `output.audio.played` in the message flow.
378 lines
12 KiB
Python
378 lines
12 KiB
Python
import pytest
|
|
|
|
from core.session import Session, WsSessionState
|
|
from models.ws_v1 import OutputAudioPlayedMessage, SessionStartMessage, parse_client_message
|
|
|
|
|
|
def _session() -> Session:
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_test"
|
|
session._assistant_id = "assistant_demo"
|
|
return session
|
|
|
|
|
|
def test_parse_client_message_rejects_hello_message():
|
|
with pytest.raises(ValueError, match="Unknown client message type: hello"):
|
|
parse_client_message({"type": "hello", "version": "v1"})
|
|
|
|
|
|
def test_parse_client_message_accepts_output_audio_played():
|
|
message = parse_client_message({"type": "output.audio.played", "tts_id": "tts_001"})
|
|
assert isinstance(message, OutputAudioPlayedMessage)
|
|
assert message.tts_id == "tts_001"
|
|
|
|
|
|
def test_parse_client_message_rejects_output_audio_played_without_tts_id():
|
|
with pytest.raises(ValueError, match="tts_id"):
|
|
parse_client_message({"type": "output.audio.played", "tts_id": ""})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_text_reports_invalid_message_for_hello():
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_invalid_hello"
|
|
session.ws_state = WsSessionState.WAIT_START
|
|
|
|
class _Transport:
|
|
async def close(self):
|
|
return None
|
|
|
|
session.transport = _Transport()
|
|
captured = []
|
|
|
|
async def _send_error(sender, message, code, **kwargs):
|
|
captured.append((sender, code, message, kwargs))
|
|
|
|
session._send_error = _send_error
|
|
|
|
await session.handle_text('{"type":"hello","version":"v1"}')
|
|
assert captured
|
|
sender, code, message, _ = captured[0]
|
|
assert sender == "client"
|
|
assert code == "protocol.invalid_message"
|
|
assert "Unknown client message type: hello" in message
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_v1_message_routes_output_audio_played_to_pipeline():
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_output_audio_played"
|
|
session.ws_state = WsSessionState.ACTIVE
|
|
|
|
received = {}
|
|
|
|
class _Pipeline:
|
|
async def handle_output_audio_played(self, **payload):
|
|
received.update(payload)
|
|
|
|
session.pipeline = _Pipeline()
|
|
|
|
async def _send_error(sender, message, code, **kwargs):
|
|
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
|
|
|
session._send_error = _send_error
|
|
|
|
await session._handle_v1_message(
|
|
OutputAudioPlayedMessage(
|
|
type="output.audio.played",
|
|
tts_id="tts_001",
|
|
response_id="resp_001",
|
|
turn_id="turn_001",
|
|
played_at_ms=1730000018450,
|
|
played_ms=2520,
|
|
)
|
|
)
|
|
|
|
assert received == {
|
|
"tts_id": "tts_001",
|
|
"response_id": "resp_001",
|
|
"turn_id": "turn_001",
|
|
"played_at_ms": 1730000018450,
|
|
"played_ms": 2520,
|
|
}
|
|
|
|
|
|
def test_validate_metadata_rejects_services_payload():
|
|
session = _session()
|
|
sanitized, error = session._validate_and_sanitize_client_metadata({"services": {"llm": {"provider": "openai"}}})
|
|
assert sanitized == {}
|
|
assert error is not None
|
|
assert error["code"] == "protocol.invalid_override"
|
|
|
|
|
|
def test_validate_metadata_rejects_secret_like_override_keys():
|
|
session = _session()
|
|
sanitized, error = session._validate_and_sanitize_client_metadata(
|
|
{
|
|
"overrides": {
|
|
"output": {"mode": "audio"},
|
|
"apiKey": "xxx",
|
|
}
|
|
}
|
|
)
|
|
assert sanitized == {}
|
|
assert error is not None
|
|
assert error["code"] == "protocol.invalid_override"
|
|
|
|
|
|
def test_validate_metadata_ignores_workflow_payload():
|
|
session = _session()
|
|
sanitized, error = session._validate_and_sanitize_client_metadata(
|
|
{
|
|
"workflow": {"nodes": [], "edges": []},
|
|
"channel": "web_debug",
|
|
"overrides": {"output": {"mode": "text"}},
|
|
}
|
|
)
|
|
assert error is None
|
|
assert "workflow" not in sanitized
|
|
assert sanitized["channel"] == "web_debug"
|
|
assert sanitized["overrides"]["output"]["mode"] == "text"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_server_runtime_metadata_returns_not_found_error():
|
|
session = _session()
|
|
|
|
class _Gateway:
|
|
async def fetch_assistant_config(self, assistant_id: str):
|
|
_ = assistant_id
|
|
return {"__error_code": "assistant.not_found"}
|
|
|
|
session._backend_gateway = _Gateway()
|
|
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
|
assert runtime == {}
|
|
assert error is not None
|
|
assert error["code"] == "assistant.not_found"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_server_runtime_metadata_returns_config_unavailable_error():
|
|
session = _session()
|
|
|
|
class _Gateway:
|
|
async def fetch_assistant_config(self, assistant_id: str):
|
|
_ = assistant_id
|
|
return None
|
|
|
|
session._backend_gateway = _Gateway()
|
|
runtime, error = await session._load_server_runtime_metadata("assistant_demo")
|
|
assert runtime == {}
|
|
assert error is not None
|
|
assert error["code"] == "assistant.config_unavailable"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_session_start_requires_assistant_id_and_closes_transport():
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_missing_assistant"
|
|
session.ws_state = WsSessionState.WAIT_START
|
|
session._assistant_id = None
|
|
|
|
class _Transport:
|
|
def __init__(self):
|
|
self.closed = False
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
transport = _Transport()
|
|
session.transport = transport
|
|
captured_codes = []
|
|
|
|
async def _send_error(sender, message, code, **kwargs):
|
|
_ = (sender, message, kwargs)
|
|
captured_codes.append(code)
|
|
|
|
session._send_error = _send_error
|
|
|
|
await session._handle_session_start(SessionStartMessage(type="session.start", metadata={}))
|
|
assert captured_codes == ["protocol.assistant_id_required"]
|
|
assert transport.closed is True
|
|
assert session.ws_state == WsSessionState.STOPPED
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_session_start_applies_whitelisted_overrides_and_ignores_workflow(monkeypatch):
|
|
monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", False)
|
|
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_start_ok"
|
|
session.ws_state = WsSessionState.WAIT_START
|
|
session.state = "created"
|
|
session._assistant_id = "assistant_demo"
|
|
session.current_track_id = Session.TRACK_CONTROL
|
|
session._pipeline_started = False
|
|
|
|
class _Transport:
|
|
async def close(self):
|
|
return None
|
|
|
|
class _Pipeline:
|
|
def __init__(self):
|
|
self.started = False
|
|
self.applied = {}
|
|
self.conversation = type("Conversation", (), {"system_prompt": ""})()
|
|
|
|
async def start(self):
|
|
self.started = True
|
|
|
|
async def emit_initial_greeting(self):
|
|
return None
|
|
|
|
def apply_runtime_overrides(self, metadata):
|
|
self.applied = dict(metadata)
|
|
|
|
def resolved_runtime_config(self):
|
|
return {
|
|
"output": {"mode": "text"},
|
|
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
|
|
"tools": {"allowlist": ["calculator"]},
|
|
}
|
|
|
|
session.transport = _Transport()
|
|
session.pipeline = _Pipeline()
|
|
events = []
|
|
|
|
async def _start_history_bridge(_metadata):
|
|
return None
|
|
|
|
async def _load_server_runtime_metadata(_assistant_id):
|
|
return (
|
|
{
|
|
"assistantId": "assistant_demo",
|
|
"configVersionId": "cfg_1",
|
|
"systemPrompt": "Base prompt",
|
|
"greeting": "Base greeting",
|
|
"output": {"mode": "audio"},
|
|
},
|
|
None,
|
|
)
|
|
|
|
async def _send_event(event):
|
|
events.append(event)
|
|
|
|
async def _send_error(sender, message, code, **kwargs):
|
|
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
|
|
|
session._start_history_bridge = _start_history_bridge
|
|
session._load_server_runtime_metadata = _load_server_runtime_metadata
|
|
session._send_event = _send_event
|
|
session._send_error = _send_error
|
|
|
|
await session._handle_session_start(
|
|
SessionStartMessage(
|
|
type="session.start",
|
|
metadata={
|
|
"workflow": {"nodes": []},
|
|
"channel": "web_debug",
|
|
"source": "debug_ui",
|
|
"history": {"userId": 7},
|
|
"overrides": {
|
|
"greeting": "Override greeting",
|
|
"output": {"mode": "text"},
|
|
"tools": [{"name": "calculator"}],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
assert session.ws_state == WsSessionState.ACTIVE
|
|
assert session.pipeline.started is True
|
|
assert session.pipeline.applied["assistantId"] == "assistant_demo"
|
|
assert session.pipeline.applied["greeting"] == "Override greeting"
|
|
assert session.pipeline.applied["output"]["mode"] == "text"
|
|
assert session.pipeline.applied["tools"] == [{"name": "calculator"}]
|
|
assert not any(str(item.get("type", "")).startswith("workflow.") for item in events)
|
|
assert not any(item.get("type") == "config.resolved" for item in events)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_session_start_emits_config_resolved_when_enabled(monkeypatch):
|
|
monkeypatch.setattr("core.session.settings.ws_emit_config_resolved", True)
|
|
|
|
session = Session.__new__(Session)
|
|
session.id = "sess_start_emit_config"
|
|
session.ws_state = WsSessionState.WAIT_START
|
|
session.state = "created"
|
|
session._assistant_id = "assistant_demo"
|
|
session.current_track_id = Session.TRACK_CONTROL
|
|
session._pipeline_started = False
|
|
|
|
class _Transport:
|
|
async def close(self):
|
|
return None
|
|
|
|
class _Pipeline:
|
|
def __init__(self):
|
|
self.started = False
|
|
self.applied = {}
|
|
self.conversation = type("Conversation", (), {"system_prompt": ""})()
|
|
|
|
async def start(self):
|
|
self.started = True
|
|
|
|
async def emit_initial_greeting(self):
|
|
return None
|
|
|
|
def apply_runtime_overrides(self, metadata):
|
|
self.applied = dict(metadata)
|
|
|
|
def resolved_runtime_config(self):
|
|
return {
|
|
"output": {"mode": "text"},
|
|
"services": {"llm": {"provider": "openai", "model": "gpt-4o-mini"}},
|
|
"tools": {"allowlist": ["calculator"]},
|
|
}
|
|
|
|
session.transport = _Transport()
|
|
session.pipeline = _Pipeline()
|
|
events = []
|
|
|
|
async def _start_history_bridge(_metadata):
|
|
return None
|
|
|
|
async def _load_server_runtime_metadata(_assistant_id):
|
|
return (
|
|
{
|
|
"assistantId": "assistant_demo",
|
|
"configVersionId": "cfg_1",
|
|
"systemPrompt": "Base prompt",
|
|
"greeting": "Base greeting",
|
|
"output": {"mode": "audio"},
|
|
},
|
|
None,
|
|
)
|
|
|
|
async def _send_event(event):
|
|
events.append(event)
|
|
|
|
async def _send_error(sender, message, code, **kwargs):
|
|
raise AssertionError(f"Unexpected error: sender={sender} code={code} message={message} kwargs={kwargs}")
|
|
|
|
session._start_history_bridge = _start_history_bridge
|
|
session._load_server_runtime_metadata = _load_server_runtime_metadata
|
|
session._send_event = _send_event
|
|
session._send_error = _send_error
|
|
|
|
await session._handle_session_start(
|
|
SessionStartMessage(
|
|
type="session.start",
|
|
metadata={
|
|
"channel": "web_debug",
|
|
"overrides": {
|
|
"output": {"mode": "text"},
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
config_event = next(item for item in events if item.get("type") == "config.resolved")
|
|
assert "appId" not in config_event["config"]
|
|
assert "configVersionId" not in config_event["config"]
|
|
assert "services" not in config_event["config"]
|
|
assert config_event["config"]["channel"] == "web_debug"
|
|
assert config_event["config"]["output"]["mode"] == "text"
|
|
assert config_event["config"]["tools"]["enabled"] is True
|
|
assert config_event["config"]["tools"]["count"] == 1
|