Add DashScope ASR model support and enhance related components
- Introduced DashScope as a new ASR model in the database initialization. - Updated ASRModel schema to include vendor information. - Enhanced ASR router to support DashScope-specific functionality, including connection testing and preview capabilities. - Modified frontend components to accommodate DashScope as a selectable vendor with appropriate default settings. - Added tests to validate DashScope ASR model creation, updates, and connectivity. - Updated backend API to handle DashScope-specific base URLs and vendor normalization.
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import wave
|
||||
from array import array
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
@@ -17,6 +25,32 @@ from ..schemas import (
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||||
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
||||
|
||||
try:
|
||||
from dashscope.audio.qwen_omni import TranscriptionParams
|
||||
except ImportError:
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
DASHSCOPE_IMPORT_ERROR = ""
|
||||
except Exception as exc:
|
||||
dashscope = None # type: ignore[assignment]
|
||||
MultiModality = None # type: ignore[assignment]
|
||||
OmniRealtimeConversation = None # type: ignore[assignment]
|
||||
TranscriptionParams = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
class OmniRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
@@ -29,12 +63,377 @@ def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
}
|
||||
|
||||
|
||||
def _is_dashscope_vendor(vendor: str) -> bool:
|
||||
return (vendor or "").strip().lower() == "dashscope"
|
||||
|
||||
|
||||
def _default_asr_model(vendor: str) -> str:
|
||||
if _is_openai_compatible_vendor(vendor):
|
||||
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
|
||||
if _is_dashscope_vendor(vendor):
|
||||
return DASHSCOPE_DEFAULT_ASR_MODEL
|
||||
return "whisper-1"
|
||||
|
||||
|
||||
def _dashscope_language(language: Optional[str]) -> Optional[str]:
|
||||
normalized = (language or "").strip().lower()
|
||||
if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}:
|
||||
return None
|
||||
if normalized.startswith("zh"):
|
||||
return "zh"
|
||||
if normalized.startswith("en"):
|
||||
return "en"
|
||||
return normalized
|
||||
|
||||
|
||||
class _DashScopePreviewCallback(OmniRealtimeCallback):
|
||||
"""Collect DashScope ASR websocket events for preview/test flows."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._open_event = threading.Event()
|
||||
self._session_ready_event = threading.Event()
|
||||
self._done_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
self._final_text = ""
|
||||
self._last_interim_text = ""
|
||||
self._error_message: Optional[str] = None
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._open_event.set()
|
||||
|
||||
def on_close(self, code: int, reason: str) -> None:
|
||||
if self._done_event.is_set():
|
||||
return
|
||||
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
|
||||
def on_error(self, message: Any) -> None:
|
||||
self._error_message = str(message)
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
|
||||
def on_event(self, response: Any) -> None:
|
||||
payload = _coerce_dashscope_event(response)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"session.created", "session.updated"}:
|
||||
self._session_ready_event.set()
|
||||
return
|
||||
|
||||
if event_type == "error" or event_type.endswith(".failed"):
|
||||
self._error_message = _format_dashscope_error_event(payload)
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.text":
|
||||
interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript"))
|
||||
if interim_text:
|
||||
with self._lock:
|
||||
self._last_interim_text = interim_text
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash"))
|
||||
with self._lock:
|
||||
if final_text:
|
||||
self._final_text = final_text
|
||||
self._done_event.set()
|
||||
return
|
||||
|
||||
if event_type in {"response.done", "session.finished"}:
|
||||
self._done_event.set()
|
||||
|
||||
def wait_for_open(self, timeout: float = 10.0) -> None:
|
||||
if not self._open_event.wait(timeout):
|
||||
raise TimeoutError("DashScope websocket open timeout")
|
||||
|
||||
def wait_for_session_ready(self, timeout: float = 6.0) -> bool:
|
||||
return self._session_ready_event.wait(timeout)
|
||||
|
||||
def wait_for_done(self, timeout: float = 20.0) -> None:
|
||||
if not self._done_event.wait(timeout):
|
||||
raise TimeoutError("DashScope transcription timeout")
|
||||
|
||||
def raise_if_error(self) -> None:
|
||||
if self._error_message:
|
||||
raise RuntimeError(self._error_message)
|
||||
|
||||
def read_text(self) -> str:
|
||||
with self._lock:
|
||||
return self._final_text or self._last_interim_text
|
||||
|
||||
|
||||
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
|
||||
if isinstance(response, dict):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
try:
|
||||
parsed = json.loads(response)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"type": "raw", "message": str(response)}
|
||||
|
||||
|
||||
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict):
|
||||
code = str(error.get("code") or "").strip()
|
||||
message = str(error.get("message") or "").strip()
|
||||
if code and message:
|
||||
return f"{code}: {message}"
|
||||
return message or str(error)
|
||||
return str(error or "DashScope realtime ASR error")
|
||||
|
||||
|
||||
def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
nested = _extract_dashscope_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, dict):
|
||||
nested = _extract_dashscope_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
return ""
|
||||
|
||||
|
||||
def _create_dashscope_realtime_client(
|
||||
*,
|
||||
model: str,
|
||||
callback: _DashScopePreviewCallback,
|
||||
url: str,
|
||||
api_key: str,
|
||||
) -> Any:
|
||||
if OmniRealtimeConversation is None:
|
||||
raise RuntimeError("DashScope SDK unavailable")
|
||||
|
||||
init_kwargs = {
|
||||
"model": model,
|
||||
"callback": callback,
|
||||
"url": url,
|
||||
}
|
||||
try:
|
||||
return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc]
|
||||
except TypeError as exc:
|
||||
if "api_key" not in str(exc):
|
||||
raise
|
||||
return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc]
|
||||
|
||||
|
||||
def _close_dashscope_client(client: Any) -> None:
|
||||
finish_fn = getattr(client, "finish", None)
|
||||
if callable(finish_fn):
|
||||
try:
|
||||
finish_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _configure_dashscope_session(
|
||||
*,
|
||||
client: Any,
|
||||
callback: _DashScopePreviewCallback,
|
||||
sample_rate: int,
|
||||
language: Optional[str],
|
||||
) -> None:
|
||||
update_fn = getattr(client, "update_session", None)
|
||||
if not callable(update_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
||||
|
||||
text_modality: Any = "text"
|
||||
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
||||
text_modality = MultiModality.TEXT
|
||||
|
||||
transcription_params: Optional[Any] = None
|
||||
language_hint = _dashscope_language(language)
|
||||
if TranscriptionParams is not None:
|
||||
try:
|
||||
params_kwargs: Dict[str, Any] = {
|
||||
"sample_rate": sample_rate,
|
||||
"input_audio_format": "pcm",
|
||||
}
|
||||
if language_hint:
|
||||
params_kwargs["language"] = language_hint
|
||||
transcription_params = TranscriptionParams(**params_kwargs)
|
||||
except Exception:
|
||||
transcription_params = None
|
||||
|
||||
update_attempts = [
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
"transcription_params": transcription_params,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
},
|
||||
]
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for params in update_attempts:
|
||||
if params.get("transcription_params") is None:
|
||||
params = {key: value for key, value in params.items() if key != "transcription_params"}
|
||||
try:
|
||||
update_fn(**params)
|
||||
callback.wait_for_session_ready()
|
||||
callback.raise_if_error()
|
||||
return
|
||||
except TypeError as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
||||
|
||||
|
||||
def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]:
|
||||
try:
|
||||
with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file:
|
||||
channel_count = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_rate = wav_file.getframerate()
|
||||
compression = wav_file.getcomptype()
|
||||
pcm_frames = wav_file.readframes(wav_file.getnframes())
|
||||
except wave.Error as exc:
|
||||
raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc
|
||||
|
||||
if compression != "NONE":
|
||||
raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.")
|
||||
if sample_width != 2:
|
||||
raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.")
|
||||
if not pcm_frames:
|
||||
raise RuntimeError("Uploaded WAV file is empty")
|
||||
if channel_count <= 1:
|
||||
return pcm_frames, sample_rate
|
||||
|
||||
samples = array("h")
|
||||
samples.frombytes(pcm_frames)
|
||||
if sys.byteorder == "big":
|
||||
samples.byteswap()
|
||||
|
||||
mono_samples = array(
|
||||
"h",
|
||||
(
|
||||
int(sum(samples[index:index + channel_count]) / channel_count)
|
||||
for index in range(0, len(samples), channel_count)
|
||||
),
|
||||
)
|
||||
if sys.byteorder == "big":
|
||||
mono_samples.byteswap()
|
||||
return mono_samples.tobytes(), sample_rate
|
||||
|
||||
|
||||
def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||||
|
||||
callback = _DashScopePreviewCallback()
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = api_key
|
||||
client = _create_dashscope_realtime_client(
|
||||
model=model,
|
||||
callback=callback,
|
||||
url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
client.connect()
|
||||
callback.wait_for_open()
|
||||
_configure_dashscope_session(
|
||||
client=client,
|
||||
callback=callback,
|
||||
sample_rate=16000,
|
||||
language=language,
|
||||
)
|
||||
finally:
|
||||
_close_dashscope_client(client)
|
||||
|
||||
|
||||
def _transcribe_dashscope_preview(
|
||||
*,
|
||||
audio_bytes: bytes,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
language: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||||
|
||||
pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes)
|
||||
callback = _DashScopePreviewCallback()
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = api_key
|
||||
client = _create_dashscope_realtime_client(
|
||||
model=model,
|
||||
callback=callback,
|
||||
url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
client.connect()
|
||||
callback.wait_for_open()
|
||||
_configure_dashscope_session(
|
||||
client=client,
|
||||
callback=callback,
|
||||
sample_rate=sample_rate,
|
||||
language=language,
|
||||
)
|
||||
|
||||
append_fn = getattr(client, "append_audio", None)
|
||||
if not callable(append_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
||||
commit_fn = getattr(client, "commit", None)
|
||||
if not callable(commit_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing commit method")
|
||||
|
||||
append_fn(base64.b64encode(pcm_audio).decode("ascii"))
|
||||
commit_fn()
|
||||
callback.wait_for_done()
|
||||
callback.raise_if_error()
|
||||
return {
|
||||
"transcript": callback.read_text(),
|
||||
"language": _dashscope_language(language) or "Multi-lingual",
|
||||
"confidence": None,
|
||||
}
|
||||
finally:
|
||||
_close_dashscope_client(client)
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("")
|
||||
def list_asr_models(
|
||||
@@ -132,6 +531,27 @@ def test_asr_model(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if _is_dashscope_vendor(model.vendor):
|
||||
effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
|
||||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||||
_probe_dashscope_asr_connection(
|
||||
api_key=effective_api_key,
|
||||
base_url=base_url,
|
||||
model=selected_model,
|
||||
language=model.language,
|
||||
)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return ASRTestResponse(
|
||||
success=True,
|
||||
language=model.language,
|
||||
latency_ms=latency_ms,
|
||||
message="DashScope realtime ASR connected",
|
||||
)
|
||||
|
||||
# 连接性测试优先,避免依赖真实音频输入
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
@@ -246,7 +666,7 @@ async def preview_asr_model(
|
||||
api_key: Optional[str] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。"""
|
||||
"""预览 ASR:根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
@@ -264,18 +684,50 @@ async def preview_asr_model(
|
||||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||||
|
||||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||||
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
if _is_openai_compatible_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
elif _is_dashscope_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
base_url = (model.base_url or "").strip().rstrip("/")
|
||||
if _is_dashscope_vendor(model.vendor) and not base_url:
|
||||
base_url = DASHSCOPE_DEFAULT_BASE_URL
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
|
||||
|
||||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||||
data = {"model": selected_model}
|
||||
effective_language = (language or "").strip() or None
|
||||
|
||||
start_time = time.time()
|
||||
if _is_dashscope_vendor(model.vendor):
|
||||
try:
|
||||
payload = await asyncio.to_thread(
|
||||
_transcribe_dashscope_preview,
|
||||
audio_bytes=audio_bytes,
|
||||
api_key=effective_api_key,
|
||||
base_url=base_url,
|
||||
model=selected_model,
|
||||
language=effective_language or model.language,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc
|
||||
|
||||
transcript = str(payload.get("transcript") or "")
|
||||
response_language = str(payload.get("language") or effective_language or model.language)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return ASRTestResponse(
|
||||
success=bool(transcript),
|
||||
transcript=transcript,
|
||||
language=response_language,
|
||||
confidence=None,
|
||||
latency_ms=latency_ms,
|
||||
message=None if transcript else "No transcript in response",
|
||||
)
|
||||
|
||||
data = {"model": selected_model}
|
||||
if effective_language:
|
||||
data["language"] = effective_language
|
||||
if model.hotwords:
|
||||
@@ -284,7 +736,6 @@ async def preview_asr_model(
|
||||
headers = {"Authorization": f"Bearer {effective_api_key}"}
|
||||
files = {"file": (filename, audio_bytes, content_type)}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=90.0) as client:
|
||||
response = client.post(
|
||||
|
||||
@@ -191,6 +191,7 @@ class ASRModelCreate(ASRModelBase):
|
||||
|
||||
class ASRModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
vendor: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user