Update backend api

This commit is contained in:
Xin Wang
2026-02-08 15:52:16 +08:00
parent 727fe8a997
commit 7012f8edaf
15 changed files with 3436 additions and 19 deletions

268
api/app/routers/asr.py Normal file
View File

@@ -0,0 +1,268 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import httpx
import time
import base64
import json
from datetime import datetime
from ..db import get_db
from ..models import ASRModel
from ..schemas import (
ASRModelCreate, ASRModelUpdate, ASRModelOut,
ASRTestRequest, ASRTestResponse, ListResponse
)
router = APIRouter(prefix="/asr", tags=["ASR Models"])
# ============ ASR Models CRUD ============
@router.get("", response_model=ListResponse)
def list_asr_models(
language: Optional[str] = None,
enabled: Optional[bool] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取ASR模型列表"""
query = db.query(ASRModel)
if language:
query = query.filter(ASRModel.language == language)
if enabled is not None:
query = query.filter(ASRModel.enabled == enabled)
total = query.count()
models = query.order_by(ASRModel.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": models}
@router.get("/{id}", response_model=ASRModelOut)
def get_asr_model(id: str, db: Session = Depends(get_db)):
"""获取单个ASR模型详情"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
return model
@router.post("", response_model=ASRModelOut)
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
"""创建ASR模型"""
asr_model = ASRModel(
id=data.id or str(uuid.uuid4())[:8],
user_id=1, # 默认用户
name=data.name,
vendor=data.vendor,
language=data.language,
base_url=data.base_url,
api_key=data.api_key,
model_name=data.model_name,
hotwords=data.hotwords,
enable_punctuation=data.enable_punctuation,
enable_normalization=data.enable_normalization,
enabled=data.enabled,
)
db.add(asr_model)
db.commit()
db.refresh(asr_model)
return asr_model
@router.put("/{id}", response_model=ASRModelOut)
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
"""更新ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(model, field, value)
db.commit()
db.refresh(model)
return model
@router.delete("/{id}")
def delete_asr_model(id: str, db: Session = Depends(get_db)):
"""删除ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
db.delete(model)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/test", response_model=ASRTestResponse)
def test_asr_model(
id: str,
request: Optional[ASRTestRequest] = None,
db: Session = Depends(get_db)
):
"""测试ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
start_time = time.time()
try:
# 根据不同的厂商构造不同的请求
if model.vendor.lower() in ["siliconflow", "paraformer"]:
# SiliconFlow/Paraformer 格式
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
# 如果有音频数据
if request and request.audio_data:
payload["input"]["file_urls"] = []
elif request and request.audio_url:
payload["input"]["url"] = request.audio_url
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
elif model.vendor.lower() == "openai":
# OpenAI Whisper 格式
headers = {"Authorization": f"Bearer {model.api_key}"}
# 准备文件
files = {}
if request and request.audio_data:
audio_bytes = base64.b64decode(request.audio_data)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
elif request and request.audio_url:
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
else:
return ASRTestResponse(
success=False,
error="No audio data or URL provided"
)
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/audio/transcriptions",
files=files,
data=data,
headers=headers
)
response.raise_for_status()
result = response.json()
result = {"results": [{"transcript": result.get("text", "")}]}
else:
# 通用格式(可根据需要扩展)
return ASRTestResponse(
success=False,
message=f"Unsupported vendor: {model.vendor}"
)
latency_ms = int((time.time() - start_time) * 1000)
# 解析结果
if result_data := result.get("results", [{}])[0]:
transcript = result_data.get("transcript", "")
return ASRTestResponse(
success=True,
transcript=transcript,
language=result_data.get("language", model.language),
confidence=result_data.get("confidence"),
latency_ms=latency_ms,
)
return ASRTestResponse(
success=False,
message="No transcript in response",
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
return ASRTestResponse(
success=False,
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
return ASRTestResponse(
success=False,
error=str(e)[:200]
)
@router.post("/{id}/transcribe")
def transcribe_audio(
id: str,
audio_url: Optional[str] = None,
audio_data: Optional[str] = None,
hotwords: Optional[List[str]] = None,
db: Session = Depends(get_db)
):
"""转写音频"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
try:
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(hotwords or model.hotwords or []),
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
headers = {"Authorization": f"Bearer {model.api_key}"}
if audio_url:
payload["input"]["url"] = audio_url
elif audio_data:
payload["input"]["file_urls"] = []
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
if result_data := result.get("results", [{}])[0]:
return {
"success": True,
"transcript": result_data.get("transcript", ""),
"language": result_data.get("language", model.language),
"confidence": result_data.get("confidence"),
}
return {"success": False, "error": "No transcript in response"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))