Add ASR interim results support in Assistant model and API

- Introduced `asr_interim_enabled` field in the Assistant model to control interim ASR results.
- Updated AssistantBase and AssistantUpdate schemas to include the new field.
- Modified the database schema to add the `asr_interim_enabled` column.
- Enhanced runtime metadata to reflect interim ASR settings.
- Updated API endpoints and tests to validate the new functionality.
- Adjusted documentation to include details about interim ASR results configuration.
This commit is contained in:
Xin Wang
2026-03-06 12:58:54 +08:00
parent e11c3abb9e
commit da38157638
19 changed files with 183 additions and 5 deletions

View File

@@ -249,6 +249,8 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
asr_runtime["apiKey"] = cls._as_str(asr.get("api_key"))
if cls._as_str(asr.get("api_url")):
asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url"))
if asr.get("enable_interim") is not None:
asr_runtime["enableInterim"] = asr.get("enable_interim")
if asr.get("interim_interval_ms") is not None:
asr_runtime["interimIntervalMs"] = asr.get("interim_interval_ms")
if asr.get("min_audio_ms") is not None:

View File

@@ -89,6 +89,7 @@ class Settings(BaseSettings):
)
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field(

View File

@@ -44,6 +44,7 @@ agent:
api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
enable_interim: false
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160

View File

@@ -41,6 +41,7 @@ agent:
api_key: your_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
enable_interim: false
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160

View File

@@ -53,6 +53,7 @@ class OpenAICompatibleASRService(BaseASRService):
model: str = "FunAudioLLM/SenseVoiceSmall",
sample_rate: int = 16000,
language: str = "auto",
enable_interim: bool = False,
interim_interval_ms: int = 500, # How often to send interim results
min_audio_for_interim_ms: int = 300, # Min audio before first interim
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
@@ -66,6 +67,7 @@ class OpenAICompatibleASRService(BaseASRService):
model: ASR model name or alias
sample_rate: Audio sample rate (16000 recommended)
language: Language code (auto for automatic detection)
enable_interim: Whether to generate interim transcriptions in offline mode
interim_interval_ms: How often to generate interim transcriptions
min_audio_for_interim_ms: Minimum audio duration before first interim
on_transcript: Callback for transcription results (text, is_final)
@@ -80,6 +82,7 @@ class OpenAICompatibleASRService(BaseASRService):
raw_api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
self.api_url = self._resolve_transcriptions_endpoint(raw_api_url)
self.model = self.MODELS.get(model.lower(), model)
self.enable_interim = bool(enable_interim)
self.interim_interval_ms = interim_interval_ms
self.min_audio_for_interim_ms = min_audio_for_interim_ms
self.on_transcript = on_transcript
@@ -181,6 +184,9 @@ class OpenAICompatibleASRService(BaseASRService):
if not self._session:
logger.warning("ASR session not connected")
return None
if not is_final and not self.enable_interim:
return None
# Check minimum audio duration
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
@@ -310,6 +316,9 @@ class OpenAICompatibleASRService(BaseASRService):
This periodically transcribes buffered audio for
real-time feedback to the user.
"""
if not self.enable_interim:
return
if self._interim_task and not self._interim_task.done():
return

View File

@@ -117,6 +117,7 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
enable_interim=spec.enable_interim,
interim_interval_ms=spec.interim_interval_ms,
min_audio_for_interim_ms=spec.min_audio_for_interim_ms,
on_transcript=spec.on_transcript,

View File

@@ -599,6 +599,7 @@ class DuplexPipeline:
"provider": asr_provider,
"mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")),
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"enableInterim": self._asr_interim_enabled(),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
@@ -865,6 +866,20 @@ class DuplexPipeline:
return self._runtime_barge_in_min_duration_ms
return self._barge_in_min_duration_ms
def _asr_interim_enabled(self) -> bool:
current_mode = self._asr_mode
if not self.asr_service:
current_mode = self._resolve_asr_mode(
self._runtime_asr.get("provider") or settings.asr_provider,
self._runtime_asr.get("mode"),
)
if current_mode != "offline":
return True
enabled = self._coerce_bool(self._runtime_asr.get("enableInterim"))
if enabled is not None:
return enabled
return bool(settings.asr_enable_interim)
def _barge_in_silence_tolerance_frames(self) -> int:
"""Convert silence tolerance from ms to frame count using current chunk size."""
chunk_ms = max(1, settings.chunk_size_ms)
@@ -991,6 +1006,9 @@ class DuplexPipeline:
asr_api_key = self._runtime_asr.get("apiKey")
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim"))
if asr_enable_interim is None:
asr_enable_interim = bool(settings.asr_enable_interim)
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode"))
@@ -1004,6 +1022,7 @@ class DuplexPipeline:
api_key=str(asr_api_key).strip() if asr_api_key else None,
api_url=str(asr_api_url).strip() if asr_api_url else None,
model=str(asr_model).strip() if asr_model else None,
enable_interim=asr_enable_interim,
interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback,
@@ -1481,6 +1500,9 @@ class DuplexPipeline:
text: Transcribed text
is_final: Whether this is the final transcription
"""
if not is_final and not self._asr_interim_enabled():
return
# Avoid sending duplicate transcripts
if text == self._last_sent_transcript and not is_final:
return
@@ -1550,7 +1572,8 @@ class DuplexPipeline:
if self._asr_mode == "streaming":
await self._streaming_asr().begin_utterance()
else:
await self._offline_asr().start_interim_transcription()
if self._asr_interim_enabled():
await self._offline_asr().start_interim_transcription()
# Prime ASR with a short pre-speech context window so the utterance
# start isn't lost while waiting for VAD to transition to Speech.

View File

@@ -22,6 +22,7 @@ class ASRServiceSpec:
api_key: Optional[str] = None
api_url: Optional[str] = None
model: Optional[str] = None
enable_interim: bool = False
interim_interval_ms: int = 500
min_audio_for_interim_ms: int = 300
on_transcript: Optional[TranscriptCallback] = None

View File

@@ -32,6 +32,7 @@ def test_create_asr_service_openai_compatible_returns_offline_provider():
)
assert isinstance(service, OpenAICompatibleASRService)
assert service.mode == "offline"
assert service.enable_interim is False
def test_create_asr_service_fallback_buffered_for_unsupported_provider():

