import os import time import uuid from typing import List, Optional import httpx from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from sqlalchemy.orm import Session from ..db import get_db from ..models import ASRModel from ..schemas import ( ASRModelCreate, ASRModelUpdate, ASRModelOut, ASRTestRequest, ASRTestResponse ) router = APIRouter(prefix="/asr", tags=["ASR Models"]) SILICONFLOW_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" def _is_siliconflow_vendor(vendor: str) -> bool: return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"} def _default_asr_model(vendor: str) -> str: if _is_siliconflow_vendor(vendor): return SILICONFLOW_DEFAULT_ASR_MODEL return "whisper-1" # ============ ASR Models CRUD ============ @router.get("") def list_asr_models( language: Optional[str] = None, enabled: Optional[bool] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取ASR模型列表""" query = db.query(ASRModel) if language: query = query.filter(ASRModel.language == language) if enabled is not None: query = query.filter(ASRModel.enabled == enabled) total = query.count() models = query.order_by(ASRModel.created_at.desc()) \ .offset((page-1)*limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": models} @router.get("/{id}", response_model=ASRModelOut) def get_asr_model(id: str, db: Session = Depends(get_db)): """获取单个ASR模型详情""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") return model @router.post("", response_model=ASRModelOut) def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)): """创建ASR模型""" asr_model = ASRModel( id=data.id or str(uuid.uuid4())[:8], user_id=1, # 默认用户 name=data.name, vendor=data.vendor, language=data.language, base_url=data.base_url, api_key=data.api_key, model_name=data.model_name, hotwords=data.hotwords, enable_punctuation=data.enable_punctuation, enable_normalization=data.enable_normalization, enabled=data.enabled, ) db.add(asr_model) db.commit() db.refresh(asr_model) return asr_model @router.put("/{id}", response_model=ASRModelOut) def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)): """更新ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(model, field, value) db.commit() db.refresh(model) return model @router.delete("/{id}") def delete_asr_model(id: str, db: Session = Depends(get_db)): """删除ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") db.delete(model) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/test", response_model=ASRTestResponse) def test_asr_model( id: str, request: Optional[ASRTestRequest] = None, db: Session = Depends(get_db) ): """测试ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") start_time = time.time() try: # 连接性测试优先,避免依赖真实音频输入 headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: if model.vendor.lower() in ["siliconflow", "paraformer"]: response = client.get(f"{model.base_url}/asr", headers=headers) elif model.vendor.lower() == "openai": response = client.get(f"{model.base_url}/audio/models", headers=headers) else: response = client.get(f"{model.base_url}/health", headers=headers) response.raise_for_status() raw_result = response.json() # 兼容不同供应商格式 if isinstance(raw_result, dict) and "results" in raw_result: result = raw_result elif isinstance(raw_result, dict) and "text" in raw_result: result = {"results": [{"transcript": raw_result.get("text", "")}]} else: result = {"results": [{"transcript": ""}]} latency_ms = int((time.time() - start_time) * 1000) # 解析结果 if result_data := result.get("results", [{}])[0]: transcript = result_data.get("transcript", "") return ASRTestResponse( success=True, transcript=transcript, language=result_data.get("language", model.language), confidence=result_data.get("confidence"), latency_ms=latency_ms, ) return ASRTestResponse( success=False, message="No transcript in response", latency_ms=latency_ms ) except httpx.HTTPStatusError as e: return ASRTestResponse( success=False, error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" ) except Exception as e: return ASRTestResponse( success=False, error=str(e)[:200] ) @router.post("/{id}/transcribe") def transcribe_audio( id: str, audio_url: Optional[str] = None, audio_data: Optional[str] = None, hotwords: Optional[List[str]] = None, db: Session = Depends(get_db) ): """转写音频""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") try: payload = { "model": model.model_name or "paraformer-v2", "input": {}, "parameters": { "hotwords": " ".join(hotwords or model.hotwords or []), "enable_punctuation": model.enable_punctuation, "enable_normalization": model.enable_normalization, } } headers = {"Authorization": f"Bearer {model.api_key}"} if audio_url: payload["input"]["url"] = audio_url elif audio_data: payload["input"]["file_urls"] = [] with httpx.Client(timeout=120.0) as client: response = client.post( f"{model.base_url}/asr", json=payload, headers=headers ) response.raise_for_status() result = response.json() if result_data := result.get("results", [{}])[0]: return { "success": True, "transcript": result_data.get("transcript", ""), "language": result_data.get("language", model.language), "confidence": result_data.get("confidence"), } return {"success": False, "error": "No transcript in response"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/{id}/preview", response_model=ASRTestResponse) async def preview_asr_model( id: str, file: UploadFile = File(...), language: Optional[str] = Form(None), api_key: Optional[str] = Form(None), db: Session = Depends(get_db), ): """预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") if not file: raise HTTPException(status_code=400, detail="Audio file is required") filename = file.filename or "preview.wav" content_type = file.content_type or "application/octet-stream" if not content_type.startswith("audio/"): raise HTTPException(status_code=400, detail="Only audio files are supported") audio_bytes = await file.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Uploaded audio file is empty") effective_api_key = (api_key or "").strip() or (model.api_key or "").strip() if not effective_api_key and _is_siliconflow_vendor(model.vendor): effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() if not effective_api_key: raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}") base_url = (model.base_url or "").strip().rstrip("/") if not base_url: raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}") selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor) data = {"model": selected_model} effective_language = (language or "").strip() or None if effective_language: data["language"] = effective_language if model.hotwords: data["prompt"] = " ".join(model.hotwords) headers = {"Authorization": f"Bearer {effective_api_key}"} files = {"file": (filename, audio_bytes, content_type)} start_time = time.time() try: with httpx.Client(timeout=90.0) as client: response = client.post( f"{base_url}/audio/transcriptions", headers=headers, data=data, files=files, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc if response.status_code != 200: detail = response.text try: detail_json = response.json() detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail except Exception: pass raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}") try: payload = response.json() except Exception: payload = {"text": response.text} transcript = "" response_language = model.language confidence = None if isinstance(payload, dict): transcript = str(payload.get("text") or payload.get("transcript") or "") response_language = str(payload.get("language") or effective_language or model.language) raw_confidence = payload.get("confidence") if raw_confidence is not None: try: confidence = float(raw_confidence) except (TypeError, ValueError): confidence = None latency_ms = int((time.time() - start_time) * 1000) return ASRTestResponse( success=bool(transcript), transcript=transcript, language=response_language, confidence=confidence, latency_ms=latency_ms, message=None if transcript else "No transcript in response", )