335 lines
11 KiB
Python
335 lines
11 KiB
Python
import os
|
||
import time
|
||
from typing import List, Optional
|
||
|
||
import httpx
|
||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||
from sqlalchemy.orm import Session
|
||
|
||
from ..db import get_db
|
||
from ..id_generator import unique_short_id
|
||
from ..models import ASRModel
|
||
from ..schemas import (
|
||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||
ASRTestRequest, ASRTestResponse
|
||
)
|
||
|
||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||
|
||
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||
|
||
|
||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||
normalized = (vendor or "").strip().lower()
|
||
return normalized in {
|
||
"openai compatible",
|
||
"openai-compatible",
|
||
"siliconflow", # backward compatibility
|
||
"硅基流动", # backward compatibility
|
||
}
|
||
|
||
|
||
def _default_asr_model(vendor: str) -> str:
|
||
if _is_openai_compatible_vendor(vendor):
|
||
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
|
||
return "whisper-1"
|
||
|
||
|
||
# ============ ASR Models CRUD ============
|
||
@router.get("")
|
||
def list_asr_models(
|
||
language: Optional[str] = None,
|
||
enabled: Optional[bool] = None,
|
||
page: int = 1,
|
||
limit: int = 50,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取ASR模型列表"""
|
||
query = db.query(ASRModel)
|
||
|
||
if language:
|
||
query = query.filter(ASRModel.language == language)
|
||
if enabled is not None:
|
||
query = query.filter(ASRModel.enabled == enabled)
|
||
|
||
total = query.count()
|
||
models = query.order_by(ASRModel.created_at.desc()) \
|
||
.offset((page-1)*limit).limit(limit).all()
|
||
|
||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||
|
||
|
||
@router.get("/{id}", response_model=ASRModelOut)
|
||
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||
"""获取单个ASR模型详情"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
return model
|
||
|
||
|
||
@router.post("", response_model=ASRModelOut)
|
||
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||
"""创建ASR模型"""
|
||
asr_model = ASRModel(
|
||
id=unique_short_id("asr", db, ASRModel),
|
||
user_id=1, # 默认用户
|
||
name=data.name,
|
||
vendor=data.vendor,
|
||
language=data.language,
|
||
base_url=data.base_url,
|
||
api_key=data.api_key,
|
||
model_name=data.model_name,
|
||
hotwords=data.hotwords,
|
||
enable_punctuation=data.enable_punctuation,
|
||
enable_normalization=data.enable_normalization,
|
||
enabled=data.enabled,
|
||
)
|
||
db.add(asr_model)
|
||
db.commit()
|
||
db.refresh(asr_model)
|
||
return asr_model
|
||
|
||
|
||
@router.put("/{id}", response_model=ASRModelOut)
|
||
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||
"""更新ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
update_data = data.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(model, field, value)
|
||
|
||
db.commit()
|
||
db.refresh(model)
|
||
return model
|
||
|
||
|
||
@router.delete("/{id}")
|
||
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||
"""删除ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
db.delete(model)
|
||
db.commit()
|
||
return {"message": "Deleted successfully"}
|
||
|
||
|
||
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||
def test_asr_model(
|
||
id: str,
|
||
request: Optional[ASRTestRequest] = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""测试ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 连接性测试优先,避免依赖真实音频输入
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
with httpx.Client(timeout=60.0) as client:
|
||
if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer":
|
||
response = client.get(f"{model.base_url}/asr", headers=headers)
|
||
elif model.vendor.lower() == "openai":
|
||
response = client.get(f"{model.base_url}/audio/models", headers=headers)
|
||
else:
|
||
response = client.get(f"{model.base_url}/health", headers=headers)
|
||
response.raise_for_status()
|
||
raw_result = response.json()
|
||
|
||
# 兼容不同供应商格式
|
||
if isinstance(raw_result, dict) and "results" in raw_result:
|
||
result = raw_result
|
||
elif isinstance(raw_result, dict) and "text" in raw_result:
|
||
result = {"results": [{"transcript": raw_result.get("text", "")}]}
|
||
else:
|
||
result = {"results": [{"transcript": ""}]}
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 解析结果
|
||
if result_data := result.get("results", [{}])[0]:
|
||
transcript = result_data.get("transcript", "")
|
||
return ASRTestResponse(
|
||
success=True,
|
||
transcript=transcript,
|
||
language=result_data.get("language", model.language),
|
||
confidence=result_data.get("confidence"),
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
return ASRTestResponse(
|
||
success=False,
|
||
message="No transcript in response",
|
||
latency_ms=latency_ms
|
||
)
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
return ASRTestResponse(
|
||
success=False,
|
||
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||
)
|
||
except Exception as e:
|
||
return ASRTestResponse(
|
||
success=False,
|
||
error=str(e)[:200]
|
||
)
|
||
|
||
|
||
@router.post("/{id}/transcribe")
|
||
def transcribe_audio(
|
||
id: str,
|
||
audio_url: Optional[str] = None,
|
||
audio_data: Optional[str] = None,
|
||
hotwords: Optional[List[str]] = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""转写音频"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
try:
|
||
payload = {
|
||
"model": model.model_name or "paraformer-v2",
|
||
"input": {},
|
||
"parameters": {
|
||
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||
"enable_punctuation": model.enable_punctuation,
|
||
"enable_normalization": model.enable_normalization,
|
||
}
|
||
}
|
||
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
if audio_url:
|
||
payload["input"]["url"] = audio_url
|
||
elif audio_data:
|
||
payload["input"]["file_urls"] = []
|
||
|
||
with httpx.Client(timeout=120.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/asr",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result = response.json()
|
||
|
||
if result_data := result.get("results", [{}])[0]:
|
||
return {
|
||
"success": True,
|
||
"transcript": result_data.get("transcript", ""),
|
||
"language": result_data.get("language", model.language),
|
||
"confidence": result_data.get("confidence"),
|
||
}
|
||
|
||
return {"success": False, "error": "No transcript in response"}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/{id}/preview", response_model=ASRTestResponse)
|
||
async def preview_asr_model(
|
||
id: str,
|
||
file: UploadFile = File(...),
|
||
language: Optional[str] = Form(None),
|
||
api_key: Optional[str] = Form(None),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
if not file:
|
||
raise HTTPException(status_code=400, detail="Audio file is required")
|
||
|
||
filename = file.filename or "preview.wav"
|
||
content_type = file.content_type or "application/octet-stream"
|
||
if not content_type.startswith("audio/"):
|
||
raise HTTPException(status_code=400, detail="Only audio files are supported")
|
||
|
||
audio_bytes = await file.read()
|
||
if not audio_bytes:
|
||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||
|
||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
|
||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||
if not effective_api_key:
|
||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||
|
||
base_url = (model.base_url or "").strip().rstrip("/")
|
||
if not base_url:
|
||
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
|
||
|
||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||
data = {"model": selected_model}
|
||
effective_language = (language or "").strip() or None
|
||
if effective_language:
|
||
data["language"] = effective_language
|
||
if model.hotwords:
|
||
data["prompt"] = " ".join(model.hotwords)
|
||
|
||
headers = {"Authorization": f"Bearer {effective_api_key}"}
|
||
files = {"file": (filename, audio_bytes, content_type)}
|
||
|
||
start_time = time.time()
|
||
try:
|
||
with httpx.Client(timeout=90.0) as client:
|
||
response = client.post(
|
||
f"{base_url}/audio/transcriptions",
|
||
headers=headers,
|
||
data=data,
|
||
files=files,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc
|
||
|
||
if response.status_code != 200:
|
||
detail = response.text
|
||
try:
|
||
detail_json = response.json()
|
||
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
||
except Exception:
|
||
pass
|
||
raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}")
|
||
|
||
try:
|
||
payload = response.json()
|
||
except Exception:
|
||
payload = {"text": response.text}
|
||
|
||
transcript = ""
|
||
response_language = model.language
|
||
confidence = None
|
||
if isinstance(payload, dict):
|
||
transcript = str(payload.get("text") or payload.get("transcript") or "")
|
||
response_language = str(payload.get("language") or effective_language or model.language)
|
||
raw_confidence = payload.get("confidence")
|
||
if raw_confidence is not None:
|
||
try:
|
||
confidence = float(raw_confidence)
|
||
except (TypeError, ValueError):
|
||
confidence = None
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
return ASRTestResponse(
|
||
success=bool(transcript),
|
||
transcript=transcript,
|
||
language=response_language,
|
||
confidence=confidence,
|
||
latency_ms=latency_ms,
|
||
message=None if transcript else "No transcript in response",
|
||
)
|