Implement DashScope ASR provider and enhance ASR service architecture
- Added DashScope ASR service implementation for real-time streaming. - Updated ASR provider logic to support DashScope alongside existing providers. - Enhanced runtime metadata resolution to include DashScope as a valid ASR provider. - Modified configuration files and documentation to reflect the addition of DashScope. - Introduced tests to validate DashScope integration and ASR service behavior. - Refactored ASR service factory to accommodate new provider options and modes.
This commit is contained in:
46
engine/tests/test_asr_factory_modes.py
Normal file
46
engine/tests/test_asr_factory_modes.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from providers.asr.buffered import BufferedASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
||||
from providers.factory.default import DefaultRealtimeServiceFactory
|
||||
from runtime.ports import ASRServiceSpec
|
||||
|
||||
|
||||
def test_create_asr_service_dashscope_returns_streaming_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="dashscope",
|
||||
mode="streaming",
|
||||
sample_rate=16000,
|
||||
api_key="test-key",
|
||||
model="qwen3-asr-flash-realtime",
|
||||
)
|
||||
)
|
||||
assert isinstance(service, DashScopeRealtimeASRService)
|
||||
assert service.mode == "streaming"
|
||||
|
||||
|
||||
def test_create_asr_service_openai_compatible_returns_offline_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="openai_compatible",
|
||||
sample_rate=16000,
|
||||
api_key="test-key",
|
||||
model="FunAudioLLM/SenseVoiceSmall",
|
||||
)
|
||||
)
|
||||
assert isinstance(service, OpenAICompatibleASRService)
|
||||
assert service.mode == "offline"
|
||||
|
||||
|
||||
def test_create_asr_service_fallback_buffered_for_unsupported_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="unknown_provider",
|
||||
sample_rate=16000,
|
||||
)
|
||||
)
|
||||
assert isinstance(service, BufferedASRService)
|
||||
assert service.mode == "offline"
|
||||
67
engine/tests/test_dashscope_asr_provider.py
Normal file
67
engine/tests/test_dashscope_asr_provider.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_asr_interim_event_emits_interim_transcript():
|
||||
received = []
|
||||
|
||||
async def _on_transcript(text: str, is_final: bool) -> None:
|
||||
received.append((text, is_final))
|
||||
|
||||
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
|
||||
service._loop = asyncio.get_running_loop()
|
||||
service._running = True
|
||||
|
||||
service._on_ws_event(
|
||||
{
|
||||
"type": "conversation.item.input_audio_transcription.text",
|
||||
"stash": "你好世界",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
result = service._transcript_queue.get_nowait()
|
||||
assert result.text == "你好世界"
|
||||
assert result.is_final is False
|
||||
assert received == [("你好世界", False)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue():
|
||||
received = []
|
||||
|
||||
async def _on_transcript(text: str, is_final: bool) -> None:
|
||||
received.append((text, is_final))
|
||||
|
||||
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
|
||||
service._loop = asyncio.get_running_loop()
|
||||
service._running = True
|
||||
service._audio_sent_in_utterance = True
|
||||
|
||||
service._on_ws_event(
|
||||
{
|
||||
"type": "conversation.item.input_audio_transcription.completed",
|
||||
"transcript": "最终识别结果",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
result = service._transcript_queue.get_nowait()
|
||||
assert result.text == "最终识别结果"
|
||||
assert result.is_final is True
|
||||
assert service._final_queue.get_nowait() == "最终识别结果"
|
||||
assert received == [("最终识别结果", True)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout():
|
||||
service = DashScopeRealtimeASRService(api_key="test-key")
|
||||
service._audio_sent_in_utterance = True
|
||||
service._last_interim_text = "部分结果"
|
||||
|
||||
text = await service.wait_for_final_transcription(timeout_ms=10)
|
||||
assert text == "部分结果"
|
||||
196
engine/tests/test_duplex_asr_modes.py
Normal file
196
engine/tests/test_duplex_asr_modes.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from runtime.pipeline.duplex import DuplexPipeline
|
||||
|
||||
|
||||
class _DummySileroVAD:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process_audio(self, _pcm: bytes) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class _DummyVADProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _speech_prob: float):
|
||||
return "Silence", 0.0
|
||||
|
||||
|
||||
class _DummyEouDetector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_speaking = True
|
||||
|
||||
def process(self, _vad_status: str, force_eligible: bool = False) -> bool:
|
||||
_ = force_eligible
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
self.is_speaking = False
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
async def send_event(self, _event: Dict[str, Any]) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeStreamingASR:
|
||||
mode = "streaming"
|
||||
|
||||
def __init__(self):
|
||||
self.begin_calls = 0
|
||||
self.end_calls = 0
|
||||
self.wait_calls = 0
|
||||
self.sent_audio: List[bytes] = []
|
||||
self.wait_text = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
self.sent_audio.append(audio)
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.begin_calls += 1
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
self.end_calls += 1
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
|
||||
_ = timeout_ms
|
||||
self.wait_calls += 1
|
||||
return self.wait_text
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeOfflineASR:
|
||||
mode = "offline"
|
||||
|
||||
def __init__(self):
|
||||
self.start_interim_calls = 0
|
||||
self.stop_interim_calls = 0
|
||||
self.sent_audio: List[bytes] = []
|
||||
self.final_text = "offline final"
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
self.sent_audio.append(audio)
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
self.start_interim_calls += 1
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
self.stop_interim_calls += 1
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return self.final_text
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
return self.final_text
|
||||
|
||||
|
||||
def _build_pipeline(monkeypatch, asr_service):
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector)
|
||||
return DuplexPipeline(
|
||||
transport=_FakeTransport(),
|
||||
session_id="asr_mode_test",
|
||||
asr_service=asr_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
|
||||
asr = _FakeStreamingASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "streaming"
|
||||
pipeline._pending_speech_audio = b"\x00" * 320
|
||||
pipeline._pre_speech_buffer = b"\x00" * 640
|
||||
|
||||
await pipeline._start_asr_capture()
|
||||
|
||||
assert asr.begin_calls == 1
|
||||
assert asr.sent_audio
|
||||
assert pipeline._asr_capture_active is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_asr_capture_uses_offline_interim_control(monkeypatch):
|
||||
asr = _FakeOfflineASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "offline"
|
||||
pipeline._pending_speech_audio = b"\x00" * 320
|
||||
pipeline._pre_speech_buffer = b"\x00" * 640
|
||||
|
||||
await pipeline._start_asr_capture()
|
||||
|
||||
assert asr.start_interim_calls == 1
|
||||
assert asr.sent_audio
|
||||
assert pipeline._asr_capture_active is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
|
||||
asr = _FakeStreamingASR()
|
||||
asr.wait_text = ""
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "streaming"
|
||||
pipeline._asr_capture_active = True
|
||||
pipeline._latest_asr_interim_text = "fallback interim text"
|
||||
await pipeline.conversation.start_user_turn()
|
||||
|
||||
captured_events = []
|
||||
captured_turns = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
_ = priority
|
||||
captured_events.append(event)
|
||||
|
||||
async def _noop_stop_current_speech() -> None:
|
||||
return None
|
||||
|
||||
async def _capture_turn(user_text: str, *args, **kwargs) -> None:
|
||||
_ = (args, kwargs)
|
||||
captured_turns.append(user_text)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech)
|
||||
monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn)
|
||||
|
||||
await pipeline._on_end_of_utterance()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert asr.end_calls == 1
|
||||
assert asr.wait_calls == 1
|
||||
assert captured_turns == ["fallback interim text"]
|
||||
assert any(event.get("type") == "transcript.final" for event in captured_events)
|
||||
@@ -52,9 +52,33 @@ class _FakeTTS:
|
||||
|
||||
|
||||
class _FakeASR:
|
||||
mode = "offline"
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
||||
|
||||
Reference in New Issue
Block a user