Files
AI-VideoAssistant/api/app/routers/asr.py
Xin Wang bfe165daae Add DashScope ASR model support and enhance related components
- Introduced DashScope as a new ASR model in the database initialization.
- Updated ASRModel schema to include vendor information.
- Enhanced ASR router to support DashScope-specific functionality, including connection testing and preview capabilities.
- Modified frontend components to accommodate DashScope as a selectable vendor with appropriate default settings.
- Added tests to validate DashScope ASR model creation, updates, and connectivity.
- Updated backend API to handle DashScope-specific base URLs and vendor normalization.
2026-03-09 07:37:00 +08:00

786 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import base64
import io
import json
import os
import sys
import threading
import time
import wave
from array import array
from typing import Any, Dict, List, Optional, Tuple
import httpx
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.orm import Session
from ..db import get_db
from ..id_generator import unique_short_id
from ..models import ASRModel
from ..schemas import (
ASRModelCreate, ASRModelUpdate, ASRModelOut,
ASRTestRequest, ASRTestResponse
)
router = APIRouter(prefix="/asr", tags=["ASR Models"])
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime"
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
try:
import dashscope
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
try:
from dashscope.audio.qwen_omni import TranscriptionParams
except ImportError:
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
DASHSCOPE_SDK_AVAILABLE = True
DASHSCOPE_IMPORT_ERROR = ""
except Exception as exc:
dashscope = None # type: ignore[assignment]
MultiModality = None # type: ignore[assignment]
OmniRealtimeConversation = None # type: ignore[assignment]
TranscriptionParams = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
class OmniRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
def _is_openai_compatible_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"硅基流动", # backward compatibility
}
def _is_dashscope_vendor(vendor: str) -> bool:
return (vendor or "").strip().lower() == "dashscope"
def _default_asr_model(vendor: str) -> str:
if _is_openai_compatible_vendor(vendor):
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
if _is_dashscope_vendor(vendor):
return DASHSCOPE_DEFAULT_ASR_MODEL
return "whisper-1"
def _dashscope_language(language: Optional[str]) -> Optional[str]:
normalized = (language or "").strip().lower()
if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}:
return None
if normalized.startswith("zh"):
return "zh"
if normalized.startswith("en"):
return "en"
return normalized
class _DashScopePreviewCallback(OmniRealtimeCallback):
"""Collect DashScope ASR websocket events for preview/test flows."""
def __init__(self) -> None:
super().__init__()
self._open_event = threading.Event()
self._session_ready_event = threading.Event()
self._done_event = threading.Event()
self._lock = threading.Lock()
self._final_text = ""
self._last_interim_text = ""
self._error_message: Optional[str] = None
def on_open(self) -> None:
self._open_event.set()
def on_close(self, code: int, reason: str) -> None:
if self._done_event.is_set():
return
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
self._done_event.set()
self._session_ready_event.set()
def on_error(self, message: Any) -> None:
self._error_message = str(message)
self._done_event.set()
self._session_ready_event.set()
def on_event(self, response: Any) -> None:
payload = _coerce_dashscope_event(response)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"session.created", "session.updated"}:
self._session_ready_event.set()
return
if event_type == "error" or event_type.endswith(".failed"):
self._error_message = _format_dashscope_error_event(payload)
self._done_event.set()
self._session_ready_event.set()
return
if event_type == "conversation.item.input_audio_transcription.text":
interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript"))
if interim_text:
with self._lock:
self._last_interim_text = interim_text
return
if event_type == "conversation.item.input_audio_transcription.completed":
final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash"))
with self._lock:
if final_text:
self._final_text = final_text
self._done_event.set()
return
if event_type in {"response.done", "session.finished"}:
self._done_event.set()
def wait_for_open(self, timeout: float = 10.0) -> None:
if not self._open_event.wait(timeout):
raise TimeoutError("DashScope websocket open timeout")
def wait_for_session_ready(self, timeout: float = 6.0) -> bool:
return self._session_ready_event.wait(timeout)
def wait_for_done(self, timeout: float = 20.0) -> None:
if not self._done_event.wait(timeout):
raise TimeoutError("DashScope transcription timeout")
def raise_if_error(self) -> None:
if self._error_message:
raise RuntimeError(self._error_message)
def read_text(self) -> str:
with self._lock:
return self._final_text or self._last_interim_text
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
if isinstance(response, dict):
return response
if isinstance(response, str):
try:
parsed = json.loads(response)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {"type": "raw", "message": str(response)}
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
error = payload.get("error")
if isinstance(error, dict):
code = str(error.get("code") or "").strip()
message = str(error.get("message") or "").strip()
if code and message:
return f"{code}: {message}"
return message or str(error)
return str(error or "DashScope realtime ASR error")
def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = _extract_dashscope_text(value, keys=keys)
if nested:
return nested
for value in payload.values():
if isinstance(value, dict):
nested = _extract_dashscope_text(value, keys=keys)
if nested:
return nested
return ""
def _create_dashscope_realtime_client(
*,
model: str,
callback: _DashScopePreviewCallback,
url: str,
api_key: str,
) -> Any:
if OmniRealtimeConversation is None:
raise RuntimeError("DashScope SDK unavailable")
init_kwargs = {
"model": model,
"callback": callback,
"url": url,
}
try:
return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc]
except TypeError as exc:
if "api_key" not in str(exc):
raise
return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc]
def _close_dashscope_client(client: Any) -> None:
finish_fn = getattr(client, "finish", None)
if callable(finish_fn):
try:
finish_fn()
except Exception:
pass
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
def _configure_dashscope_session(
*,
client: Any,
callback: _DashScopePreviewCallback,
sample_rate: int,
language: Optional[str],
) -> None:
update_fn = getattr(client, "update_session", None)
if not callable(update_fn):
raise RuntimeError("DashScope ASR SDK missing update_session method")
text_modality: Any = "text"
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
text_modality = MultiModality.TEXT
transcription_params: Optional[Any] = None
language_hint = _dashscope_language(language)
if TranscriptionParams is not None:
try:
params_kwargs: Dict[str, Any] = {
"sample_rate": sample_rate,
"input_audio_format": "pcm",
}
if language_hint:
params_kwargs["language"] = language_hint
transcription_params = TranscriptionParams(**params_kwargs)
except Exception:
transcription_params = None
update_attempts = [
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
"transcription_params": transcription_params,
},
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
},
{
"output_modalities": [text_modality],
},
]
last_error: Optional[Exception] = None
for params in update_attempts:
if params.get("transcription_params") is None:
params = {key: value for key, value in params.items() if key != "transcription_params"}
try:
update_fn(**params)
callback.wait_for_session_ready()
callback.raise_if_error()
return
except TypeError as exc:
last_error = exc
continue
except Exception as exc:
last_error = exc
continue
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]:
try:
with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file:
channel_count = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
sample_rate = wav_file.getframerate()
compression = wav_file.getcomptype()
pcm_frames = wav_file.readframes(wav_file.getnframes())
except wave.Error as exc:
raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc
if compression != "NONE":
raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.")
if sample_width != 2:
raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.")
if not pcm_frames:
raise RuntimeError("Uploaded WAV file is empty")
if channel_count <= 1:
return pcm_frames, sample_rate
samples = array("h")
samples.frombytes(pcm_frames)
if sys.byteorder == "big":
samples.byteswap()
mono_samples = array(
"h",
(
int(sum(samples[index:index + channel_count]) / channel_count)
for index in range(0, len(samples), channel_count)
),
)
if sys.byteorder == "big":
mono_samples.byteswap()
return mono_samples.tobytes(), sample_rate
def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
callback = _DashScopePreviewCallback()
if dashscope is not None:
dashscope.api_key = api_key
client = _create_dashscope_realtime_client(
model=model,
callback=callback,
url=base_url,
api_key=api_key,
)
try:
client.connect()
callback.wait_for_open()
_configure_dashscope_session(
client=client,
callback=callback,
sample_rate=16000,
language=language,
)
finally:
_close_dashscope_client(client)
def _transcribe_dashscope_preview(
*,
audio_bytes: bytes,
api_key: str,
base_url: str,
model: str,
language: Optional[str],
) -> Dict[str, Any]:
if not DASHSCOPE_SDK_AVAILABLE:
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes)
callback = _DashScopePreviewCallback()
if dashscope is not None:
dashscope.api_key = api_key
client = _create_dashscope_realtime_client(
model=model,
callback=callback,
url=base_url,
api_key=api_key,
)
try:
client.connect()
callback.wait_for_open()
_configure_dashscope_session(
client=client,
callback=callback,
sample_rate=sample_rate,
language=language,
)
append_fn = getattr(client, "append_audio", None)
if not callable(append_fn):
raise RuntimeError("DashScope ASR SDK missing append_audio method")
commit_fn = getattr(client, "commit", None)
if not callable(commit_fn):
raise RuntimeError("DashScope ASR SDK missing commit method")
append_fn(base64.b64encode(pcm_audio).decode("ascii"))
commit_fn()
callback.wait_for_done()
callback.raise_if_error()
return {
"transcript": callback.read_text(),
"language": _dashscope_language(language) or "Multi-lingual",
"confidence": None,
}
finally:
_close_dashscope_client(client)
# ============ ASR Models CRUD ============
@router.get("")
def list_asr_models(
language: Optional[str] = None,
enabled: Optional[bool] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取ASR模型列表"""
query = db.query(ASRModel)
if language:
query = query.filter(ASRModel.language == language)
if enabled is not None:
query = query.filter(ASRModel.enabled == enabled)
total = query.count()
models = query.order_by(ASRModel.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": models}
@router.get("/{id}", response_model=ASRModelOut)
def get_asr_model(id: str, db: Session = Depends(get_db)):
"""获取单个ASR模型详情"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
return model
@router.post("", response_model=ASRModelOut)
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
"""创建ASR模型"""
asr_model = ASRModel(
id=unique_short_id("asr", db, ASRModel),
user_id=1, # 默认用户
name=data.name,
vendor=data.vendor,
language=data.language,
base_url=data.base_url,
api_key=data.api_key,
model_name=data.model_name,
hotwords=data.hotwords,
enable_punctuation=data.enable_punctuation,
enable_normalization=data.enable_normalization,
enabled=data.enabled,
)
db.add(asr_model)
db.commit()
db.refresh(asr_model)
return asr_model
@router.put("/{id}", response_model=ASRModelOut)
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
"""更新ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(model, field, value)
db.commit()
db.refresh(model)
return model
@router.delete("/{id}")
def delete_asr_model(id: str, db: Session = Depends(get_db)):
"""删除ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
db.delete(model)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/test", response_model=ASRTestResponse)
def test_asr_model(
id: str,
request: Optional[ASRTestRequest] = None,
db: Session = Depends(get_db)
):
"""测试ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
start_time = time.time()
try:
if _is_dashscope_vendor(model.vendor):
effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
if not effective_api_key:
return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}")
base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
_probe_dashscope_asr_connection(
api_key=effective_api_key,
base_url=base_url,
model=selected_model,
language=model.language,
)
latency_ms = int((time.time() - start_time) * 1000)
return ASRTestResponse(
success=True,
language=model.language,
latency_ms=latency_ms,
message="DashScope realtime ASR connected",
)
# 连接性测试优先,避免依赖真实音频输入
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer":
response = client.get(f"{model.base_url}/asr", headers=headers)
elif model.vendor.lower() == "openai":
response = client.get(f"{model.base_url}/audio/models", headers=headers)
else:
response = client.get(f"{model.base_url}/health", headers=headers)
response.raise_for_status()
raw_result = response.json()
# 兼容不同供应商格式
if isinstance(raw_result, dict) and "results" in raw_result:
result = raw_result
elif isinstance(raw_result, dict) and "text" in raw_result:
result = {"results": [{"transcript": raw_result.get("text", "")}]}
else:
result = {"results": [{"transcript": ""}]}
latency_ms = int((time.time() - start_time) * 1000)
# 解析结果
if result_data := result.get("results", [{}])[0]:
transcript = result_data.get("transcript", "")
return ASRTestResponse(
success=True,
transcript=transcript,
language=result_data.get("language", model.language),
confidence=result_data.get("confidence"),
latency_ms=latency_ms,
)
return ASRTestResponse(
success=False,
message="No transcript in response",
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
return ASRTestResponse(
success=False,
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
return ASRTestResponse(
success=False,
error=str(e)[:200]
)
@router.post("/{id}/transcribe")
def transcribe_audio(
id: str,
audio_url: Optional[str] = None,
audio_data: Optional[str] = None,
hotwords: Optional[List[str]] = None,
db: Session = Depends(get_db)
):
"""转写音频"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
try:
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(hotwords or model.hotwords or []),
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
headers = {"Authorization": f"Bearer {model.api_key}"}
if audio_url:
payload["input"]["url"] = audio_url
elif audio_data:
payload["input"]["file_urls"] = []
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
if result_data := result.get("results", [{}])[0]:
return {
"success": True,
"transcript": result_data.get("transcript", ""),
"language": result_data.get("language", model.language),
"confidence": result_data.get("confidence"),
}
return {"success": False, "error": "No transcript in response"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{id}/preview", response_model=ASRTestResponse)
async def preview_asr_model(
id: str,
file: UploadFile = File(...),
language: Optional[str] = Form(None),
api_key: Optional[str] = Form(None),
db: Session = Depends(get_db),
):
"""预览 ASR根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
if not file:
raise HTTPException(status_code=400, detail="Audio file is required")
filename = file.filename or "preview.wav"
content_type = file.content_type or "application/octet-stream"
if not content_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="Only audio files are supported")
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
if not effective_api_key:
if _is_openai_compatible_vendor(model.vendor):
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
elif _is_dashscope_vendor(model.vendor):
effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
if not effective_api_key:
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
base_url = (model.base_url or "").strip().rstrip("/")
if _is_dashscope_vendor(model.vendor) and not base_url:
base_url = DASHSCOPE_DEFAULT_BASE_URL
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
effective_language = (language or "").strip() or None
start_time = time.time()
if _is_dashscope_vendor(model.vendor):
try:
payload = await asyncio.to_thread(
_transcribe_dashscope_preview,
audio_bytes=audio_bytes,
api_key=effective_api_key,
base_url=base_url,
model=selected_model,
language=effective_language or model.language,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc
transcript = str(payload.get("transcript") or "")
response_language = str(payload.get("language") or effective_language or model.language)
latency_ms = int((time.time() - start_time) * 1000)
return ASRTestResponse(
success=bool(transcript),
transcript=transcript,
language=response_language,
confidence=None,
latency_ms=latency_ms,
message=None if transcript else "No transcript in response",
)
data = {"model": selected_model}
if effective_language:
data["language"] = effective_language
if model.hotwords:
data["prompt"] = " ".join(model.hotwords)
headers = {"Authorization": f"Bearer {effective_api_key}"}
files = {"file": (filename, audio_bytes, content_type)}
try:
with httpx.Client(timeout=90.0) as client:
response = client.post(
f"{base_url}/audio/transcriptions",
headers=headers,
data=data,
files=files,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc
if response.status_code != 200:
detail = response.text
try:
detail_json = response.json()
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
except Exception:
pass
raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}")
try:
payload = response.json()
except Exception:
payload = {"text": response.text}
transcript = ""
response_language = model.language
confidence = None
if isinstance(payload, dict):
transcript = str(payload.get("text") or payload.get("transcript") or "")
response_language = str(payload.get("language") or effective_language or model.language)
raw_confidence = payload.get("confidence")
if raw_confidence is not None:
try:
confidence = float(raw_confidence)
except (TypeError, ValueError):
confidence = None
latency_ms = int((time.time() - start_time) * 1000)
return ASRTestResponse(
success=bool(transcript),
transcript=transcript,
language=response_language,
confidence=confidence,
latency_ms=latency_ms,
message=None if transcript else "No transcript in response",
)