from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import List, Optional import uuid import httpx import time import base64 import json from datetime import datetime from ..db import get_db from ..models import ASRModel from ..schemas import ( ASRModelCreate, ASRModelUpdate, ASRModelOut, ASRTestRequest, ASRTestResponse ) router = APIRouter(prefix="/asr", tags=["ASR Models"]) # ============ ASR Models CRUD ============ @router.get("") def list_asr_models( language: Optional[str] = None, enabled: Optional[bool] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取ASR模型列表""" query = db.query(ASRModel) if language: query = query.filter(ASRModel.language == language) if enabled is not None: query = query.filter(ASRModel.enabled == enabled) total = query.count() models = query.order_by(ASRModel.created_at.desc()) \ .offset((page-1)*limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": models} @router.get("/{id}", response_model=ASRModelOut) def get_asr_model(id: str, db: Session = Depends(get_db)): """获取单个ASR模型详情""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") return model @router.post("", response_model=ASRModelOut) def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)): """创建ASR模型""" asr_model = ASRModel( id=data.id or str(uuid.uuid4())[:8], user_id=1, # 默认用户 name=data.name, vendor=data.vendor, language=data.language, base_url=data.base_url, api_key=data.api_key, model_name=data.model_name, hotwords=data.hotwords, enable_punctuation=data.enable_punctuation, enable_normalization=data.enable_normalization, enabled=data.enabled, ) db.add(asr_model) db.commit() db.refresh(asr_model) return asr_model @router.put("/{id}", response_model=ASRModelOut) def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)): """更新ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(model, field, value) db.commit() db.refresh(model) return model @router.delete("/{id}") def delete_asr_model(id: str, db: Session = Depends(get_db)): """删除ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") db.delete(model) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/test", response_model=ASRTestResponse) def test_asr_model( id: str, request: Optional[ASRTestRequest] = None, db: Session = Depends(get_db) ): """测试ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") start_time = time.time() try: # 连接性测试优先,避免依赖真实音频输入 headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: if model.vendor.lower() in ["siliconflow", "paraformer"]: response = client.get(f"{model.base_url}/asr", headers=headers) elif model.vendor.lower() == "openai": response = client.get(f"{model.base_url}/audio/models", headers=headers) else: response = client.get(f"{model.base_url}/health", headers=headers) response.raise_for_status() raw_result = response.json() # 兼容不同供应商格式 if isinstance(raw_result, dict) and "results" in raw_result: result = raw_result elif isinstance(raw_result, dict) and "text" in raw_result: result = {"results": [{"transcript": raw_result.get("text", "")}]} else: result = {"results": [{"transcript": ""}]} latency_ms = int((time.time() - start_time) * 1000) # 解析结果 if result_data := result.get("results", [{}])[0]: transcript = result_data.get("transcript", "") return ASRTestResponse( success=True, transcript=transcript, language=result_data.get("language", model.language), confidence=result_data.get("confidence"), latency_ms=latency_ms, ) return ASRTestResponse( success=False, message="No transcript in response", latency_ms=latency_ms ) except httpx.HTTPStatusError as e: return ASRTestResponse( success=False, error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" ) except Exception as e: return ASRTestResponse( success=False, error=str(e)[:200] ) @router.post("/{id}/transcribe") def transcribe_audio( id: str, audio_url: Optional[str] = None, audio_data: Optional[str] = None, hotwords: Optional[List[str]] = None, db: Session = Depends(get_db) ): """转写音频""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") try: payload = { "model": model.model_name or "paraformer-v2", "input": {}, "parameters": { "hotwords": " ".join(hotwords or model.hotwords or []), "enable_punctuation": model.enable_punctuation, "enable_normalization": model.enable_normalization, } } headers = {"Authorization": f"Bearer {model.api_key}"} if audio_url: payload["input"]["url"] = audio_url elif audio_data: payload["input"]["file_urls"] = [] with httpx.Client(timeout=120.0) as client: response = client.post( f"{model.base_url}/asr", json=payload, headers=headers ) response.raise_for_status() result = response.json() if result_data := result.get("results", [{}])[0]: return { "success": True, "transcript": result_data.get("transcript", ""), "language": result_data.get("language", model.language), "confidence": result_data.get("confidence"), } return {"success": False, "error": "No transcript in response"} except Exception as e: raise HTTPException(status_code=500, detail=str(e))