- Introduced DashScope as a new ASR model in the database initialization. - Updated ASRModel schema to include vendor information. - Enhanced ASR router to support DashScope-specific functionality, including connection testing and preview capabilities. - Modified frontend components to accommodate DashScope as a selectable vendor with appropriate default settings. - Added tests to validate DashScope ASR model creation, updates, and connectivity. - Updated backend API to handle DashScope-specific base URLs and vendor normalization.
786 lines
26 KiB
Python
786 lines
26 KiB
Python
import asyncio
|
||
import base64
|
||
import io
|
||
import json
|
||
import os
|
||
import sys
|
||
import threading
|
||
import time
|
||
import wave
|
||
from array import array
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import httpx
|
||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||
from sqlalchemy.orm import Session
|
||
|
||
from ..db import get_db
|
||
from ..id_generator import unique_short_id
|
||
from ..models import ASRModel
|
||
from ..schemas import (
|
||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||
ASRTestRequest, ASRTestResponse
|
||
)
|
||
|
||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||
|
||
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||
DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||
|
||
try:
|
||
import dashscope
|
||
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
||
|
||
try:
|
||
from dashscope.audio.qwen_omni import TranscriptionParams
|
||
except ImportError:
|
||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
||
|
||
DASHSCOPE_SDK_AVAILABLE = True
|
||
DASHSCOPE_IMPORT_ERROR = ""
|
||
except Exception as exc:
|
||
dashscope = None # type: ignore[assignment]
|
||
MultiModality = None # type: ignore[assignment]
|
||
OmniRealtimeConversation = None # type: ignore[assignment]
|
||
TranscriptionParams = None # type: ignore[assignment]
|
||
DASHSCOPE_SDK_AVAILABLE = False
|
||
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
||
|
||
class OmniRealtimeCallback: # type: ignore[no-redef]
|
||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||
|
||
pass
|
||
|
||
|
||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||
normalized = (vendor or "").strip().lower()
|
||
return normalized in {
|
||
"openai compatible",
|
||
"openai-compatible",
|
||
"siliconflow", # backward compatibility
|
||
"硅基流动", # backward compatibility
|
||
}
|
||
|
||
|
||
def _is_dashscope_vendor(vendor: str) -> bool:
|
||
return (vendor or "").strip().lower() == "dashscope"
|
||
|
||
|
||
def _default_asr_model(vendor: str) -> str:
|
||
if _is_openai_compatible_vendor(vendor):
|
||
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
|
||
if _is_dashscope_vendor(vendor):
|
||
return DASHSCOPE_DEFAULT_ASR_MODEL
|
||
return "whisper-1"
|
||
|
||
|
||
def _dashscope_language(language: Optional[str]) -> Optional[str]:
|
||
normalized = (language or "").strip().lower()
|
||
if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}:
|
||
return None
|
||
if normalized.startswith("zh"):
|
||
return "zh"
|
||
if normalized.startswith("en"):
|
||
return "en"
|
||
return normalized
|
||
|
||
|
||
class _DashScopePreviewCallback(OmniRealtimeCallback):
|
||
"""Collect DashScope ASR websocket events for preview/test flows."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self._open_event = threading.Event()
|
||
self._session_ready_event = threading.Event()
|
||
self._done_event = threading.Event()
|
||
self._lock = threading.Lock()
|
||
self._final_text = ""
|
||
self._last_interim_text = ""
|
||
self._error_message: Optional[str] = None
|
||
|
||
def on_open(self) -> None:
|
||
self._open_event.set()
|
||
|
||
def on_close(self, code: int, reason: str) -> None:
|
||
if self._done_event.is_set():
|
||
return
|
||
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
|
||
self._done_event.set()
|
||
self._session_ready_event.set()
|
||
|
||
def on_error(self, message: Any) -> None:
|
||
self._error_message = str(message)
|
||
self._done_event.set()
|
||
self._session_ready_event.set()
|
||
|
||
def on_event(self, response: Any) -> None:
|
||
payload = _coerce_dashscope_event(response)
|
||
event_type = str(payload.get("type") or "").strip()
|
||
if not event_type:
|
||
return
|
||
|
||
if event_type in {"session.created", "session.updated"}:
|
||
self._session_ready_event.set()
|
||
return
|
||
|
||
if event_type == "error" or event_type.endswith(".failed"):
|
||
self._error_message = _format_dashscope_error_event(payload)
|
||
self._done_event.set()
|
||
self._session_ready_event.set()
|
||
return
|
||
|
||
if event_type == "conversation.item.input_audio_transcription.text":
|
||
interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript"))
|
||
if interim_text:
|
||
with self._lock:
|
||
self._last_interim_text = interim_text
|
||
return
|
||
|
||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||
final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash"))
|
||
with self._lock:
|
||
if final_text:
|
||
self._final_text = final_text
|
||
self._done_event.set()
|
||
return
|
||
|
||
if event_type in {"response.done", "session.finished"}:
|
||
self._done_event.set()
|
||
|
||
def wait_for_open(self, timeout: float = 10.0) -> None:
|
||
if not self._open_event.wait(timeout):
|
||
raise TimeoutError("DashScope websocket open timeout")
|
||
|
||
def wait_for_session_ready(self, timeout: float = 6.0) -> bool:
|
||
return self._session_ready_event.wait(timeout)
|
||
|
||
def wait_for_done(self, timeout: float = 20.0) -> None:
|
||
if not self._done_event.wait(timeout):
|
||
raise TimeoutError("DashScope transcription timeout")
|
||
|
||
def raise_if_error(self) -> None:
|
||
if self._error_message:
|
||
raise RuntimeError(self._error_message)
|
||
|
||
def read_text(self) -> str:
|
||
with self._lock:
|
||
return self._final_text or self._last_interim_text
|
||
|
||
|
||
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
|
||
if isinstance(response, dict):
|
||
return response
|
||
if isinstance(response, str):
|
||
try:
|
||
parsed = json.loads(response)
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
except json.JSONDecodeError:
|
||
pass
|
||
return {"type": "raw", "message": str(response)}
|
||
|
||
|
||
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
|
||
error = payload.get("error")
|
||
if isinstance(error, dict):
|
||
code = str(error.get("code") or "").strip()
|
||
message = str(error.get("message") or "").strip()
|
||
if code and message:
|
||
return f"{code}: {message}"
|
||
return message or str(error)
|
||
return str(error or "DashScope realtime ASR error")
|
||
|
||
|
||
def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str:
|
||
for key in keys:
|
||
value = payload.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value.strip()
|
||
if isinstance(value, dict):
|
||
nested = _extract_dashscope_text(value, keys=keys)
|
||
if nested:
|
||
return nested
|
||
|
||
for value in payload.values():
|
||
if isinstance(value, dict):
|
||
nested = _extract_dashscope_text(value, keys=keys)
|
||
if nested:
|
||
return nested
|
||
return ""
|
||
|
||
|
||
def _create_dashscope_realtime_client(
|
||
*,
|
||
model: str,
|
||
callback: _DashScopePreviewCallback,
|
||
url: str,
|
||
api_key: str,
|
||
) -> Any:
|
||
if OmniRealtimeConversation is None:
|
||
raise RuntimeError("DashScope SDK unavailable")
|
||
|
||
init_kwargs = {
|
||
"model": model,
|
||
"callback": callback,
|
||
"url": url,
|
||
}
|
||
try:
|
||
return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc]
|
||
except TypeError as exc:
|
||
if "api_key" not in str(exc):
|
||
raise
|
||
return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc]
|
||
|
||
|
||
def _close_dashscope_client(client: Any) -> None:
|
||
finish_fn = getattr(client, "finish", None)
|
||
if callable(finish_fn):
|
||
try:
|
||
finish_fn()
|
||
except Exception:
|
||
pass
|
||
|
||
close_fn = getattr(client, "close", None)
|
||
if callable(close_fn):
|
||
try:
|
||
close_fn()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def _configure_dashscope_session(
|
||
*,
|
||
client: Any,
|
||
callback: _DashScopePreviewCallback,
|
||
sample_rate: int,
|
||
language: Optional[str],
|
||
) -> None:
|
||
update_fn = getattr(client, "update_session", None)
|
||
if not callable(update_fn):
|
||
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
||
|
||
text_modality: Any = "text"
|
||
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
||
text_modality = MultiModality.TEXT
|
||
|
||
transcription_params: Optional[Any] = None
|
||
language_hint = _dashscope_language(language)
|
||
if TranscriptionParams is not None:
|
||
try:
|
||
params_kwargs: Dict[str, Any] = {
|
||
"sample_rate": sample_rate,
|
||
"input_audio_format": "pcm",
|
||
}
|
||
if language_hint:
|
||
params_kwargs["language"] = language_hint
|
||
transcription_params = TranscriptionParams(**params_kwargs)
|
||
except Exception:
|
||
transcription_params = None
|
||
|
||
update_attempts = [
|
||
{
|
||
"output_modalities": [text_modality],
|
||
"enable_turn_detection": False,
|
||
"enable_input_audio_transcription": True,
|
||
"transcription_params": transcription_params,
|
||
},
|
||
{
|
||
"output_modalities": [text_modality],
|
||
"enable_turn_detection": False,
|
||
"enable_input_audio_transcription": True,
|
||
},
|
||
{
|
||
"output_modalities": [text_modality],
|
||
},
|
||
]
|
||
|
||
last_error: Optional[Exception] = None
|
||
for params in update_attempts:
|
||
if params.get("transcription_params") is None:
|
||
params = {key: value for key, value in params.items() if key != "transcription_params"}
|
||
try:
|
||
update_fn(**params)
|
||
callback.wait_for_session_ready()
|
||
callback.raise_if_error()
|
||
return
|
||
except TypeError as exc:
|
||
last_error = exc
|
||
continue
|
||
except Exception as exc:
|
||
last_error = exc
|
||
continue
|
||
|
||
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
||
|
||
|
||
def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]:
|
||
try:
|
||
with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file:
|
||
channel_count = wav_file.getnchannels()
|
||
sample_width = wav_file.getsampwidth()
|
||
sample_rate = wav_file.getframerate()
|
||
compression = wav_file.getcomptype()
|
||
pcm_frames = wav_file.readframes(wav_file.getnframes())
|
||
except wave.Error as exc:
|
||
raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc
|
||
|
||
if compression != "NONE":
|
||
raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.")
|
||
if sample_width != 2:
|
||
raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.")
|
||
if not pcm_frames:
|
||
raise RuntimeError("Uploaded WAV file is empty")
|
||
if channel_count <= 1:
|
||
return pcm_frames, sample_rate
|
||
|
||
samples = array("h")
|
||
samples.frombytes(pcm_frames)
|
||
if sys.byteorder == "big":
|
||
samples.byteswap()
|
||
|
||
mono_samples = array(
|
||
"h",
|
||
(
|
||
int(sum(samples[index:index + channel_count]) / channel_count)
|
||
for index in range(0, len(samples), channel_count)
|
||
),
|
||
)
|
||
if sys.byteorder == "big":
|
||
mono_samples.byteswap()
|
||
return mono_samples.tobytes(), sample_rate
|
||
|
||
|
||
def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None:
|
||
if not DASHSCOPE_SDK_AVAILABLE:
|
||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||
|
||
callback = _DashScopePreviewCallback()
|
||
if dashscope is not None:
|
||
dashscope.api_key = api_key
|
||
client = _create_dashscope_realtime_client(
|
||
model=model,
|
||
callback=callback,
|
||
url=base_url,
|
||
api_key=api_key,
|
||
)
|
||
|
||
try:
|
||
client.connect()
|
||
callback.wait_for_open()
|
||
_configure_dashscope_session(
|
||
client=client,
|
||
callback=callback,
|
||
sample_rate=16000,
|
||
language=language,
|
||
)
|
||
finally:
|
||
_close_dashscope_client(client)
|
||
|
||
|
||
def _transcribe_dashscope_preview(
|
||
*,
|
||
audio_bytes: bytes,
|
||
api_key: str,
|
||
base_url: str,
|
||
model: str,
|
||
language: Optional[str],
|
||
) -> Dict[str, Any]:
|
||
if not DASHSCOPE_SDK_AVAILABLE:
|
||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||
|
||
pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes)
|
||
callback = _DashScopePreviewCallback()
|
||
if dashscope is not None:
|
||
dashscope.api_key = api_key
|
||
client = _create_dashscope_realtime_client(
|
||
model=model,
|
||
callback=callback,
|
||
url=base_url,
|
||
api_key=api_key,
|
||
)
|
||
|
||
try:
|
||
client.connect()
|
||
callback.wait_for_open()
|
||
_configure_dashscope_session(
|
||
client=client,
|
||
callback=callback,
|
||
sample_rate=sample_rate,
|
||
language=language,
|
||
)
|
||
|
||
append_fn = getattr(client, "append_audio", None)
|
||
if not callable(append_fn):
|
||
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
||
commit_fn = getattr(client, "commit", None)
|
||
if not callable(commit_fn):
|
||
raise RuntimeError("DashScope ASR SDK missing commit method")
|
||
|
||
append_fn(base64.b64encode(pcm_audio).decode("ascii"))
|
||
commit_fn()
|
||
callback.wait_for_done()
|
||
callback.raise_if_error()
|
||
return {
|
||
"transcript": callback.read_text(),
|
||
"language": _dashscope_language(language) or "Multi-lingual",
|
||
"confidence": None,
|
||
}
|
||
finally:
|
||
_close_dashscope_client(client)
|
||
|
||
|
||
# ============ ASR Models CRUD ============
|
||
@router.get("")
|
||
def list_asr_models(
|
||
language: Optional[str] = None,
|
||
enabled: Optional[bool] = None,
|
||
page: int = 1,
|
||
limit: int = 50,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取ASR模型列表"""
|
||
query = db.query(ASRModel)
|
||
|
||
if language:
|
||
query = query.filter(ASRModel.language == language)
|
||
if enabled is not None:
|
||
query = query.filter(ASRModel.enabled == enabled)
|
||
|
||
total = query.count()
|
||
models = query.order_by(ASRModel.created_at.desc()) \
|
||
.offset((page-1)*limit).limit(limit).all()
|
||
|
||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||
|
||
|
||
@router.get("/{id}", response_model=ASRModelOut)
|
||
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||
"""获取单个ASR模型详情"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
return model
|
||
|
||
|
||
@router.post("", response_model=ASRModelOut)
|
||
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||
"""创建ASR模型"""
|
||
asr_model = ASRModel(
|
||
id=unique_short_id("asr", db, ASRModel),
|
||
user_id=1, # 默认用户
|
||
name=data.name,
|
||
vendor=data.vendor,
|
||
language=data.language,
|
||
base_url=data.base_url,
|
||
api_key=data.api_key,
|
||
model_name=data.model_name,
|
||
hotwords=data.hotwords,
|
||
enable_punctuation=data.enable_punctuation,
|
||
enable_normalization=data.enable_normalization,
|
||
enabled=data.enabled,
|
||
)
|
||
db.add(asr_model)
|
||
db.commit()
|
||
db.refresh(asr_model)
|
||
return asr_model
|
||
|
||
|
||
@router.put("/{id}", response_model=ASRModelOut)
|
||
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||
"""更新ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
update_data = data.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(model, field, value)
|
||
|
||
db.commit()
|
||
db.refresh(model)
|
||
return model
|
||
|
||
|
||
@router.delete("/{id}")
|
||
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||
"""删除ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
db.delete(model)
|
||
db.commit()
|
||
return {"message": "Deleted successfully"}
|
||
|
||
|
||
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||
def test_asr_model(
|
||
id: str,
|
||
request: Optional[ASRTestRequest] = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""测试ASR模型"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
if _is_dashscope_vendor(model.vendor):
|
||
effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||
if not effective_api_key:
|
||
return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}")
|
||
|
||
base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
|
||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||
_probe_dashscope_asr_connection(
|
||
api_key=effective_api_key,
|
||
base_url=base_url,
|
||
model=selected_model,
|
||
language=model.language,
|
||
)
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
return ASRTestResponse(
|
||
success=True,
|
||
language=model.language,
|
||
latency_ms=latency_ms,
|
||
message="DashScope realtime ASR connected",
|
||
)
|
||
|
||
# 连接性测试优先,避免依赖真实音频输入
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
with httpx.Client(timeout=60.0) as client:
|
||
if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer":
|
||
response = client.get(f"{model.base_url}/asr", headers=headers)
|
||
elif model.vendor.lower() == "openai":
|
||
response = client.get(f"{model.base_url}/audio/models", headers=headers)
|
||
else:
|
||
response = client.get(f"{model.base_url}/health", headers=headers)
|
||
response.raise_for_status()
|
||
raw_result = response.json()
|
||
|
||
# 兼容不同供应商格式
|
||
if isinstance(raw_result, dict) and "results" in raw_result:
|
||
result = raw_result
|
||
elif isinstance(raw_result, dict) and "text" in raw_result:
|
||
result = {"results": [{"transcript": raw_result.get("text", "")}]}
|
||
else:
|
||
result = {"results": [{"transcript": ""}]}
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 解析结果
|
||
if result_data := result.get("results", [{}])[0]:
|
||
transcript = result_data.get("transcript", "")
|
||
return ASRTestResponse(
|
||
success=True,
|
||
transcript=transcript,
|
||
language=result_data.get("language", model.language),
|
||
confidence=result_data.get("confidence"),
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
return ASRTestResponse(
|
||
success=False,
|
||
message="No transcript in response",
|
||
latency_ms=latency_ms
|
||
)
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
return ASRTestResponse(
|
||
success=False,
|
||
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||
)
|
||
except Exception as e:
|
||
return ASRTestResponse(
|
||
success=False,
|
||
error=str(e)[:200]
|
||
)
|
||
|
||
|
||
@router.post("/{id}/transcribe")
|
||
def transcribe_audio(
|
||
id: str,
|
||
audio_url: Optional[str] = None,
|
||
audio_data: Optional[str] = None,
|
||
hotwords: Optional[List[str]] = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""转写音频"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
try:
|
||
payload = {
|
||
"model": model.model_name or "paraformer-v2",
|
||
"input": {},
|
||
"parameters": {
|
||
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||
"enable_punctuation": model.enable_punctuation,
|
||
"enable_normalization": model.enable_normalization,
|
||
}
|
||
}
|
||
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
if audio_url:
|
||
payload["input"]["url"] = audio_url
|
||
elif audio_data:
|
||
payload["input"]["file_urls"] = []
|
||
|
||
with httpx.Client(timeout=120.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/asr",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result = response.json()
|
||
|
||
if result_data := result.get("results", [{}])[0]:
|
||
return {
|
||
"success": True,
|
||
"transcript": result_data.get("transcript", ""),
|
||
"language": result_data.get("language", model.language),
|
||
"confidence": result_data.get("confidence"),
|
||
}
|
||
|
||
return {"success": False, "error": "No transcript in response"}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/{id}/preview", response_model=ASRTestResponse)
|
||
async def preview_asr_model(
|
||
id: str,
|
||
file: UploadFile = File(...),
|
||
language: Optional[str] = Form(None),
|
||
api_key: Optional[str] = Form(None),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""预览 ASR:根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。"""
|
||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||
|
||
if not file:
|
||
raise HTTPException(status_code=400, detail="Audio file is required")
|
||
|
||
filename = file.filename or "preview.wav"
|
||
content_type = file.content_type or "application/octet-stream"
|
||
if not content_type.startswith("audio/"):
|
||
raise HTTPException(status_code=400, detail="Only audio files are supported")
|
||
|
||
audio_bytes = await file.read()
|
||
if not audio_bytes:
|
||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||
|
||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||
if not effective_api_key:
|
||
if _is_openai_compatible_vendor(model.vendor):
|
||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||
elif _is_dashscope_vendor(model.vendor):
|
||
effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||
if not effective_api_key:
|
||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||
|
||
base_url = (model.base_url or "").strip().rstrip("/")
|
||
if _is_dashscope_vendor(model.vendor) and not base_url:
|
||
base_url = DASHSCOPE_DEFAULT_BASE_URL
|
||
if not base_url:
|
||
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
|
||
|
||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||
effective_language = (language or "").strip() or None
|
||
|
||
start_time = time.time()
|
||
if _is_dashscope_vendor(model.vendor):
|
||
try:
|
||
payload = await asyncio.to_thread(
|
||
_transcribe_dashscope_preview,
|
||
audio_bytes=audio_bytes,
|
||
api_key=effective_api_key,
|
||
base_url=base_url,
|
||
model=selected_model,
|
||
language=effective_language or model.language,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc
|
||
|
||
transcript = str(payload.get("transcript") or "")
|
||
response_language = str(payload.get("language") or effective_language or model.language)
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
return ASRTestResponse(
|
||
success=bool(transcript),
|
||
transcript=transcript,
|
||
language=response_language,
|
||
confidence=None,
|
||
latency_ms=latency_ms,
|
||
message=None if transcript else "No transcript in response",
|
||
)
|
||
|
||
data = {"model": selected_model}
|
||
if effective_language:
|
||
data["language"] = effective_language
|
||
if model.hotwords:
|
||
data["prompt"] = " ".join(model.hotwords)
|
||
|
||
headers = {"Authorization": f"Bearer {effective_api_key}"}
|
||
files = {"file": (filename, audio_bytes, content_type)}
|
||
|
||
try:
|
||
with httpx.Client(timeout=90.0) as client:
|
||
response = client.post(
|
||
f"{base_url}/audio/transcriptions",
|
||
headers=headers,
|
||
data=data,
|
||
files=files,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail=f"ASR request failed: {exc}") from exc
|
||
|
||
if response.status_code != 200:
|
||
detail = response.text
|
||
try:
|
||
detail_json = response.json()
|
||
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
||
except Exception:
|
||
pass
|
||
raise HTTPException(status_code=502, detail=f"ASR vendor error: {detail}")
|
||
|
||
try:
|
||
payload = response.json()
|
||
except Exception:
|
||
payload = {"text": response.text}
|
||
|
||
transcript = ""
|
||
response_language = model.language
|
||
confidence = None
|
||
if isinstance(payload, dict):
|
||
transcript = str(payload.get("text") or payload.get("transcript") or "")
|
||
response_language = str(payload.get("language") or effective_language or model.language)
|
||
raw_confidence = payload.get("confidence")
|
||
if raw_confidence is not None:
|
||
try:
|
||
confidence = float(raw_confidence)
|
||
except (TypeError, ValueError):
|
||
confidence = None
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
return ASRTestResponse(
|
||
success=bool(transcript),
|
||
transcript=transcript,
|
||
language=response_language,
|
||
confidence=confidence,
|
||
latency_ms=latency_ms,
|
||
message=None if transcript else "No transcript in response",
|
||
)
|