Add DashScope ASR model support and enhance related components

- Introduced DashScope as a new ASR model in the database initialization.
- Updated ASRModel schema to include vendor information.
- Enhanced ASR router to support DashScope-specific functionality, including connection testing and preview capabilities.
- Modified frontend components to accommodate DashScope as a selectable vendor with appropriate default settings.
- Added tests to validate DashScope ASR model creation, updates, and connectivity.
- Updated backend API to handle DashScope-specific base URLs and vendor normalization.
This commit is contained in:
Xin Wang
2026-03-09 07:37:00 +08:00
parent e41d34fe23
commit bfe165daae
6 changed files with 638 additions and 21 deletions

View File

@@ -1,6 +1,14 @@
import asyncio
import base64
import io
import json
import os
import sys
import threading
import time
from typing import List, Optional
import wave
from array import array
from typing import Any, Dict, List, Optional, Tuple
import httpx
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
@@ -17,6 +25,32 @@ from ..schemas import (
router = APIRouter(prefix="/asr", tags=["ASR Models"])
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime"
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
try:
import dashscope
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
try:
from dashscope.audio.qwen_omni import TranscriptionParams
except ImportError:
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
DASHSCOPE_SDK_AVAILABLE = True
DASHSCOPE_IMPORT_ERROR = ""
except Exception as exc:
dashscope = None # type: ignore[assignment]
MultiModality = None # type: ignore[assignment]
OmniRealtimeConversation = None # type: ignore[assignment]
TranscriptionParams = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
class OmniRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
def _is_openai_compatible_vendor(vendor: str) -> bool:
@@ -29,12 +63,377 @@ def _is_openai_compatible_vendor(vendor: str) -> bool:
}
def _is_dashscope_vendor(vendor: str) -> bool:
return (vendor or "").strip().lower() == "dashscope"
def _default_asr_model(vendor: str) -> str:
if _is_openai_compatible_vendor(vendor):
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
if _is_dashscope_vendor(vendor):
return DASHSCOPE_DEFAULT_ASR_MODEL
return "whisper-1"
def _dashscope_language(language: Optional[str]) -> Optional[str]:
normalized = (language or "").strip().lower()
if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}:
return None
if normalized.startswith("zh"):
return "zh"
if normalized.startswith("en"):
return "en"
return normalized
class _DashScopePreviewCallback(OmniRealtimeCallback):
"""Collect DashScope ASR websocket events for preview/test flows."""
def __init__(self) -> None:
super().__init__()
self._open_event = threading.Event()
self._session_ready_event = threading.Event()
self._done_event = threading.Event()
self._lock = threading.Lock()
self._final_text = ""
self._last_interim_text = ""
self._error_message: Optional[str] = None
def on_open(self) -> None:
self._open_event.set()
def on_close(self, code: int, reason: str) -> None:
if self._done_event.is_set():
return
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
self._done_event.set()
self._session_ready_event.set()
def on_error(self, message: Any) -> None:
self._error_message = str(message)
self._done_event.set()
self._session_ready_event.set()
def on_event(self, response: Any) -> None:
payload = _coerce_dashscope_event(response)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"session.created", "session.updated"}:
self._session_ready_event.set()
return
if event_type == "error" or event_type.endswith(".failed"):
self._error_message = _format_dashscope_error_event(payload)
self._done_event.set()
self._session_ready_event.set()
return
if event_type == "conversation.item.input_audio_transcription.text":
interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript"))
if interim_text:
with self._lock:
self._last_interim_text = interim_text
return
if event_type == "conversation.item.input_audio_transcription.completed":
final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash"))
with self._lock:
if final_text:
self._final_text = final_text
self._done_event.set()
return
if event_type in {"response.done", "session.finished"}:
self._done_event.set()
def wait_for_open(self, timeout: float = 10.0) -> None:
if not self._open_event.wait(timeout):
raise TimeoutError("DashScope websocket open timeout")
def wait_for_session_ready(self, timeout: float = 6.0) -> bool:
return self._session_ready_event.wait(timeout)
def wait_for_done(self, timeout: float = 20.0) -> None:
if not self._done_event.wait(timeout):
raise TimeoutError("DashScope transcription timeout")
def raise_if_error(self) -> None:
if self._error_message:
raise RuntimeError(self._error_message)
def read_text(self) -> str:
with self._lock:
return self._final_text or self._last_interim_text
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
if isinstance(response, dict):
return response
if isinstance(response, str):
try:
parsed = json.loads(response)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {"type": "raw", "message": str(response)}
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
error = payload.get("error")
if isinstance(error, dict):
code = str(error.get("code") or "").strip()
message = str(error.get("message") or "").strip()
if code and message:
return f"{code}: {message}"
return message or str(error)
return str(error or "DashScope realtime ASR error")
def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = _extract_dashscope_text(value, keys=keys)
if nested:
return nested
for value in payload.values():
if isinstance(value, dict):
nested = _extract_dashscope_text(value, keys=keys)
if nested:
return nested
return ""
def _create_dashscope_realtime_client(
*,
model: str,
callback: _DashScopePreviewCallback,
url: str,
api_key: str,
) -> Any:
if OmniRealtimeConversation is None:
raise RuntimeError("DashScope SDK unavailable")
init_kwargs = {
"model": model,
"callback": callback,
"url": url,
}
try:
return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc]
except TypeError as exc:
if "api_key" not in str(exc):
raise
return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc]
def _close_dashscope_client(client: Any) -> None:
finish_fn = getattr(client, "finish", None)
if callable(finish_fn):
try:
finish_fn()
except Exception:
pass
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
def _configure_dashscope_session(
*,
client: Any,
callback: _DashScopePreviewCallback,
sample_rate: int,
language: Optional[str],
) -> None:
update_fn = getattr(client, "update_session", None)
if not callable(update_fn):
raise RuntimeError("DashScope ASR SDK missing update_session method")
text_modality: Any = "text"
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
text_modality = MultiModality.TEXT
transcription_params: Optional[Any] = None
language_hint = _dashscope_language(language)
if TranscriptionParams is not None:
try:
params_kwargs: Dict[str, Any] = {
"sample_rate": sample_rate,
"input_audio_format": "pcm",
}
if language_hint:
params_kwargs["language"] = language_hint
transcription_params = TranscriptionParams(**params_kwargs)
except Exception:
transcription_params = None
update_attempts = [
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
"transcription_params": transcription_params,
},
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
},
{
"output_modalities": [text_modality],
},
]
last_error: Optional[Exception] = None
for params in update_attempts:
if params.get("transcription_params") is None:
params = {key: value for key, value in params.items() if key != "transcription_params"}
try:
update_fn(**params)
callback.wait_for_session_ready()
callback.raise_if_error()
return
except TypeError as exc:
last_error = exc
continue
except Exception as exc:
last_error = exc
continue
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]:
try:
with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file:
channel_count = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
sample_rate = wav_file.getframerate()
compression = wav_file.getcomptype()
pcm_frames = wav_file.readframes(wav_file.getnframes())
except wave.Error as exc:
raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc
if compression != "NONE":
raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.")
if sample_width != 2:
raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.")
if not pcm_frames:
raise RuntimeError("Uploaded WAV file is empty")
if channel_count <= 1:
return pcm_frames, sample_rate
samples = array("h")
samples.frombytes(pcm_frames)
if sys.byteorder == "big":
samples.byteswap()
mono_samples = array(
"h",
(
int(sum(samples[index:index + channel_count]) / channel_count)
for index in range(0, len(samples), channel_count)
),
)
if sys.byteorder == "big":
mono_samples.byteswap()
return mono_samples.tobytes(), sample_rate
def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
callback = _DashScopePreviewCallback()
if dashscope is not None:
dashscope.api_key = api_key
client = _create_dashscope_realtime_client(
model=model,
callback=callback,
url=base_url,
api_key=api_key,
)
try:
client.connect()
callback.wait_for_open()
_configure_dashscope_session(
client=client,
callback=callback,
sample_rate=16000,
language=language,
)
finally:
_close_dashscope_client(client)
def _transcribe_dashscope_preview(
*,
audio_bytes: bytes,
api_key: str,
base_url: str,
model: str,
language: Optional[str],
) -> Dict[str, Any]:
if not DASHSCOPE_SDK_AVAILABLE:
hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}")
pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes)
callback = _DashScopePreviewCallback()
if dashscope is not None:
dashscope.api_key = api_key
client = _create_dashscope_realtime_client(
model=model,
callback=callback,
url=base_url,
api_key=api_key,
)
try:
client.connect()
callback.wait_for_open()
_configure_dashscope_session(
client=client,
callback=callback,
sample_rate=sample_rate,
language=language,
)
append_fn = getattr(client, "append_audio", None)
if not callable(append_fn):
raise RuntimeError("DashScope ASR SDK missing append_audio method")
commit_fn = getattr(client, "commit", None)
if not callable(commit_fn):
raise RuntimeError("DashScope ASR SDK missing commit method")
append_fn(base64.b64encode(pcm_audio).decode("ascii"))
commit_fn()
callback.wait_for_done()
callback.raise_if_error()
return {
"transcript": callback.read_text(),
"language": _dashscope_language(language) or "Multi-lingual",
"confidence": None,
}
finally:
_close_dashscope_client(client)
# ============ ASR Models CRUD ============
@router.get("")
def list_asr_models(
@@ -132,6 +531,27 @@ def test_asr_model(
start_time = time.time()
try:
if _is_dashscope_vendor(model.vendor):
effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
if not effective_api_key:
return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}")
base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
_probe_dashscope_asr_connection(
api_key=effective_api_key,
base_url=base_url,
model=selected_model,
language=model.language,
)
latency_ms = int((time.time() - start_time) * 1000)
return ASRTestResponse(
success=True,
language=model.language,
latency_ms=latency_ms,
message="DashScope realtime ASR connected",
)
# 连接性测试优先,避免依赖真实音频输入
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
@@ -246,7 +666,7 @@ async def preview_asr_model(
api_key: Optional[str] = Form(None),
db: Session = Depends(get_db),
):
"""预览 ASR上传音频并调用 OpenAI-compatible /audio/transcriptions"""
"""预览 ASR根据供应商调用 OpenAI-compatible 或 DashScope 实时识别"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
@@ -264,18 +684,50 @@ async def preview_asr_model(
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not effective_api_key:
if _is_openai_compatible_vendor(model.vendor):
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
elif _is_dashscope_vendor(model.vendor):
effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip()
if not effective_api_key:
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
base_url = (model.base_url or "").strip().rstrip("/")
if _is_dashscope_vendor(model.vendor) and not base_url:
base_url = DASHSCOPE_DEFAULT_BASE_URL
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}")
selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor)
data = {"model": selected_model}
effective_language = (language or "").strip() or None
start_time = time.time()
if _is_dashscope_vendor(model.vendor):
try:
payload = await asyncio.to_thread(
_transcribe_dashscope_preview,
audio_bytes=audio_bytes,
api_key=effective_api_key,
base_url=base_url,
model=selected_model,
language=effective_language or model.language,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc
transcript = str(payload.get("transcript") or "")
response_language = str(payload.get("language") or effective_language or model.language)
latency_ms = int((time.time() - start_time) * 1000)
return ASRTestResponse(
success=bool(transcript),
transcript=transcript,
language=response_language,
confidence=None,
latency_ms=latency_ms,
message=None if transcript else "No transcript in response",
)
data = {"model": selected_model}
if effective_language:
data["language"] = effective_language
if model.hotwords:
@@ -284,7 +736,6 @@ async def preview_asr_model(
headers = {"Authorization": f"Bearer {effective_api_key}"}
files = {"file": (filename, audio_bytes, content_type)}
start_time = time.time()
try:
with httpx.Client(timeout=90.0) as client:
response = client.post(

View File

@@ -191,6 +191,7 @@ class ASRModelCreate(ASRModelBase):
class ASRModelUpdate(BaseModel):
name: Optional[str] = None
vendor: Optional[str] = None
language: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None

View File

@@ -34,6 +34,7 @@ SEED_LLM_IDS = {
SEED_ASR_IDS = {
"sensevoice_small": short_id("asr"),
"telespeech_asr": short_id("asr"),
"dashscope_realtime": short_id("asr"),
}
SEED_ASSISTANT_IDS = {
@@ -408,6 +409,20 @@ def init_default_asr_models():
enable_normalization=True,
enabled=True,
),
ASRModel(
id=SEED_ASR_IDS["dashscope_realtime"],
user_id=1,
name="DashScope Realtime ASR",
vendor="DashScope",
language="Multi-lingual",
base_url=DASHSCOPE_REALTIME_URL,
api_key="YOUR_API_KEY",
model_name="qwen3-asr-flash-realtime",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
]
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")

View File

@@ -1,8 +1,21 @@
"""Tests for ASR Model API endpoints"""
import io
import wave
import pytest
from unittest.mock import patch, MagicMock
def _make_wav_bytes(sample_rate: int = 16000) -> bytes:
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(b"\x00\x00" * sample_rate)
return buffer.getvalue()
class TestASRModelAPI:
"""Test cases for ASR Model endpoints"""
@@ -75,6 +88,24 @@ class TestASRModelAPI:
assert data["language"] == "en"
assert data["enable_punctuation"] == False
def test_update_asr_model_vendor(self, client, sample_asr_model_data):
"""Test updating ASR vendor metadata."""
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
response = client.put(
f"/api/asr/{model_id}",
json={
"vendor": "DashScope",
"model_name": "qwen3-asr-flash-realtime",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
},
)
assert response.status_code == 200
data = response.json()
assert data["vendor"] == "DashScope"
assert data["model_name"] == "qwen3-asr-flash-realtime"
def test_delete_asr_model(self, client, sample_asr_model_data):
"""Test deleting an ASR model"""
# Create first
@@ -234,6 +265,28 @@ class TestASRModelAPI:
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
def test_test_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
"""Test DashScope ASR connectivity probe."""
from app.routers import asr as asr_router
sample_asr_model_data["vendor"] = "DashScope"
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
def fake_probe(**kwargs):
assert kwargs["api_key"] == sample_asr_model_data["api_key"]
assert kwargs["model"] == "qwen3-asr-flash-realtime"
monkeypatch.setattr(asr_router, "_probe_dashscope_asr_connection", fake_probe)
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "DashScope realtime ASR connected"
@patch('httpx.Client')
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with failed connection"""
@@ -274,7 +327,7 @@ class TestASRModelAPI:
def test_different_asr_vendors(self, client):
"""Test creating ASR models with different vendors"""
vendors = ["SiliconFlow", "OpenAI", "Azure"]
vendors = ["SiliconFlow", "OpenAI", "Azure", "DashScope"]
for vendor in vendors:
data = {
"id": f"asr-vendor-{vendor.lower()}",
@@ -345,3 +398,33 @@ class TestASRModelAPI:
)
assert response.status_code == 400
assert "Only audio files are supported" in response.text
def test_preview_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
"""Test ASR preview endpoint with DashScope realtime helper."""
from app.routers import asr as asr_router
sample_asr_model_data["vendor"] = "DashScope"
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
def fake_preview(**kwargs):
assert kwargs["base_url"] == sample_asr_model_data["base_url"]
assert kwargs["model"] == sample_asr_model_data["model_name"]
return {
"transcript": "你好,这是实时识别",
"language": "zh",
"confidence": None,
}
monkeypatch.setattr(asr_router, "_transcribe_dashscope_preview", fake_preview)
response = client.post(
f"/api/asr/{model_id}/preview",
files={"file": ("sample.wav", _make_wav_bytes(), "audio/wav")},
)
assert response.status_code == 200
payload = response.json()
assert payload["success"] is True
assert payload["transcript"] == "你好,这是实时识别"

View File

@@ -82,6 +82,16 @@ const convertRecordedBlobToWav = async (blob: Blob): Promise<File> => {
}
};
const OPENAI_COMPATIBLE_DEFAULT_MODEL = 'FunAudioLLM/SenseVoiceSmall';
const OPENAI_COMPATIBLE_DEFAULT_BASE_URL = 'https://api.siliconflow.cn/v1';
const DASHSCOPE_DEFAULT_MODEL = 'qwen3-asr-flash-realtime';
const DASHSCOPE_DEFAULT_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime';
type ASRVendor = 'OpenAI Compatible' | 'DashScope';
const normalizeVendor = (value?: string): ASRVendor =>
String(value || '').trim().toLowerCase() === 'dashscope' ? 'DashScope' : 'OpenAI Compatible';
export const ASRLibraryPage: React.FC = () => {
const [models, setModels] = useState<ASRModel[]>([]);
const [searchTerm, setSearchTerm] = useState('');
@@ -271,10 +281,10 @@ const ASRModelModal: React.FC<{
initialModel?: ASRModel;
}> = ({ isOpen, onClose, onSubmit, initialModel }) => {
const [name, setName] = useState('');
const [vendor, setVendor] = useState('OpenAI Compatible');
const [vendor, setVendor] = useState<ASRVendor>('OpenAI Compatible');
const [language, setLanguage] = useState('zh');
const [modelName, setModelName] = useState('FunAudioLLM/SenseVoiceSmall');
const [baseUrl, setBaseUrl] = useState('https://api.siliconflow.cn/v1');
const [modelName, setModelName] = useState(OPENAI_COMPATIBLE_DEFAULT_MODEL);
const [baseUrl, setBaseUrl] = useState(OPENAI_COMPATIBLE_DEFAULT_BASE_URL);
const [apiKey, setApiKey] = useState('');
const [hotwords, setHotwords] = useState('');
const [enablePunctuation, setEnablePunctuation] = useState(true);
@@ -282,14 +292,40 @@ const ASRModelModal: React.FC<{
const [enabled, setEnabled] = useState(true);
const [saving, setSaving] = useState(false);
const getDefaultModel = (nextVendor: ASRVendor): string =>
nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL;
const getDefaultBaseUrl = (nextVendor: ASRVendor): string =>
nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL;
const handleVendorChange = (nextVendor: ASRVendor) => {
const previousVendor = vendor;
setVendor(nextVendor);
const previousDefaultModel = getDefaultModel(previousVendor);
const nextDefaultModel = getDefaultModel(nextVendor);
const trimmedModelName = modelName.trim();
if (!trimmedModelName || trimmedModelName === previousDefaultModel) {
setModelName(nextDefaultModel);
}
const previousDefaultBaseUrl = getDefaultBaseUrl(previousVendor);
const nextDefaultBaseUrl = getDefaultBaseUrl(nextVendor);
const trimmedBaseUrl = baseUrl.trim();
if (!trimmedBaseUrl || trimmedBaseUrl === previousDefaultBaseUrl) {
setBaseUrl(nextDefaultBaseUrl);
}
};
useEffect(() => {
if (!isOpen) return;
if (initialModel) {
const nextVendor = normalizeVendor(initialModel.vendor);
setName(initialModel.name || '');
setVendor(initialModel.vendor || 'OpenAI Compatible');
setVendor(nextVendor);
setLanguage(initialModel.language || 'zh');
setModelName(initialModel.modelName || 'FunAudioLLM/SenseVoiceSmall');
setBaseUrl(initialModel.baseUrl || 'https://api.siliconflow.cn/v1');
setModelName(initialModel.modelName || getDefaultModel(nextVendor));
setBaseUrl(initialModel.baseUrl || getDefaultBaseUrl(nextVendor));
setApiKey(initialModel.apiKey || '');
setHotwords(toHotwordsValue(initialModel.hotwords));
setEnablePunctuation(initialModel.enablePunctuation ?? true);
@@ -301,8 +337,8 @@ const ASRModelModal: React.FC<{
setName('');
setVendor('OpenAI Compatible');
setLanguage('zh');
setModelName('FunAudioLLM/SenseVoiceSmall');
setBaseUrl('https://api.siliconflow.cn/v1');
setModelName(OPENAI_COMPATIBLE_DEFAULT_MODEL);
setBaseUrl(OPENAI_COMPATIBLE_DEFAULT_BASE_URL);
setApiKey('');
setHotwords('');
setEnablePunctuation(true);
@@ -368,9 +404,10 @@ const ASRModelModal: React.FC<{
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block"></label>
<Select
value={vendor}
onChange={(e) => setVendor(e.target.value)}
onChange={(e) => handleVendorChange(e.target.value as ASRVendor)}
>
<option value="OpenAI Compatible">OpenAI Compatible</option>
<option value="DashScope">DashScope</option>
</Select>
</div>
<div className="space-y-1.5">
@@ -388,13 +425,22 @@ const ASRModelModal: React.FC<{
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">Model Name</label>
<Input value={modelName} onChange={(e) => setModelName(e.target.value)} placeholder="FunAudioLLM/SenseVoiceSmall" />
<Input
value={modelName}
onChange={(e) => setModelName(e.target.value)}
placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL}
/>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block flex items-center"><Server className="w-3 h-3 mr-1.5" />Base URL</label>
<Input value={baseUrl} onChange={(e) => setBaseUrl(e.target.value)} placeholder="https://api.siliconflow.cn/v1" className="font-mono text-xs" />
<Input
value={baseUrl}
onChange={(e) => setBaseUrl(e.target.value)}
placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL}
className="font-mono text-xs"
/>
</div>
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block flex items-center"><Key className="w-3 h-3 mr-1.5" />API Key</label>
@@ -405,6 +451,11 @@ const ASRModelModal: React.FC<{
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block"> (comma separated)</label>
<Input value={hotwords} onChange={(e) => setHotwords(e.target.value)} placeholder="品牌名, 人名, 专有词" />
{vendor === 'DashScope' && (
<p className="text-[11px] text-muted-foreground">
DashScope WebSocket ASR使 WAV
</p>
)}
</div>
<div className="grid grid-cols-1 md:grid-cols-3 gap-2">

View File

@@ -3,6 +3,8 @@ import { apiRequest, getApiBaseUrl } from './apiClient';
type AnyRecord = Record<string, any>;
const DEFAULT_LIST_LIMIT = 1000;
const OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL = 'https://api.siliconflow.cn/v1';
const DASHSCOPE_DEFAULT_ASR_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime';
const TOOL_ID_ALIASES: Record<string, string> = {
voice_message_prompt: 'voice_msg_prompt',
};
@@ -129,7 +131,16 @@ const mapVoice = (raw: AnyRecord): Voice => ({
const mapASRModel = (raw: AnyRecord): ASRModel => ({
id: String(readField(raw, ['id'], '')),
name: readField(raw, ['name'], ''),
vendor: readField(raw, ['vendor'], 'OpenAI Compatible'),
vendor: (() => {
const vendor = String(readField(raw, ['vendor'], '')).trim().toLowerCase();
if (vendor === 'dashscope') {
return 'DashScope';
}
if (vendor === 'siliconflow' || vendor === 'openai compatible' || vendor === 'openai-compatible' || vendor === '硅基流动') {
return 'OpenAI Compatible';
}
return String(readField(raw, ['vendor'], 'OpenAI Compatible')) || 'OpenAI Compatible';
})(),
language: readField(raw, ['language'], 'zh'),
baseUrl: readField(raw, ['baseUrl', 'base_url'], ''),
apiKey: readField(raw, ['apiKey', 'api_key'], ''),
@@ -457,11 +468,16 @@ export const fetchASRModels = async (): Promise<ASRModel[]> => {
};
export const createASRModel = async (data: Partial<ASRModel>): Promise<ASRModel> => {
const vendor = data.vendor || 'OpenAI Compatible';
const normalizedVendor = String(vendor).trim().toLowerCase();
const defaultBaseUrl = normalizedVendor === 'dashscope'
? DASHSCOPE_DEFAULT_ASR_BASE_URL
: OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL;
const payload = {
name: data.name || 'New ASR Model',
vendor: data.vendor || 'OpenAI Compatible',
vendor,
language: data.language || 'zh',
base_url: data.baseUrl || 'https://api.siliconflow.cn/v1',
base_url: data.baseUrl || defaultBaseUrl,
api_key: data.apiKey || '',
model_name: data.modelName || undefined,
hotwords: data.hotwords || [],