- Introduced `asr_interim_enabled` field in the Assistant model to control interim ASR results. - Updated AssistantBase and AssistantUpdate schemas to include the new field. - Modified the database schema to add the `asr_interim_enabled` column. - Enhanced runtime metadata to reflect interim ASR settings. - Updated API endpoints and tests to validate the new functionality. - Adjusted documentation to include details about interim ASR results configuration.
261 lines
7.3 KiB
Python
261 lines
7.3 KiB
Python
import asyncio
|
|
from typing import Any, Dict, List
|
|
|
|
import pytest
|
|
|
|
from runtime.pipeline.duplex import DuplexPipeline
|
|
|
|
|
|
class _DummySileroVAD:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def process_audio(self, _pcm: bytes) -> float:
|
|
return 0.0
|
|
|
|
|
|
class _DummyVADProcessor:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def process(self, _speech_prob: float):
|
|
return "Silence", 0.0
|
|
|
|
|
|
class _DummyEouDetector:
|
|
def __init__(self, *args, **kwargs):
|
|
self.is_speaking = True
|
|
|
|
def process(self, _vad_status: str, force_eligible: bool = False) -> bool:
|
|
_ = force_eligible
|
|
return False
|
|
|
|
def reset(self) -> None:
|
|
self.is_speaking = False
|
|
|
|
|
|
class _FakeTransport:
|
|
async def send_event(self, _event: Dict[str, Any]) -> None:
|
|
return None
|
|
|
|
async def send_audio(self, _audio: bytes) -> None:
|
|
return None
|
|
|
|
|
|
class _FakeStreamingASR:
|
|
mode = "streaming"
|
|
|
|
def __init__(self):
|
|
self.begin_calls = 0
|
|
self.end_calls = 0
|
|
self.wait_calls = 0
|
|
self.sent_audio: List[bytes] = []
|
|
self.wait_text = ""
|
|
|
|
async def connect(self) -> None:
|
|
return None
|
|
|
|
async def disconnect(self) -> None:
|
|
return None
|
|
|
|
async def send_audio(self, audio: bytes) -> None:
|
|
self.sent_audio.append(audio)
|
|
|
|
async def receive_transcripts(self):
|
|
if False:
|
|
yield None
|
|
|
|
async def begin_utterance(self) -> None:
|
|
self.begin_calls += 1
|
|
|
|
async def end_utterance(self) -> None:
|
|
self.end_calls += 1
|
|
|
|
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
|
|
_ = timeout_ms
|
|
self.wait_calls += 1
|
|
return self.wait_text
|
|
|
|
def clear_utterance(self) -> None:
|
|
return None
|
|
|
|
|
|
class _FakeOfflineASR:
|
|
mode = "offline"
|
|
|
|
def __init__(self):
|
|
self.start_interim_calls = 0
|
|
self.stop_interim_calls = 0
|
|
self.sent_audio: List[bytes] = []
|
|
self.final_text = "offline final"
|
|
|
|
async def connect(self) -> None:
|
|
return None
|
|
|
|
async def disconnect(self) -> None:
|
|
return None
|
|
|
|
async def send_audio(self, audio: bytes) -> None:
|
|
self.sent_audio.append(audio)
|
|
|
|
async def receive_transcripts(self):
|
|
if False:
|
|
yield None
|
|
|
|
async def start_interim_transcription(self) -> None:
|
|
self.start_interim_calls += 1
|
|
|
|
async def stop_interim_transcription(self) -> None:
|
|
self.stop_interim_calls += 1
|
|
|
|
async def get_final_transcription(self) -> str:
|
|
return self.final_text
|
|
|
|
def clear_buffer(self) -> None:
|
|
return None
|
|
|
|
def get_and_clear_text(self) -> str:
|
|
return self.final_text
|
|
|
|
|
|
def _build_pipeline(monkeypatch, asr_service):
|
|
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
|
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
|
monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector)
|
|
return DuplexPipeline(
|
|
transport=_FakeTransport(),
|
|
session_id="asr_mode_test",
|
|
asr_service=asr_service,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
|
|
asr = _FakeStreamingASR()
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "streaming"
|
|
pipeline._pending_speech_audio = b"\x00" * 320
|
|
pipeline._pre_speech_buffer = b"\x00" * 640
|
|
|
|
await pipeline._start_asr_capture()
|
|
|
|
assert asr.begin_calls == 1
|
|
assert asr.sent_audio
|
|
assert pipeline._asr_capture_active is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_asr_capture_uses_offline_interim_control_when_enabled(monkeypatch):
|
|
asr = _FakeOfflineASR()
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "offline"
|
|
pipeline._runtime_asr["enableInterim"] = True
|
|
pipeline._pending_speech_audio = b"\x00" * 320
|
|
pipeline._pre_speech_buffer = b"\x00" * 640
|
|
|
|
await pipeline._start_asr_capture()
|
|
|
|
assert asr.start_interim_calls == 1
|
|
assert asr.sent_audio
|
|
assert pipeline._asr_capture_active is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_asr_capture_skips_offline_interim_control_when_disabled(monkeypatch):
|
|
asr = _FakeOfflineASR()
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "offline"
|
|
pipeline._runtime_asr["enableInterim"] = False
|
|
pipeline._pending_speech_audio = b"\x00" * 320
|
|
pipeline._pre_speech_buffer = b"\x00" * 640
|
|
|
|
await pipeline._start_asr_capture()
|
|
|
|
assert asr.start_interim_calls == 0
|
|
assert asr.sent_audio
|
|
assert pipeline._asr_capture_active is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_offline_interim_callback_ignored_when_disabled(monkeypatch):
|
|
asr = _FakeOfflineASR()
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "offline"
|
|
pipeline._runtime_asr["enableInterim"] = False
|
|
|
|
captured_events = []
|
|
captured_deltas = []
|
|
|
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
|
_ = priority
|
|
captured_events.append(event)
|
|
|
|
async def _capture_delta(text: str):
|
|
captured_deltas.append(text)
|
|
|
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
|
monkeypatch.setattr(pipeline, "_emit_transcript_delta", _capture_delta)
|
|
|
|
await pipeline._on_transcript_callback("ignored interim", is_final=False)
|
|
|
|
assert captured_events == []
|
|
assert captured_deltas == []
|
|
assert pipeline._latest_asr_interim_text == ""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_offline_final_callback_emits_when_interim_disabled(monkeypatch):
|
|
asr = _FakeOfflineASR()
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "offline"
|
|
pipeline._runtime_asr["enableInterim"] = False
|
|
|
|
captured_events = []
|
|
|
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
|
_ = priority
|
|
captured_events.append(event)
|
|
|
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
|
|
|
await pipeline._on_transcript_callback("final only", is_final=True)
|
|
|
|
assert any(event.get("type") == "transcript.final" for event in captured_events)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
|
|
asr = _FakeStreamingASR()
|
|
asr.wait_text = ""
|
|
pipeline = _build_pipeline(monkeypatch, asr)
|
|
pipeline._asr_mode = "streaming"
|
|
pipeline._asr_capture_active = True
|
|
pipeline._latest_asr_interim_text = "fallback interim text"
|
|
await pipeline.conversation.start_user_turn()
|
|
|
|
captured_events = []
|
|
captured_turns = []
|
|
|
|
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
|
_ = priority
|
|
captured_events.append(event)
|
|
|
|
async def _noop_stop_current_speech() -> None:
|
|
return None
|
|
|
|
async def _capture_turn(user_text: str, *args, **kwargs) -> None:
|
|
_ = (args, kwargs)
|
|
captured_turns.append(user_text)
|
|
|
|
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
|
monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech)
|
|
monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn)
|
|
|
|
await pipeline._on_end_of_utterance()
|
|
await asyncio.sleep(0.05)
|
|
|
|
assert asr.end_calls == 1
|
|
assert asr.wait_calls == 1
|
|
assert captured_turns == ["fallback interim text"]
|
|
assert any(event.get("type") == "transcript.final" for event in captured_events)
|