import asyncio import base64 import io import json import os import sys import threading import time import wave from array import array from typing import Any, Dict, List, Optional, Tuple import httpx from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from sqlalchemy.orm import Session from ..db import get_db from ..id_generator import unique_short_id from ..models import ASRModel from ..schemas import ( ASRModelCreate, ASRModelUpdate, ASRModelOut, ASRTestRequest, ASRTestResponse ) router = APIRouter(prefix="/asr", tags=["ASR Models"]) OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime" DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" try: import dashscope from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation try: from dashscope.audio.qwen_omni import TranscriptionParams except ImportError: from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams DASHSCOPE_SDK_AVAILABLE = True DASHSCOPE_IMPORT_ERROR = "" except Exception as exc: dashscope = None # type: ignore[assignment] MultiModality = None # type: ignore[assignment] OmniRealtimeConversation = None # type: ignore[assignment] TranscriptionParams = None # type: ignore[assignment] DASHSCOPE_SDK_AVAILABLE = False DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}" class OmniRealtimeCallback: # type: ignore[no-redef] """Fallback callback base when DashScope SDK is unavailable.""" pass def _is_openai_compatible_vendor(vendor: str) -> bool: normalized = (vendor or "").strip().lower() return normalized in { "openai compatible", "openai-compatible", "siliconflow", # backward compatibility "硅基流动", # backward compatibility } def _is_dashscope_vendor(vendor: str) -> bool: return (vendor or "").strip().lower() == "dashscope" def _default_asr_model(vendor: str) -> str: if _is_openai_compatible_vendor(vendor): return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL if _is_dashscope_vendor(vendor): return DASHSCOPE_DEFAULT_ASR_MODEL return "whisper-1" def _dashscope_language(language: Optional[str]) -> Optional[str]: normalized = (language or "").strip().lower() if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}: return None if normalized.startswith("zh"): return "zh" if normalized.startswith("en"): return "en" return normalized class _DashScopePreviewCallback(OmniRealtimeCallback): """Collect DashScope ASR websocket events for preview/test flows.""" def __init__(self) -> None: super().__init__() self._open_event = threading.Event() self._session_ready_event = threading.Event() self._done_event = threading.Event() self._lock = threading.Lock() self._final_text = "" self._last_interim_text = "" self._error_message: Optional[str] = None def on_open(self) -> None: self._open_event.set() def on_close(self, code: int, reason: str) -> None: if self._done_event.is_set(): return self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}" self._done_event.set() self._session_ready_event.set() def on_error(self, message: Any) -> None: self._error_message = str(message) self._done_event.set() self._session_ready_event.set() def on_event(self, response: Any) -> None: payload = _coerce_dashscope_event(response) event_type = str(payload.get("type") or "").strip() if not event_type: return if event_type in {"session.created", "session.updated"}: self._session_ready_event.set() return if event_type == "error" or event_type.endswith(".failed"): self._error_message = _format_dashscope_error_event(payload) self._done_event.set() self._session_ready_event.set() return if event_type == "conversation.item.input_audio_transcription.text": interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript")) if interim_text: with self._lock: self._last_interim_text = interim_text return if event_type == "conversation.item.input_audio_transcription.completed": final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash")) with self._lock: if final_text: self._final_text = final_text self._done_event.set() return if event_type in {"response.done", "session.finished"}: self._done_event.set() def wait_for_open(self, timeout: float = 10.0) -> None: if not self._open_event.wait(timeout): raise TimeoutError("DashScope websocket open timeout") def wait_for_session_ready(self, timeout: float = 6.0) -> bool: return self._session_ready_event.wait(timeout) def wait_for_done(self, timeout: float = 20.0) -> None: if not self._done_event.wait(timeout): raise TimeoutError("DashScope transcription timeout") def raise_if_error(self) -> None: if self._error_message: raise RuntimeError(self._error_message) def read_text(self) -> str: with self._lock: return self._final_text or self._last_interim_text def _coerce_dashscope_event(response: Any) -> Dict[str, Any]: if isinstance(response, dict): return response if isinstance(response, str): try: parsed = json.loads(response) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass return {"type": "raw", "message": str(response)} def _format_dashscope_error_event(payload: Dict[str, Any]) -> str: error = payload.get("error") if isinstance(error, dict): code = str(error.get("code") or "").strip() message = str(error.get("message") or "").strip() if code and message: return f"{code}: {message}" return message or str(error) return str(error or "DashScope realtime ASR error") def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str: for key in keys: value = payload.get(key) if isinstance(value, str) and value.strip(): return value.strip() if isinstance(value, dict): nested = _extract_dashscope_text(value, keys=keys) if nested: return nested for value in payload.values(): if isinstance(value, dict): nested = _extract_dashscope_text(value, keys=keys) if nested: return nested return "" def _create_dashscope_realtime_client( *, model: str, callback: _DashScopePreviewCallback, url: str, api_key: str, ) -> Any: if OmniRealtimeConversation is None: raise RuntimeError("DashScope SDK unavailable") init_kwargs = { "model": model, "callback": callback, "url": url, } try: return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc] except TypeError as exc: if "api_key" not in str(exc): raise return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc] def _close_dashscope_client(client: Any) -> None: finish_fn = getattr(client, "finish", None) if callable(finish_fn): try: finish_fn() except Exception: pass close_fn = getattr(client, "close", None) if callable(close_fn): try: close_fn() except Exception: pass def _configure_dashscope_session( *, client: Any, callback: _DashScopePreviewCallback, sample_rate: int, language: Optional[str], ) -> None: update_fn = getattr(client, "update_session", None) if not callable(update_fn): raise RuntimeError("DashScope ASR SDK missing update_session method") text_modality: Any = "text" if MultiModality is not None and hasattr(MultiModality, "TEXT"): text_modality = MultiModality.TEXT transcription_params: Optional[Any] = None language_hint = _dashscope_language(language) if TranscriptionParams is not None: try: params_kwargs: Dict[str, Any] = { "sample_rate": sample_rate, "input_audio_format": "pcm", } if language_hint: params_kwargs["language"] = language_hint transcription_params = TranscriptionParams(**params_kwargs) except Exception: transcription_params = None update_attempts = [ { "output_modalities": [text_modality], "enable_turn_detection": False, "enable_input_audio_transcription": True, "transcription_params": transcription_params, }, { "output_modalities": [text_modality], "enable_turn_detection": False, "enable_input_audio_transcription": True, }, { "output_modalities": [text_modality], }, ] last_error: Optional[Exception] = None for params in update_attempts: if params.get("transcription_params") is None: params = {key: value for key, value in params.items() if key != "transcription_params"} try: update_fn(**params) callback.wait_for_session_ready() callback.raise_if_error() return except TypeError as exc: last_error = exc continue except Exception as exc: last_error = exc continue raise RuntimeError(f"DashScope ASR session.update failed: {last_error}") def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]: try: with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file: channel_count = wav_file.getnchannels() sample_width = wav_file.getsampwidth() sample_rate = wav_file.getframerate() compression = wav_file.getcomptype() pcm_frames = wav_file.readframes(wav_file.getnframes()) except wave.Error as exc: raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc if compression != "NONE": raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.") if sample_width != 2: raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.") if not pcm_frames: raise RuntimeError("Uploaded WAV file is empty") if channel_count <= 1: return pcm_frames, sample_rate samples = array("h") samples.frombytes(pcm_frames) if sys.byteorder == "big": samples.byteswap() mono_samples = array( "h", ( int(sum(samples[index:index + channel_count]) / channel_count) for index in range(0, len(samples), channel_count) ), ) if sys.byteorder == "big": mono_samples.byteswap() return mono_samples.tobytes(), sample_rate def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None: if not DASHSCOPE_SDK_AVAILABLE: hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`" detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}") callback = _DashScopePreviewCallback() if dashscope is not None: dashscope.api_key = api_key client = _create_dashscope_realtime_client( model=model, callback=callback, url=base_url, api_key=api_key, ) try: client.connect() callback.wait_for_open() _configure_dashscope_session( client=client, callback=callback, sample_rate=16000, language=language, ) finally: _close_dashscope_client(client) def _transcribe_dashscope_preview( *, audio_bytes: bytes, api_key: str, base_url: str, model: str, language: Optional[str], ) -> Dict[str, Any]: if not DASHSCOPE_SDK_AVAILABLE: hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`" detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}") pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes) callback = _DashScopePreviewCallback() if dashscope is not None: dashscope.api_key = api_key client = _create_dashscope_realtime_client( model=model, callback=callback, url=base_url, api_key=api_key, ) try: client.connect() callback.wait_for_open() _configure_dashscope_session( client=client, callback=callback, sample_rate=sample_rate, language=language, ) append_fn = getattr(client, "append_audio", None) if not callable(append_fn): raise RuntimeError("DashScope ASR SDK missing append_audio method") commit_fn = getattr(client, "commit", None) if not callable(commit_fn): raise RuntimeError("DashScope ASR SDK missing commit method") append_fn(base64.b64encode(pcm_audio).decode("ascii")) commit_fn() callback.wait_for_done() callback.raise_if_error() return { "transcript": callback.read_text(), "language": _dashscope_language(language) or "Multi-lingual", "confidence": None, } finally: _close_dashscope_client(client) # ============ ASR Models CRUD ============ @router.get("") def list_asr_models( language: Optional[str] = None, enabled: Optional[bool] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取ASR模型列表""" query = db.query(ASRModel) if language: query = query.filter(ASRModel.language == language) if enabled is not None: query = query.filter(ASRModel.enabled == enabled) total = query.count() models = query.order_by(ASRModel.created_at.desc()) \ .offset((page-1)*limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": models} @router.get("/{id}", response_model=ASRModelOut) def get_asr_model(id: str, db: Session = Depends(get_db)): """获取单个ASR模型详情""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") return model @router.post("", response_model=ASRModelOut) def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)): """创建ASR模型""" asr_model = ASRModel( id=unique_short_id("asr", db, ASRModel), user_id=1, # 默认用户 name=data.name, vendor=data.vendor, language=data.language, base_url=data.base_url, api_key=data.api_key, model_name=data.model_name, hotwords=data.hotwords, enable_punctuation=data.enable_punctuation, enable_normalization=data.enable_normalization, enabled=data.enabled, ) db.add(asr_model) db.commit() db.refresh(asr_model) return asr_model @router.put("/{id}", response_model=ASRModelOut) def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)): """更新ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(model, field, value) db.commit() db.refresh(model) return model @router.delete("/{id}") def delete_asr_model(id: str, db: Session = Depends(get_db)): """删除ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") db.delete(model) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/test", response_model=ASRTestResponse) def test_asr_model( id: str, request: Optional[ASRTestRequest] = None, db: Session = Depends(get_db) ): """测试ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") start_time = time.time() try: if _is_dashscope_vendor(model.vendor): effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip() if not effective_api_key: return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}") base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor) _probe_dashscope_asr_connection( api_key=effective_api_key, base_url=base_url, model=selected_model, language=model.language, ) latency_ms = int((time.time() - start_time) * 1000) return ASRTestResponse( success=True, language=model.language, latency_ms=latency_ms, message="DashScope realtime ASR connected", ) # 连接性测试优先,避免依赖真实音频输入 headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer": response = client.get(f"{model.base_url}/asr", headers=headers) elif model.vendor.lower() == "openai": response = client.get(f"{model.base_url}/audio/models", headers=headers) else: response = client.get(f"{model.base_url}/health", headers=headers) response.raise_for_status() raw_result = response.json() # 兼容不同供应商格式 if isinstance(raw_result, dict) and "results" in raw_result: result = raw_result elif isinstance(raw_result, dict) and "text" in raw_result: result = {"results": [{"transcript": raw_result.get("text", "")}]} else: result = {"results": [{"transcript": ""}]} latency_ms = int((time.time() - start_time) * 1000) # 解析结果 if result_data := result.get("results", [{}])[0]: transcript = result_data.get("transcript", "") return ASRTestResponse( success=True, transcript=transcript, language=result_data.get("language", model.language), confidence=result_data.get("confidence"), latency_ms=latency_ms, ) return ASRTestResponse( success=False, message="No transcript in response", latency_ms=latency_ms ) except httpx.HTTPStatusError as e: return ASRTestResponse( success=False, error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" ) except Exception as e: return ASRTestResponse( success=False, error=str(e)[:200] ) @router.post("/{id}/transcribe") def transcribe_audio( id: str, audio_url: Optional[str] = None, audio_data: Optional[str] = None, hotwords: Optional[List[str]] = None, db: Session = Depends(get_db) ): """转写音频""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") try: payload = { "model": model.model_name or "paraformer-v2", "input": {}, "parameters": { "hotwords": " ".join(hotwords or model.hotwords or []), "enable_punctuation": model.enable_punctuation, "enable_normalization": model.enable_normalization, } } headers = {"Authorization": f"Bearer {model.api_key}"} if audio_url: payload["input"]["url"] = audio_url elif audio_data: payload["input"]["file_urls"] = [] with httpx.Client(timeout=120.0) as client: response = client.post( f"{model.base_url}/asr", json=payload, headers=headers ) response.raise_for_status() result = response.json() if result_data := result.get("results", [{}])[0]: return { "success": True, "transcript": result_data.get("transcript", ""), "language": result_data.get("language", model.language), "confidence": result_data.get("confidence"), } return {"success": False, "error": "No transcript in response"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/{id}/preview", response_model=ASRTestResponse) async def preview_asr_model( id: str, file: UploadFile = File(...), language: Optional[str] = Form(None), api_key: Optional[str] = Form(None), db: Session = Depends(get_db), ): """预览 ASR:根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") if not file: raise HTTPException(status_code=400, detail="Audio file is required") filename = file.filename or "preview.wav" content_type = file.content_type or "application/octet-stream" if not content_type.startswith("audio/"): raise HTTPException(status_code=400, detail="Only audio files are supported") audio_bytes = await file.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Uploaded audio file is empty") effective_api_key = (api_key or "").strip() or (model.api_key or "").strip() if not effective_api_key: if _is_openai_compatible_vendor(model.vendor): effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() elif _is_dashscope_vendor(model.vendor): effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip() if not effective_api_key: raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}") base_url = (model.base_url or "").strip().rstrip("/") if _is_dashscope_vendor(model.vendor) and not base_url: base_url = DASHSCOPE_DEFAULT_BASE_URL if not base_url: raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}") selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor) effective_language = (language or "").strip() or None start_time = time.time() if _is_dashscope_vendor(model.vendor): try: payload = await asyncio.to_thread( _transcribe_dashscope_preview, audio_bytes=audio_bytes, api_key=effective_api_key, base_url=base_url, model=selected_model, language=effective_language or model.language, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc transcript = str(payload.get("transcript") or "") response_language = str(payload.get("language") or effective_language or model.language) latency_ms = int((time.time() - start_time) * 1000) return ASRTestResponse( success=bool(transcript), transcript=transcript, language=response_language, confidence=None, latency_ms=latency_ms, message=None if transcript else "No transcript in response", ) data = {"model": selected_model} if effective_language: data["language"] = effective_language if model.hotwords: data["prompt"] = " ".join(model.hotwords) headers = {"Authorization": f"Bearer {effective_api_key}"} files = {"file": (filename, audio_bytes, content_type)} try: with httpx.Client(timeout=90.0) as client: response = client.post( f"{base_url}/audio/transcriptions", headers=headers, data=data, files=files, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc if response.status_code != 200: detail = response.text try: detail_json = response.json() detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail except Exception: pass raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}") try: payload = response.json() except Exception: payload = {"text": response.text} transcript = "" response_language = model.language confidence = None if isinstance(payload, dict): transcript = str(payload.get("text") or payload.get("transcript") or "") response_language = str(payload.get("language") or effective_language or model.language) raw_confidence = payload.get("confidence") if raw_confidence is not None: try: confidence = float(raw_confidence) except (TypeError, ValueError): confidence = None latency_ms = int((time.time() - start_time) * 1000) return ASRTestResponse( success=bool(transcript), transcript=transcript, language=response_language, confidence=confidence, latency_ms=latency_ms, message=None if transcript else "No transcript in response", )