Update backend api
This commit is contained in:
268
api/app/routers/asr.py
Normal file
268
api/app/routers/asr.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import ASRModel
|
||||
from ..schemas import (
|
||||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||
ASRTestRequest, ASRTestResponse, ListResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
def list_asr_models(
|
||||
language: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取ASR模型列表"""
|
||||
query = db.query(ASRModel)
|
||||
|
||||
if language:
|
||||
query = query.filter(ASRModel.language == language)
|
||||
if enabled is not None:
|
||||
query = query.filter(ASRModel.enabled == enabled)
|
||||
|
||||
total = query.count()
|
||||
models = query.order_by(ASRModel.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=ASRModelOut)
|
||||
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个ASR模型详情"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=ASRModelOut)
|
||||
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||||
"""创建ASR模型"""
|
||||
asr_model = ASRModel(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
language=data.language,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
model_name=data.model_name,
|
||||
hotwords=data.hotwords,
|
||||
enable_punctuation=data.enable_punctuation,
|
||||
enable_normalization=data.enable_normalization,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(asr_model)
|
||||
db.commit()
|
||||
db.refresh(asr_model)
|
||||
return asr_model
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=ASRModelOut)
|
||||
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||||
"""更新ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""删除ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||||
def test_asr_model(
|
||||
id: str,
|
||||
request: Optional[ASRTestRequest] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 根据不同的厂商构造不同的请求
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
# SiliconFlow/Paraformer 格式
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有音频数据
|
||||
if request and request.audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
elif request and request.audio_url:
|
||||
payload["input"]["url"] = request.audio_url
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
elif model.vendor.lower() == "openai":
|
||||
# OpenAI Whisper 格式
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
# 准备文件
|
||||
files = {}
|
||||
if request and request.audio_data:
|
||||
audio_bytes = base64.b64decode(request.audio_data)
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
elif request and request.audio_url:
|
||||
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
else:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error="No audio data or URL provided"
|
||||
)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = {"results": [{"transcript": result.get("text", "")}]}
|
||||
|
||||
else:
|
||||
# 通用格式(可根据需要扩展)
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message=f"Unsupported vendor: {model.vendor}"
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 解析结果
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
transcript = result_data.get("transcript", "")
|
||||
return ASRTestResponse(
|
||||
success=True,
|
||||
transcript=transcript,
|
||||
language=result_data.get("language", model.language),
|
||||
confidence=result_data.get("confidence"),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message="No transcript in response",
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=str(e)[:200]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/transcribe")
|
||||
def transcribe_audio(
|
||||
id: str,
|
||||
audio_url: Optional[str] = None,
|
||||
audio_data: Optional[str] = None,
|
||||
hotwords: Optional[List[str]] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""转写音频"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
if audio_url:
|
||||
payload["input"]["url"] = audio_url
|
||||
elif audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": result_data.get("transcript", ""),
|
||||
"language": result_data.get("language", model.language),
|
||||
"confidence": result_data.get("confidence"),
|
||||
}
|
||||
|
||||
return {"success": False, "error": "No transcript in response"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user