View File

@@ -282,7 +282,7 @@ async def test_local_yaml_adapter_rejects_path_traversal_like_assistant_id(tmp_p
@pytest.mark.asyncio
async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
async def test_local_yaml_translates_agent_schema_with_asr_interim_flag(tmp_path):
config_dir = tmp_path / "assistants"
config_dir.mkdir(parents=True, exist_ok=True)
(config_dir / "default.yaml").write_text(
@@ -305,6 +305,7 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
" model: asr-model",
" api_key: sk-asr",
" api_url: https://asr.example.com/v1/audio/transcriptions",
" enable_interim: false",
" duplex:",
" system_prompt: You are test assistant",
]
@@ -321,4 +322,5 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
assert services.get("llm", {}).get("apiKey") == "sk-llm"
assert services.get("tts", {}).get("apiKey") == "sk-tts"
assert services.get("asr", {}).get("apiKey") == "sk-asr"
assert services.get("asr", {}).get("enableInterim") is False
assert assistant.get("systemPrompt") == "You are test assistant"

View File

@@ -145,10 +145,11 @@ async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
@pytest.mark.asyncio
async def test_start_asr_capture_uses_offline_interim_control(monkeypatch):
async def test_start_asr_capture_uses_offline_interim_control_when_enabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = True
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
@@ -159,6 +160,69 @@ async def test_start_asr_capture_uses_offline_interim_control(monkeypatch):
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_start_asr_capture_skips_offline_interim_control_when_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.start_interim_calls == 0
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_offline_interim_callback_ignored_when_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
captured_events = []
captured_deltas = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
async def _capture_delta(text: str):
captured_deltas.append(text)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_emit_transcript_delta", _capture_delta)
await pipeline._on_transcript_callback("ignored interim", is_final=False)
assert captured_events == []
assert captured_deltas == []
assert pipeline._latest_asr_interim_text == ""
@pytest.mark.asyncio
async def test_offline_final_callback_emits_when_interim_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
captured_events = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
await pipeline._on_transcript_callback("final only", is_final=True)
assert any(event.get("type") == "transcript.final" for event in captured_events)
@pytest.mark.asyncio
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
asr = _FakeStreamingASR()