from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import List, Optional import uuid import httpx import time import base64 import json from datetime import datetime from ..db import get_db from ..models import ASRModel from ..schemas import ( ASRModelCreate, ASRModelUpdate, ASRModelOut, ASRTestRequest, ASRTestResponse, ListResponse ) router = APIRouter(prefix="/asr", tags=["ASR Models"]) # ============ ASR Models CRUD ============ @router.get("", response_model=ListResponse) def list_asr_models( language: Optional[str] = None, enabled: Optional[bool] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取ASR模型列表""" query = db.query(ASRModel) if language: query = query.filter(ASRModel.language == language) if enabled is not None: query = query.filter(ASRModel.enabled == enabled) total = query.count() models = query.order_by(ASRModel.created_at.desc()) \ .offset((page-1)*limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": models} @router.get("/{id}", response_model=ASRModelOut) def get_asr_model(id: str, db: Session = Depends(get_db)): """获取单个ASR模型详情""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") return model @router.post("", response_model=ASRModelOut) def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)): """创建ASR模型""" asr_model = ASRModel( id=data.id or str(uuid.uuid4())[:8], user_id=1, # 默认用户 name=data.name, vendor=data.vendor, language=data.language, base_url=data.base_url, api_key=data.api_key, model_name=data.model_name, hotwords=data.hotwords, enable_punctuation=data.enable_punctuation, enable_normalization=data.enable_normalization, enabled=data.enabled, ) db.add(asr_model) db.commit() db.refresh(asr_model) return asr_model @router.put("/{id}", response_model=ASRModelOut) def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)): """更新ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(model, field, value) db.commit() db.refresh(model) return model @router.delete("/{id}") def delete_asr_model(id: str, db: Session = Depends(get_db)): """删除ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") db.delete(model) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/test", response_model=ASRTestResponse) def test_asr_model( id: str, request: Optional[ASRTestRequest] = None, db: Session = Depends(get_db) ): """测试ASR模型""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") start_time = time.time() try: # 根据不同的厂商构造不同的请求 if model.vendor.lower() in ["siliconflow", "paraformer"]: # SiliconFlow/Paraformer 格式 payload = { "model": model.model_name or "paraformer-v2", "input": {}, "parameters": { "hotwords": " ".join(model.hotwords) if model.hotwords else "", "enable_punctuation": model.enable_punctuation, "enable_normalization": model.enable_normalization, } } # 如果有音频数据 if request and request.audio_data: payload["input"]["file_urls"] = [] elif request and request.audio_url: payload["input"]["url"] = request.audio_url headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: response = client.post( f"{model.base_url}/asr", json=payload, headers=headers ) response.raise_for_status() result = response.json() elif model.vendor.lower() == "openai": # OpenAI Whisper 格式 headers = {"Authorization": f"Bearer {model.api_key}"} # 准备文件 files = {} if request and request.audio_data: audio_bytes = base64.b64decode(request.audio_data) files = {"file": ("audio.wav", audio_bytes, "audio/wav")} data = {"model": model.model_name or "whisper-1"} elif request and request.audio_url: files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")} data = {"model": model.model_name or "whisper-1"} else: return ASRTestResponse( success=False, error="No audio data or URL provided" ) with httpx.Client(timeout=60.0) as client: response = client.post( f"{model.base_url}/audio/transcriptions", files=files, data=data, headers=headers ) response.raise_for_status() result = response.json() result = {"results": [{"transcript": result.get("text", "")}]} else: # 通用格式(可根据需要扩展) return ASRTestResponse( success=False, message=f"Unsupported vendor: {model.vendor}" ) latency_ms = int((time.time() - start_time) * 1000) # 解析结果 if result_data := result.get("results", [{}])[0]: transcript = result_data.get("transcript", "") return ASRTestResponse( success=True, transcript=transcript, language=result_data.get("language", model.language), confidence=result_data.get("confidence"), latency_ms=latency_ms, ) return ASRTestResponse( success=False, message="No transcript in response", latency_ms=latency_ms ) except httpx.HTTPStatusError as e: return ASRTestResponse( success=False, error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" ) except Exception as e: return ASRTestResponse( success=False, error=str(e)[:200] ) @router.post("/{id}/transcribe") def transcribe_audio( id: str, audio_url: Optional[str] = None, audio_data: Optional[str] = None, hotwords: Optional[List[str]] = None, db: Session = Depends(get_db) ): """转写音频""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") try: payload = { "model": model.model_name or "paraformer-v2", "input": {}, "parameters": { "hotwords": " ".join(hotwords or model.hotwords or []), "enable_punctuation": model.enable_punctuation, "enable_normalization": model.enable_normalization, } } headers = {"Authorization": f"Bearer {model.api_key}"} if audio_url: payload["input"]["url"] = audio_url elif audio_data: payload["input"]["file_urls"] = [] with httpx.Client(timeout=120.0) as client: response = client.post( f"{model.base_url}/asr", json=payload, headers=headers ) response.raise_for_status() result = response.json() if result_data := result.get("results", [{}])[0]: return { "success": True, "transcript": result_data.get("transcript", ""), "language": result_data.get("language", model.language), "confidence": result_data.get("confidence"), } return {"success": False, "error": "No transcript in response"} except Exception as e: raise HTTPException(status_code=500, detail=str(e))