Merge branch 'engine-v3' of https://gitea.xiaowang.eu.org/wx44wx/AI-VideoAssistant into engine-v3
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import wave
|
||||
from array import array
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
@@ -17,6 +25,32 @@ from ..schemas import (
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||||
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
||||
|
||||
try:
|
||||
from dashscope.audio.qwen_omni import TranscriptionParams
|
||||
except ImportError:
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
DASHSCOPE_IMPORT_ERROR = ""
|
||||
except Exception as exc:
|
||||
dashscope = None # type: ignore[assignment]
|
||||
MultiModality = None # type: ignore[assignment]
|
||||
OmniRealtimeConversation = None # type: ignore[assignment]
|
||||
TranscriptionParams = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
class OmniRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
@@ -29,12 +63,377 @@ def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
}
|
||||
|
||||
|
||||
def _is_dashscope_vendor(vendor: str) -> bool:
|
||||
return (vendor or "").strip().lower() == "dashscope"
|
||||
|
||||
|
||||
def _default_asr_model(vendor: str) -> str:
|
||||
if _is_openai_compatible_vendor(vendor):
|
||||
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
|
||||
if _is_dashscope_vendor(vendor):
|
||||
return DASHSCOPE_DEFAULT_ASR_MODEL
|
||||
return "whisper-1"
|
||||
|
||||
|
||||
def _dashscope_language(language: Optional[str]) -> Optional[str]:
|
||||
normalized = (language or "").strip().lower()
|
||||
if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}:
|
||||
return None
|
||||
if normalized.startswith("zh"):
|
||||
return "zh"
|
||||
if normalized.startswith("en"):
|
||||
return "en"
|
||||
return normalized
|
||||
|
||||
|
||||
class _DashScopePreviewCallback(OmniRealtimeCallback):
|
||||
"""Collect DashScope ASR websocket events for preview/test flows."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._open_event = threading.Event()
|
||||
self._session_ready_event = threading.Event()
|
||||
self._done_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
self._final_text = ""
|
||||
self._last_interim_text = ""
|
||||
self._error_message: Optional[str] = None
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._open_event.set()
|
||||
|
||||
def on_close(self, code: int, reason: str) -> None:
|
||||
if self._done_event.is_set():
|
||||
return
|
||||
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
|
||||
def on_error(self, message: Any) -> None:
|
||||
self._error_message = str(message)
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
|
||||
def on_event(self, response: Any) -> None:
|
||||
payload = _coerce_dashscope_event(response)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"session.created", "session.updated"}:
|
||||
self._session_ready_event.set()
|
||||
return
|
||||
|
||||
if event_type == "error" or event_type.endswith(".failed"):
|
||||
self._error_message = _format_dashscope_error_event(payload)
|
||||
self._done_event.set()
|
||||
self._session_ready_event.set()
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.text":
|
||||
interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript"))
|
||||
if interim_text:
|
||||
with self._lock:
|
||||
self._last_interim_text = interim_text
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash"))
|
||||
with self._lock:
|
||||
if final_text:
|
||||
self._final_text = final_text
|
||||
self._done_event.set()
|
||||
return
|
||||
|
||||
if event_type in {"response.done", "session.finished"}:
|
||||
self._done_event.set()
|
||||
|
||||
def wait_for_open(self, timeout: float = 10.0) -> None:
|
||||
if not self._open_event.wait(timeout):
|
||||
raise TimeoutError("DashScope websocket open timeout")
|
||||
|
||||
def wait_for_session_ready(self, timeout: float = 6.0) -> bool:
|
||||
return self._session_ready_event.wait(timeout)
|
||||
|
||||
def wait_for_done(self, timeout: float = 20.0) -> None:
|
||||
if not self._done_event.wait(timeout):
|
||||
raise TimeoutError("DashScope transcription timeout")
|
||||
|
||||
def raise_if_error(self) -> None:
|
||||
if self._error_message:
|
||||
raise RuntimeError(self._error_message)
|
||||
|
||||
def read_text(self) -> str:
|
||||
with self._lock:
|
||||
return self._final_text or self._last_interim_text
|
||||
|
||||
|
||||
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
|
||||
if isinstance(response, dict):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
try:
|
||||
parsed = json.loads(response)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"type": "raw", "message": str(response)}
|
||||
|
||||
|
||||
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict):
|
||||
code = str(error.get("code") or "").strip()
|
||||
message = str(error.get("message") or "").strip()
|
||||
if code and message:
|
||||
return f"{code}: {message}"
|
||||
return message or str(error)
|
||||
return str(error or "DashScope realtime ASR error")
|
||||
|
||||
|
||||
def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
nested = _extract_dashscope_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, dict):
|
||||
nested = _extract_dashscope_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
return ""
|
||||
|
||||
|
||||
def _create_dashscope_realtime_client(
|
||||
*,
|
||||
model: str,
|
||||
callback: _DashScopePreviewCallback,
|
||||
url: str,
|
||||
api_key: str,
|
||||
) -> Any:
|
||||
if OmniRealtimeConversation is None:
|
||||
raise RuntimeError("DashScope SDK unavailable")
|
||||
|
||||
init_kwargs = {
|
||||
"model": model,
|
||||
"callback": callback,
|
||||
"url": url,
|
||||
}
|
||||
try:
|
||||
return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc]
|
||||
except TypeError as exc:
|
||||
if "api_key" not in str(exc):
|
||||
raise
|
||||
return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc]
|
||||
|
||||
|
||||
def _close_dashscope_client(client: Any) -> None:
|
||||
finish_fn = getattr(client, "finish", None)
|
||||
if callable(finish_fn):
|
||||
try:
|
||||
finish_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _configure_dashscope_session(
|
||||
*,
|
||||
client: Any,
|
||||
callback: _DashScopePreviewCallback,
|
||||
sample_rate: int,
|
||||
language: Optional[str],
|
||||
) -> None:
|
||||
update_fn = getattr(client, "update_session", None)
|
||||
if not callable(update_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
||||
|
||||
text_modality: Any = "text"
|
||||
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
||||
text_modality = MultiModality.TEXT
|
||||
|
||||
transcription_params: Optional[Any] = None
|
||||
language_hint = _dashscope_language(language)
|
||||
if TranscriptionParams is not None:
|
||||
try:
|
||||
params_kwargs: Dict[str, Any] = {
|
||||
"sample_rate": sample_rate,
|
||||
"input_audio_format": "pcm",
|
||||
}
|
||||
if language_hint:
|
||||
params_kwargs["language"] = language_hint
|
||||
transcription_params = TranscriptionParams(**params_kwargs)
|
||||
except Exception:
|
||||
transcription_params = None
|
||||
|
||||
update_attempts = [
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
"transcription_params": transcription_params,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
},
|
||||
]
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for params in update_attempts:
|
||||
if params.get("transcription_params") is None:
|
||||
params = {key: value for key, value in params.items() if key != "transcription_params"}
|
||||
try:
|
||||
update_fn(**params)
|
||||
callback.wait_for_session_ready()
|
||||
callback.raise_if_error()
|
||||
return
|
||||
except TypeError as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
||||
|
||||
|
||||
def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]:
|
||||
try:
|
||||
with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file:
|
||||
channel_count = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_rate = wav_file.getframerate()
|
||||
compression = wav_file.getcomptype()
|
||||
pcm_frames = wav_file.readframes(wav_file.getnframes())
|
||||
except wave.Error as exc:
|
||||
raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc
|
||||
|
||||
if compression != "NONE":
|
||||
raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.")
|
||||
if sample_width != 2:
|
||||
raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.")
|
||||
if not pcm_frames:
|
||||
raise RuntimeError("Uploaded WAV file is empty")
|
||||
if channel_count <= 1:
|
||||
return pcm_frames, sample_rate
|
||||
|
||||
samples = array("h")
|
||||
samples.frombytes(pcm_frames)
|
||||
if sys.byteorder == "big":
|
||||
samples.byteswap()
|
||||
|
||||
mono_samples = array(
|
||||
"h",
|
||||
(
|
||||
int(sum(samples[index:index + channel_count]) / channel_count)
|
||||
for index in range(0, len(samples), channel_count)
|
||||
),
|
||||
)
|
||||
if sys.byteorder == "big":
|
||||
mono_samples.byteswap()
|
||||
return mono_samples.tobytes(), sample_rate
|
||||
|
||||
|
||||
def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||||
|
||||
callback = _DashScopePreviewCallback()
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = api_key
|
||||
client = _create_dashscope_realtime_client(
|
||||
model=model,
|
||||
callback=callback,
|
||||
url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
client.connect()
|
||||
callback.wait_for_open()
|
||||
_configure_dashscope_session(
|
||||
client=client,
|
||||
callback=callback,
|
||||
sample_rate=16000,
|
||||
language=language,
|
||||
)
|
||||
finally:
|
||||
_close_dashscope_client(client)
|
||||
|
||||
|
||||
def _transcribe_dashscope_preview(
|
||||
*,
|
||||
audio_bytes: bytes,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
language: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
|
||||
|
||||
pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes)
|
||||
callback = _DashScopePreviewCallback()
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = api_key
|
||||
client = _create_dashscope_realtime_client(
|
||||
model=model,
|
||||
callback=callback,
|
||||
url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
client.connect()
|
||||
callback.wait_for_open()
|
||||
_configure_dashscope_session(
|
||||
client=client,
|
||||
callback=callback,
|
||||
sample_rate=sample_rate,
|
||||
language=language,
|
||||
)
|
||||
|
||||
append_fn = getattr(client, "append_audio", None)
|
||||
if not callable(append_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
||||
commit_fn = getattr(client, "commit", None)
|
||||
if not callable(commit_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing commit method")
|
||||
|
||||
append_fn(base64.b64encode(pcm_audio).decode("ascii"))
|
||||
commit_fn()
|
||||
callback.wait_for_done()
|
||||
callback.raise_if_error()
|
||||
return {
|
||||
"transcript": callback.read_text(),
|
||||
"language": _dashscope_language(language) or "Multi-lingual",
|
||||
"confidence": None,
|
||||
}
|
||||
finally:
|
||||
_close_dashscope_client(client)
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("")
|
||||
def list_asr_models(
|
||||
@@ -132,6 +531,27 @@ def test_asr_model(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if _is_dashscope_vendor(model.vendor):
|
||||
effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
|
||||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||||
_probe_dashscope_asr_connection(
|
||||
api_key=effective_api_key,
|
||||
base_url=base_url,
|
||||
model=selected_model,
|
||||
language=model.language,
|
||||
)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return ASRTestResponse(
|
||||
success=True,
|
||||
language=model.language,
|
||||
latency_ms=latency_ms,
|
||||
message="DashScope realtime ASR connected",
|
||||
)
|
||||
|
||||
# 连接性测试优先,避免依赖真实音频输入
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
@@ -246,7 +666,7 @@ async def preview_asr_model(
|
||||
api_key: Optional[str] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。"""
|
||||
"""预览 ASR:根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
@@ -264,18 +684,50 @@ async def preview_asr_model(
|
||||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||||
|
||||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||||
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
if _is_openai_compatible_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
elif _is_dashscope_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
base_url = (model.base_url or "").strip().rstrip("/")
|
||||
if _is_dashscope_vendor(model.vendor) and not base_url:
|
||||
base_url = DASHSCOPE_DEFAULT_BASE_URL
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
|
||||
|
||||
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
|
||||
data = {"model": selected_model}
|
||||
effective_language = (language or "").strip() or None
|
||||
|
||||
start_time = time.time()
|
||||
if _is_dashscope_vendor(model.vendor):
|
||||
try:
|
||||
payload = await asyncio.to_thread(
|
||||
_transcribe_dashscope_preview,
|
||||
audio_bytes=audio_bytes,
|
||||
api_key=effective_api_key,
|
||||
base_url=base_url,
|
||||
model=selected_model,
|
||||
language=effective_language or model.language,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc
|
||||
|
||||
transcript = str(payload.get("transcript") or "")
|
||||
response_language = str(payload.get("language") or effective_language or model.language)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
return ASRTestResponse(
|
||||
success=bool(transcript),
|
||||
transcript=transcript,
|
||||
language=response_language,
|
||||
confidence=None,
|
||||
latency_ms=latency_ms,
|
||||
message=None if transcript else "No transcript in response",
|
||||
)
|
||||
|
||||
data = {"model": selected_model}
|
||||
if effective_language:
|
||||
data["language"] = effective_language
|
||||
if model.hotwords:
|
||||
@@ -284,7 +736,6 @@ async def preview_asr_model(
|
||||
headers = {"Authorization": f"Bearer {effective_api_key}"}
|
||||
files = {"file": (filename, audio_bytes, content_type)}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=90.0) as client:
|
||||
response = client.post(
|
||||
|
||||
@@ -191,6 +191,7 @@ class ASRModelCreate(ASRModelBase):
|
||||
|
||||
class ASRModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
vendor: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@@ -34,6 +34,7 @@ SEED_LLM_IDS = {
|
||||
SEED_ASR_IDS = {
|
||||
"sensevoice_small": short_id("asr"),
|
||||
"telespeech_asr": short_id("asr"),
|
||||
"dashscope_realtime": short_id("asr"),
|
||||
}
|
||||
|
||||
SEED_ASSISTANT_IDS = {
|
||||
@@ -408,6 +409,20 @@ def init_default_asr_models():
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
ASRModel(
|
||||
id=SEED_ASR_IDS["dashscope_realtime"],
|
||||
user_id=1,
|
||||
name="DashScope Realtime ASR",
|
||||
vendor="DashScope",
|
||||
language="Multi-lingual",
|
||||
base_url=DASHSCOPE_REALTIME_URL,
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="qwen3-asr-flash-realtime",
|
||||
hotwords=[],
|
||||
enable_punctuation=True,
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
|
||||
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
"""Tests for ASR Model API endpoints"""
|
||||
import io
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
def _make_wav_bytes(sample_rate: int = 16000) -> bytes:
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(b"\x00\x00" * sample_rate)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class TestASRModelAPI:
|
||||
"""Test cases for ASR Model endpoints"""
|
||||
|
||||
@@ -75,6 +88,24 @@ class TestASRModelAPI:
|
||||
assert data["language"] == "en"
|
||||
assert data["enable_punctuation"] == False
|
||||
|
||||
def test_update_asr_model_vendor(self, client, sample_asr_model_data):
|
||||
"""Test updating ASR vendor metadata."""
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
response = client.put(
|
||||
f"/api/asr/{model_id}",
|
||||
json={
|
||||
"vendor": "DashScope",
|
||||
"model_name": "qwen3-asr-flash-realtime",
|
||||
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["vendor"] == "DashScope"
|
||||
assert data["model_name"] == "qwen3-asr-flash-realtime"
|
||||
|
||||
def test_delete_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test deleting an ASR model"""
|
||||
# Create first
|
||||
@@ -234,6 +265,28 @@ class TestASRModelAPI:
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_test_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
|
||||
"""Test DashScope ASR connectivity probe."""
|
||||
from app.routers import asr as asr_router
|
||||
|
||||
sample_asr_model_data["vendor"] = "DashScope"
|
||||
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
def fake_probe(**kwargs):
|
||||
assert kwargs["api_key"] == sample_asr_model_data["api_key"]
|
||||
assert kwargs["model"] == "qwen3-asr-flash-realtime"
|
||||
|
||||
monkeypatch.setattr(asr_router, "_probe_dashscope_asr_connection", fake_probe)
|
||||
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "DashScope realtime ASR connected"
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with failed connection"""
|
||||
@@ -274,7 +327,7 @@ class TestASRModelAPI:
|
||||
|
||||
def test_different_asr_vendors(self, client):
|
||||
"""Test creating ASR models with different vendors"""
|
||||
vendors = ["SiliconFlow", "OpenAI", "Azure"]
|
||||
vendors = ["SiliconFlow", "OpenAI", "Azure", "DashScope"]
|
||||
for vendor in vendors:
|
||||
data = {
|
||||
"id": f"asr-vendor-{vendor.lower()}",
|
||||
@@ -345,3 +398,33 @@ class TestASRModelAPI:
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Only audio files are supported" in response.text
|
||||
|
||||
def test_preview_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
|
||||
"""Test ASR preview endpoint with DashScope realtime helper."""
|
||||
from app.routers import asr as asr_router
|
||||
|
||||
sample_asr_model_data["vendor"] = "DashScope"
|
||||
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
def fake_preview(**kwargs):
|
||||
assert kwargs["base_url"] == sample_asr_model_data["base_url"]
|
||||
assert kwargs["model"] == sample_asr_model_data["model_name"]
|
||||
return {
|
||||
"transcript": "你好,这是实时识别",
|
||||
"language": "zh",
|
||||
"confidence": None,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(asr_router, "_transcribe_dashscope_preview", fake_preview)
|
||||
|
||||
response = client.post(
|
||||
f"/api/asr/{model_id}/preview",
|
||||
files={"file": ("sample.wav", _make_wav_bytes(), "audio/wav")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["success"] is True
|
||||
assert payload["transcript"] == "你好,这是实时识别"
|
||||
|
||||
@@ -82,6 +82,16 @@ const convertRecordedBlobToWav = async (blob: Blob): Promise<File> => {
|
||||
}
|
||||
};
|
||||
|
||||
const OPENAI_COMPATIBLE_DEFAULT_MODEL = 'FunAudioLLM/SenseVoiceSmall';
|
||||
const OPENAI_COMPATIBLE_DEFAULT_BASE_URL = 'https://api.siliconflow.cn/v1';
|
||||
const DASHSCOPE_DEFAULT_MODEL = 'qwen3-asr-flash-realtime';
|
||||
const DASHSCOPE_DEFAULT_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime';
|
||||
|
||||
type ASRVendor = 'OpenAI Compatible' | 'DashScope';
|
||||
|
||||
const normalizeVendor = (value?: string): ASRVendor =>
|
||||
String(value || '').trim().toLowerCase() === 'dashscope' ? 'DashScope' : 'OpenAI Compatible';
|
||||
|
||||
export const ASRLibraryPage: React.FC = () => {
|
||||
const [models, setModels] = useState<ASRModel[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
@@ -271,10 +281,10 @@ const ASRModelModal: React.FC<{
|
||||
initialModel?: ASRModel;
|
||||
}> = ({ isOpen, onClose, onSubmit, initialModel }) => {
|
||||
const [name, setName] = useState('');
|
||||
const [vendor, setVendor] = useState('OpenAI Compatible');
|
||||
const [vendor, setVendor] = useState<ASRVendor>('OpenAI Compatible');
|
||||
const [language, setLanguage] = useState('zh');
|
||||
const [modelName, setModelName] = useState('FunAudioLLM/SenseVoiceSmall');
|
||||
const [baseUrl, setBaseUrl] = useState('https://api.siliconflow.cn/v1');
|
||||
const [modelName, setModelName] = useState(OPENAI_COMPATIBLE_DEFAULT_MODEL);
|
||||
const [baseUrl, setBaseUrl] = useState(OPENAI_COMPATIBLE_DEFAULT_BASE_URL);
|
||||
const [apiKey, setApiKey] = useState('');
|
||||
const [hotwords, setHotwords] = useState('');
|
||||
const [enablePunctuation, setEnablePunctuation] = useState(true);
|
||||
@@ -282,14 +292,40 @@ const ASRModelModal: React.FC<{
|
||||
const [enabled, setEnabled] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
|
||||
const getDefaultModel = (nextVendor: ASRVendor): string =>
|
||||
nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL;
|
||||
|
||||
const getDefaultBaseUrl = (nextVendor: ASRVendor): string =>
|
||||
nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL;
|
||||
|
||||
const handleVendorChange = (nextVendor: ASRVendor) => {
|
||||
const previousVendor = vendor;
|
||||
setVendor(nextVendor);
|
||||
|
||||
const previousDefaultModel = getDefaultModel(previousVendor);
|
||||
const nextDefaultModel = getDefaultModel(nextVendor);
|
||||
const trimmedModelName = modelName.trim();
|
||||
if (!trimmedModelName || trimmedModelName === previousDefaultModel) {
|
||||
setModelName(nextDefaultModel);
|
||||
}
|
||||
|
||||
const previousDefaultBaseUrl = getDefaultBaseUrl(previousVendor);
|
||||
const nextDefaultBaseUrl = getDefaultBaseUrl(nextVendor);
|
||||
const trimmedBaseUrl = baseUrl.trim();
|
||||
if (!trimmedBaseUrl || trimmedBaseUrl === previousDefaultBaseUrl) {
|
||||
setBaseUrl(nextDefaultBaseUrl);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
if (initialModel) {
|
||||
const nextVendor = normalizeVendor(initialModel.vendor);
|
||||
setName(initialModel.name || '');
|
||||
setVendor(initialModel.vendor || 'OpenAI Compatible');
|
||||
setVendor(nextVendor);
|
||||
setLanguage(initialModel.language || 'zh');
|
||||
setModelName(initialModel.modelName || 'FunAudioLLM/SenseVoiceSmall');
|
||||
setBaseUrl(initialModel.baseUrl || 'https://api.siliconflow.cn/v1');
|
||||
setModelName(initialModel.modelName || getDefaultModel(nextVendor));
|
||||
setBaseUrl(initialModel.baseUrl || getDefaultBaseUrl(nextVendor));
|
||||
setApiKey(initialModel.apiKey || '');
|
||||
setHotwords(toHotwordsValue(initialModel.hotwords));
|
||||
setEnablePunctuation(initialModel.enablePunctuation ?? true);
|
||||
@@ -301,8 +337,8 @@ const ASRModelModal: React.FC<{
|
||||
setName('');
|
||||
setVendor('OpenAI Compatible');
|
||||
setLanguage('zh');
|
||||
setModelName('FunAudioLLM/SenseVoiceSmall');
|
||||
setBaseUrl('https://api.siliconflow.cn/v1');
|
||||
setModelName(OPENAI_COMPATIBLE_DEFAULT_MODEL);
|
||||
setBaseUrl(OPENAI_COMPATIBLE_DEFAULT_BASE_URL);
|
||||
setApiKey('');
|
||||
setHotwords('');
|
||||
setEnablePunctuation(true);
|
||||
@@ -368,9 +404,10 @@ const ASRModelModal: React.FC<{
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">接口类型</label>
|
||||
<Select
|
||||
value={vendor}
|
||||
onChange={(e) => setVendor(e.target.value)}
|
||||
onChange={(e) => handleVendorChange(e.target.value as ASRVendor)}
|
||||
>
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
<option value="DashScope">DashScope</option>
|
||||
</Select>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
@@ -388,13 +425,22 @@ const ASRModelModal: React.FC<{
|
||||
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">Model Name</label>
|
||||
<Input value={modelName} onChange={(e) => setModelName(e.target.value)} placeholder="FunAudioLLM/SenseVoiceSmall" />
|
||||
<Input
|
||||
value={modelName}
|
||||
onChange={(e) => setModelName(e.target.value)}
|
||||
placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block flex items-center"><Server className="w-3 h-3 mr-1.5" />Base URL</label>
|
||||
<Input value={baseUrl} onChange={(e) => setBaseUrl(e.target.value)} placeholder="https://api.siliconflow.cn/v1" className="font-mono text-xs" />
|
||||
<Input
|
||||
value={baseUrl}
|
||||
onChange={(e) => setBaseUrl(e.target.value)}
|
||||
placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL}
|
||||
className="font-mono text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block flex items-center"><Key className="w-3 h-3 mr-1.5" />API Key</label>
|
||||
@@ -405,6 +451,11 @@ const ASRModelModal: React.FC<{
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">热词 (comma separated)</label>
|
||||
<Input value={hotwords} onChange={(e) => setHotwords(e.target.value)} placeholder="品牌名, 人名, 专有词" />
|
||||
{vendor === 'DashScope' && (
|
||||
<p className="text-[11px] text-muted-foreground">
|
||||
DashScope 走实时 WebSocket ASR。预览建议使用浏览器录音或上传 WAV 文件。
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-2">
|
||||
|
||||
@@ -3,6 +3,8 @@ import { apiRequest, getApiBaseUrl } from './apiClient';
|
||||
|
||||
type AnyRecord = Record<string, any>;
|
||||
const DEFAULT_LIST_LIMIT = 1000;
|
||||
const OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL = 'https://api.siliconflow.cn/v1';
|
||||
const DASHSCOPE_DEFAULT_ASR_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime';
|
||||
const TOOL_ID_ALIASES: Record<string, string> = {
|
||||
voice_message_prompt: 'voice_msg_prompt',
|
||||
};
|
||||
@@ -129,7 +131,16 @@ const mapVoice = (raw: AnyRecord): Voice => ({
|
||||
const mapASRModel = (raw: AnyRecord): ASRModel => ({
|
||||
id: String(readField(raw, ['id'], '')),
|
||||
name: readField(raw, ['name'], ''),
|
||||
vendor: readField(raw, ['vendor'], 'OpenAI Compatible'),
|
||||
vendor: (() => {
|
||||
const vendor = String(readField(raw, ['vendor'], '')).trim().toLowerCase();
|
||||
if (vendor === 'dashscope') {
|
||||
return 'DashScope';
|
||||
}
|
||||
if (vendor === 'siliconflow' || vendor === 'openai compatible' || vendor === 'openai-compatible' || vendor === '硅基流动') {
|
||||
return 'OpenAI Compatible';
|
||||
}
|
||||
return String(readField(raw, ['vendor'], 'OpenAI Compatible')) || 'OpenAI Compatible';
|
||||
})(),
|
||||
language: readField(raw, ['language'], 'zh'),
|
||||
baseUrl: readField(raw, ['baseUrl', 'base_url'], ''),
|
||||
apiKey: readField(raw, ['apiKey', 'api_key'], ''),
|
||||
@@ -457,11 +468,16 @@ export const fetchASRModels = async (): Promise<ASRModel[]> => {
|
||||
};
|
||||
|
||||
export const createASRModel = async (data: Partial<ASRModel>): Promise<ASRModel> => {
|
||||
const vendor = data.vendor || 'OpenAI Compatible';
|
||||
const normalizedVendor = String(vendor).trim().toLowerCase();
|
||||
const defaultBaseUrl = normalizedVendor === 'dashscope'
|
||||
? DASHSCOPE_DEFAULT_ASR_BASE_URL
|
||||
: OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL;
|
||||
const payload = {
|
||||
name: data.name || 'New ASR Model',
|
||||
vendor: data.vendor || 'OpenAI Compatible',
|
||||
vendor,
|
||||
language: data.language || 'zh',
|
||||
base_url: data.baseUrl || 'https://api.siliconflow.cn/v1',
|
||||
base_url: data.baseUrl || defaultBaseUrl,
|
||||
api_key: data.apiKey || '',
|
||||
model_name: data.modelName || undefined,
|
||||
hotwords: data.hotwords || [],
|
||||
|
||||
Reference in New Issue
Block a user