Files
AI-VideoAssistant/api/app/routers/asr.py
2026-02-08 16:10:40 +08:00

222 lines
6.9 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import httpx
import time
import base64
import json
from datetime import datetime
from ..db import get_db
from ..models import ASRModel
from ..schemas import (
ASRModelCreate, ASRModelUpdate, ASRModelOut,
ASRTestRequest, ASRTestResponse
)
router = APIRouter(prefix="/asr", tags=["ASR Models"])
# ============ ASR Models CRUD ============
@router.get("")
def list_asr_models(
language: Optional[str] = None,
enabled: Optional[bool] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取ASR模型列表"""
query = db.query(ASRModel)
if language:
query = query.filter(ASRModel.language == language)
if enabled is not None:
query = query.filter(ASRModel.enabled == enabled)
total = query.count()
models = query.order_by(ASRModel.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": models}
@router.get("/{id}", response_model=ASRModelOut)
def get_asr_model(id: str, db: Session = Depends(get_db)):
"""获取单个ASR模型详情"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
return model
@router.post("", response_model=ASRModelOut)
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
"""创建ASR模型"""
asr_model = ASRModel(
id=data.id or str(uuid.uuid4())[:8],
user_id=1, # 默认用户
name=data.name,
vendor=data.vendor,
language=data.language,
base_url=data.base_url,
api_key=data.api_key,
model_name=data.model_name,
hotwords=data.hotwords,
enable_punctuation=data.enable_punctuation,
enable_normalization=data.enable_normalization,
enabled=data.enabled,
)
db.add(asr_model)
db.commit()
db.refresh(asr_model)
return asr_model
@router.put("/{id}", response_model=ASRModelOut)
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
"""更新ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(model, field, value)
db.commit()
db.refresh(model)
return model
@router.delete("/{id}")
def delete_asr_model(id: str, db: Session = Depends(get_db)):
"""删除ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
db.delete(model)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/test", response_model=ASRTestResponse)
def test_asr_model(
id: str,
request: Optional[ASRTestRequest] = None,
db: Session = Depends(get_db)
):
"""测试ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
start_time = time.time()
try:
# 连接性测试优先,避免依赖真实音频输入
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
response = client.get(f"{model.base_url}/asr", headers=headers)
elif model.vendor.lower() == "openai":
response = client.get(f"{model.base_url}/audio/models", headers=headers)
else:
response = client.get(f"{model.base_url}/health", headers=headers)
response.raise_for_status()
raw_result = response.json()
# 兼容不同供应商格式
if isinstance(raw_result, dict) and "results" in raw_result:
result = raw_result
elif isinstance(raw_result, dict) and "text" in raw_result:
result = {"results": [{"transcript": raw_result.get("text", "")}]}
else:
result = {"results": [{"transcript": ""}]}
latency_ms = int((time.time() - start_time) * 1000)
# 解析结果
if result_data := result.get("results", [{}])[0]:
transcript = result_data.get("transcript", "")
return ASRTestResponse(
success=True,
transcript=transcript,
language=result_data.get("language", model.language),
confidence=result_data.get("confidence"),
latency_ms=latency_ms,
)
return ASRTestResponse(
success=False,
message="No transcript in response",
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
return ASRTestResponse(
success=False,
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
return ASRTestResponse(
success=False,
error=str(e)[:200]
)
@router.post("/{id}/transcribe")
def transcribe_audio(
id: str,
audio_url: Optional[str] = None,
audio_data: Optional[str] = None,
hotwords: Optional[List[str]] = None,
db: Session = Depends(get_db)
):
"""转写音频"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
try:
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(hotwords or model.hotwords or []),
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
headers = {"Authorization": f"Bearer {model.api_key}"}
if audio_url:
payload["input"]["url"] = audio_url
elif audio_data:
payload["input"]["file_urls"] = []
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
if result_data := result.get("results", [{}])[0]:
return {
"success": True,
"transcript": result_data.get("transcript", ""),
"language": result_data.get("language", model.language),
"confidence": result_data.get("confidence"),
}
return {"success": False, "error": "No transcript in response"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))