Update asr library preview
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import ASRModel
|
||||
@@ -17,6 +16,18 @@ from ..schemas import (
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
SILICONFLOW_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
|
||||
|
||||
def _is_siliconflow_vendor(vendor: str) -> bool:
|
||||
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
|
||||
|
||||
|
||||
def _default_asr_model(vendor: str) -> str:
|
||||
if _is_siliconflow_vendor(vendor):
|
||||
return SILICONFLOW_DEFAULT_ASR_MODEL
|
||||
return "whisper-1"
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("")
|
||||
@@ -219,3 +230,99 @@ def transcribe_audio(
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{id}/preview", response_model=ASRTestResponse)
|
||||
async def preview_asr_model(
|
||||
id: str,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
api_key: Optional[str] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
if not file:
|
||||
raise HTTPException(status_code=400, detail="Audio file is required")
|
||||
|
||||
filename = file.filename or "preview.wav"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
if not content_type.startswith("audio/"):
|
||||
raise HTTPException(status_code=400, detail="Only audio files are supported")
|
||||
|
||||
audio_bytes = await file.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||||
|
||||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||||
if not effective_api_key and _is_siliconflow_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
base_url = (model.base_url or "").strip().rstrip("/")
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
|
||||
|
||||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||||
data = {"model": selected_model}
|
||||
effective_language = (language or "").strip() or None
|
||||
if effective_language:
|
||||
data["language"] = effective_language
|
||||
if model.hotwords:
|
||||
data["prompt"] = " ".join(model.hotwords)
|
||||
|
||||
headers = {"Authorization": f"Bearer {effective_api_key}"}
|
||||
files = {"file": (filename, audio_bytes, content_type)}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=90.0) as client:
|
||||
response = client.post(
|
||||
f"{base_url}/audio/transcriptions",
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc
|
||||
|
||||
if response.status_code != 200:
|
||||
detail = response.text
|
||||
try:
|
||||
detail_json = response.json()
|
||||
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}")
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
payload = {"text": response.text}
|
||||
|
||||
transcript = ""
|
||||
response_language = model.language
|
||||
confidence = None
|
||||
if isinstance(payload, dict):
|
||||
transcript = str(payload.get("text") or payload.get("transcript") or "")
|
||||
response_language = str(payload.get("language") or effective_language or model.language)
|
||||
raw_confidence = payload.get("confidence")
|
||||
if raw_confidence is not None:
|
||||
try:
|
||||
confidence = float(raw_confidence)
|
||||
except (TypeError, ValueError):
|
||||
confidence = None
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return ASRTestResponse(
|
||||
success=bool(transcript),
|
||||
transcript=transcript,
|
||||
language=response_language,
|
||||
confidence=confidence,
|
||||
latency_ms=latency_ms,
|
||||
message=None if transcript else "No transcript in response",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user