diff --git a/api/app/routers/asr.py b/api/app/routers/asr.py index b167802..07596a6 100644 --- a/api/app/routers/asr.py +++ b/api/app/routers/asr.py @@ -1,6 +1,14 @@ +import asyncio +import base64 +import io +import json import os +import sys +import threading import time -from typing import List, Optional +import wave +from array import array +from typing import Any, Dict, List, Optional, Tuple import httpx from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile @@ -17,6 +25,32 @@ from ..schemas import ( router = APIRouter(prefix="/asr", tags=["ASR Models"]) OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" +DASHSCOPE_DEFAULT_ASR_MODEL = "qwen3-asr-flash-realtime" +DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + +try: + import dashscope + from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation + + try: + from dashscope.audio.qwen_omni import TranscriptionParams + except ImportError: + from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams + + DASHSCOPE_SDK_AVAILABLE = True + DASHSCOPE_IMPORT_ERROR = "" +except Exception as exc: + dashscope = None # type: ignore[assignment] + MultiModality = None # type: ignore[assignment] + OmniRealtimeConversation = None # type: ignore[assignment] + TranscriptionParams = None # type: ignore[assignment] + DASHSCOPE_SDK_AVAILABLE = False + DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}" + + class OmniRealtimeCallback: # type: ignore[no-redef] + """Fallback callback base when DashScope SDK is unavailable.""" + + pass def _is_openai_compatible_vendor(vendor: str) -> bool: @@ -29,12 +63,377 @@ def _is_openai_compatible_vendor(vendor: str) -> bool: } +def _is_dashscope_vendor(vendor: str) -> bool: + return (vendor or "").strip().lower() == "dashscope" + + def _default_asr_model(vendor: str) -> str: if _is_openai_compatible_vendor(vendor): return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL + if _is_dashscope_vendor(vendor): + return DASHSCOPE_DEFAULT_ASR_MODEL return "whisper-1" +def _dashscope_language(language: Optional[str]) -> Optional[str]: + normalized = (language or "").strip().lower() + if not normalized or normalized in {"multi-lingual", "multilingual", "multi_lingual", "auto"}: + return None + if normalized.startswith("zh"): + return "zh" + if normalized.startswith("en"): + return "en" + return normalized + + +class _DashScopePreviewCallback(OmniRealtimeCallback): + """Collect DashScope ASR websocket events for preview/test flows.""" + + def __init__(self) -> None: + super().__init__() + self._open_event = threading.Event() + self._session_ready_event = threading.Event() + self._done_event = threading.Event() + self._lock = threading.Lock() + self._final_text = "" + self._last_interim_text = "" + self._error_message: Optional[str] = None + + def on_open(self) -> None: + self._open_event.set() + + def on_close(self, code: int, reason: str) -> None: + if self._done_event.is_set(): + return + self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}" + self._done_event.set() + self._session_ready_event.set() + + def on_error(self, message: Any) -> None: + self._error_message = str(message) + self._done_event.set() + self._session_ready_event.set() + + def on_event(self, response: Any) -> None: + payload = _coerce_dashscope_event(response) + event_type = str(payload.get("type") or "").strip() + if not event_type: + return + + if event_type in {"session.created", "session.updated"}: + self._session_ready_event.set() + return + + if event_type == "error" or event_type.endswith(".failed"): + self._error_message = _format_dashscope_error_event(payload) + self._done_event.set() + self._session_ready_event.set() + return + + if event_type == "conversation.item.input_audio_transcription.text": + interim_text = _extract_dashscope_text(payload, keys=("stash", "text", "transcript")) + if interim_text: + with self._lock: + self._last_interim_text = interim_text + return + + if event_type == "conversation.item.input_audio_transcription.completed": + final_text = _extract_dashscope_text(payload, keys=("transcript", "text", "stash")) + with self._lock: + if final_text: + self._final_text = final_text + self._done_event.set() + return + + if event_type in {"response.done", "session.finished"}: + self._done_event.set() + + def wait_for_open(self, timeout: float = 10.0) -> None: + if not self._open_event.wait(timeout): + raise TimeoutError("DashScope websocket open timeout") + + def wait_for_session_ready(self, timeout: float = 6.0) -> bool: + return self._session_ready_event.wait(timeout) + + def wait_for_done(self, timeout: float = 20.0) -> None: + if not self._done_event.wait(timeout): + raise TimeoutError("DashScope transcription timeout") + + def raise_if_error(self) -> None: + if self._error_message: + raise RuntimeError(self._error_message) + + def read_text(self) -> str: + with self._lock: + return self._final_text or self._last_interim_text + + +def _coerce_dashscope_event(response: Any) -> Dict[str, Any]: + if isinstance(response, dict): + return response + if isinstance(response, str): + try: + parsed = json.loads(response) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + return {"type": "raw", "message": str(response)} + + +def _format_dashscope_error_event(payload: Dict[str, Any]) -> str: + error = payload.get("error") + if isinstance(error, dict): + code = str(error.get("code") or "").strip() + message = str(error.get("message") or "").strip() + if code and message: + return f"{code}: {message}" + return message or str(error) + return str(error or "DashScope realtime ASR error") + + +def _extract_dashscope_text(payload: Dict[str, Any], *, keys: Tuple[str, ...]) -> str: + for key in keys: + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested = _extract_dashscope_text(value, keys=keys) + if nested: + return nested + + for value in payload.values(): + if isinstance(value, dict): + nested = _extract_dashscope_text(value, keys=keys) + if nested: + return nested + return "" + + +def _create_dashscope_realtime_client( + *, + model: str, + callback: _DashScopePreviewCallback, + url: str, + api_key: str, +) -> Any: + if OmniRealtimeConversation is None: + raise RuntimeError("DashScope SDK unavailable") + + init_kwargs = { + "model": model, + "callback": callback, + "url": url, + } + try: + return OmniRealtimeConversation(api_key=api_key, **init_kwargs) # type: ignore[misc] + except TypeError as exc: + if "api_key" not in str(exc): + raise + return OmniRealtimeConversation(**init_kwargs) # type: ignore[misc] + + +def _close_dashscope_client(client: Any) -> None: + finish_fn = getattr(client, "finish", None) + if callable(finish_fn): + try: + finish_fn() + except Exception: + pass + + close_fn = getattr(client, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: + pass + + +def _configure_dashscope_session( + *, + client: Any, + callback: _DashScopePreviewCallback, + sample_rate: int, + language: Optional[str], +) -> None: + update_fn = getattr(client, "update_session", None) + if not callable(update_fn): + raise RuntimeError("DashScope ASR SDK missing update_session method") + + text_modality: Any = "text" + if MultiModality is not None and hasattr(MultiModality, "TEXT"): + text_modality = MultiModality.TEXT + + transcription_params: Optional[Any] = None + language_hint = _dashscope_language(language) + if TranscriptionParams is not None: + try: + params_kwargs: Dict[str, Any] = { + "sample_rate": sample_rate, + "input_audio_format": "pcm", + } + if language_hint: + params_kwargs["language"] = language_hint + transcription_params = TranscriptionParams(**params_kwargs) + except Exception: + transcription_params = None + + update_attempts = [ + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + "transcription_params": transcription_params, + }, + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + }, + { + "output_modalities": [text_modality], + }, + ] + + last_error: Optional[Exception] = None + for params in update_attempts: + if params.get("transcription_params") is None: + params = {key: value for key, value in params.items() if key != "transcription_params"} + try: + update_fn(**params) + callback.wait_for_session_ready() + callback.raise_if_error() + return + except TypeError as exc: + last_error = exc + continue + except Exception as exc: + last_error = exc + continue + + raise RuntimeError(f"DashScope ASR session.update failed: {last_error}") + + +def _load_wav_pcm16_mono(audio_bytes: bytes) -> Tuple[bytes, int]: + try: + with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file: + channel_count = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + sample_rate = wav_file.getframerate() + compression = wav_file.getcomptype() + pcm_frames = wav_file.readframes(wav_file.getnframes()) + except wave.Error as exc: + raise RuntimeError("DashScope preview currently supports WAV audio. Record in browser or upload a .wav file.") from exc + + if compression != "NONE": + raise RuntimeError("DashScope preview requires uncompressed PCM WAV audio.") + if sample_width != 2: + raise RuntimeError("DashScope preview requires 16-bit PCM WAV audio.") + if not pcm_frames: + raise RuntimeError("Uploaded WAV file is empty") + if channel_count <= 1: + return pcm_frames, sample_rate + + samples = array("h") + samples.frombytes(pcm_frames) + if sys.byteorder == "big": + samples.byteswap() + + mono_samples = array( + "h", + ( + int(sum(samples[index:index + channel_count]) / channel_count) + for index in range(0, len(samples), channel_count) + ), + ) + if sys.byteorder == "big": + mono_samples.byteswap() + return mono_samples.tobytes(), sample_rate + + +def _probe_dashscope_asr_connection(*, api_key: str, base_url: str, model: str, language: Optional[str]) -> None: + if not DASHSCOPE_SDK_AVAILABLE: + hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`" + detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" + raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}") + + callback = _DashScopePreviewCallback() + if dashscope is not None: + dashscope.api_key = api_key + client = _create_dashscope_realtime_client( + model=model, + callback=callback, + url=base_url, + api_key=api_key, + ) + + try: + client.connect() + callback.wait_for_open() + _configure_dashscope_session( + client=client, + callback=callback, + sample_rate=16000, + language=language, + ) + finally: + _close_dashscope_client(client) + + +def _transcribe_dashscope_preview( + *, + audio_bytes: bytes, + api_key: str, + base_url: str, + model: str, + language: Optional[str], +) -> Dict[str, Any]: + if not DASHSCOPE_SDK_AVAILABLE: + hint = f"`{sys.executable} -m pip install dashscope>=1.25.11`" + detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" + raise RuntimeError(f"dashscope package not installed; install with {hint}{detail}") + + pcm_audio, sample_rate = _load_wav_pcm16_mono(audio_bytes) + callback = _DashScopePreviewCallback() + if dashscope is not None: + dashscope.api_key = api_key + client = _create_dashscope_realtime_client( + model=model, + callback=callback, + url=base_url, + api_key=api_key, + ) + + try: + client.connect() + callback.wait_for_open() + _configure_dashscope_session( + client=client, + callback=callback, + sample_rate=sample_rate, + language=language, + ) + + append_fn = getattr(client, "append_audio", None) + if not callable(append_fn): + raise RuntimeError("DashScope ASR SDK missing append_audio method") + commit_fn = getattr(client, "commit", None) + if not callable(commit_fn): + raise RuntimeError("DashScope ASR SDK missing commit method") + + append_fn(base64.b64encode(pcm_audio).decode("ascii")) + commit_fn() + callback.wait_for_done() + callback.raise_if_error() + return { + "transcript": callback.read_text(), + "language": _dashscope_language(language) or "Multi-lingual", + "confidence": None, + } + finally: + _close_dashscope_client(client) + + # ============ ASR Models CRUD ============ @router.get("") def list_asr_models( @@ -132,6 +531,27 @@ def test_asr_model( start_time = time.time() try: + if _is_dashscope_vendor(model.vendor): + effective_api_key = (model.api_key or "").strip() or os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip() + if not effective_api_key: + return ASRTestResponse(success=False, error=f"API key is required for ASR model: {model.name}") + + base_url = (model.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL + selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor) + _probe_dashscope_asr_connection( + api_key=effective_api_key, + base_url=base_url, + model=selected_model, + language=model.language, + ) + latency_ms = int((time.time() - start_time) * 1000) + return ASRTestResponse( + success=True, + language=model.language, + latency_ms=latency_ms, + message="DashScope realtime ASR connected", + ) + # 连接性测试优先,避免依赖真实音频输入 headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: @@ -246,7 +666,7 @@ async def preview_asr_model( api_key: Optional[str] = Form(None), db: Session = Depends(get_db), ): - """预览 ASR:上传音频并调用 OpenAI-compatible /audio/transcriptions。""" + """预览 ASR:根据供应商调用 OpenAI-compatible 或 DashScope 实时识别。""" model = db.query(ASRModel).filter(ASRModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="ASR Model not found") @@ -264,18 +684,50 @@ async def preview_asr_model( raise HTTPException(status_code=400, detail="Uploaded audio file is empty") effective_api_key = (api_key or "").strip() or (model.api_key or "").strip() - if not effective_api_key and _is_openai_compatible_vendor(model.vendor): - effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() + if not effective_api_key: + if _is_openai_compatible_vendor(model.vendor): + effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() + elif _is_dashscope_vendor(model.vendor): + effective_api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("ASR_API_KEY", "").strip() if not effective_api_key: raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}") base_url = (model.base_url or "").strip().rstrip("/") + if _is_dashscope_vendor(model.vendor) and not base_url: + base_url = DASHSCOPE_DEFAULT_BASE_URL if not base_url: raise HTTPException(status_code=400, detail=f"Base URL is required for ASR model: {model.name}") selected_model = (model.model_name or "").strip() or _default_asr_model(model.vendor) - data = {"model": selected_model} effective_language = (language or "").strip() or None + + start_time = time.time() + if _is_dashscope_vendor(model.vendor): + try: + payload = await asyncio.to_thread( + _transcribe_dashscope_preview, + audio_bytes=audio_bytes, + api_key=effective_api_key, + base_url=base_url, + model=selected_model, + language=effective_language or model.language, + ) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"DashScope ASR request failed: {exc}") from exc + + transcript = str(payload.get("transcript") or "") + response_language = str(payload.get("language") or effective_language or model.language) + latency_ms = int((time.time() - start_time) * 1000) + return ASRTestResponse( + success=bool(transcript), + transcript=transcript, + language=response_language, + confidence=None, + latency_ms=latency_ms, + message=None if transcript else "No transcript in response", + ) + + data = {"model": selected_model} if effective_language: data["language"] = effective_language if model.hotwords: @@ -284,7 +736,6 @@ async def preview_asr_model( headers = {"Authorization": f"Bearer {effective_api_key}"} files = {"file": (filename, audio_bytes, content_type)} - start_time = time.time() try: with httpx.Client(timeout=90.0) as client: response = client.post( diff --git a/api/app/schemas.py b/api/app/schemas.py index f0ad0c3..cbce453 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -191,6 +191,7 @@ class ASRModelCreate(ASRModelBase): class ASRModelUpdate(BaseModel): name: Optional[str] = None + vendor: Optional[str] = None language: Optional[str] = None base_url: Optional[str] = None api_key: Optional[str] = None diff --git a/api/init_db.py b/api/init_db.py index e3373f6..162eb99 100644 --- a/api/init_db.py +++ b/api/init_db.py @@ -34,6 +34,7 @@ SEED_LLM_IDS = { SEED_ASR_IDS = { "sensevoice_small": short_id("asr"), "telespeech_asr": short_id("asr"), + "dashscope_realtime": short_id("asr"), } SEED_ASSISTANT_IDS = { @@ -408,6 +409,20 @@ def init_default_asr_models(): enable_normalization=True, enabled=True, ), + ASRModel( + id=SEED_ASR_IDS["dashscope_realtime"], + user_id=1, + name="DashScope Realtime ASR", + vendor="DashScope", + language="Multi-lingual", + base_url=DASHSCOPE_REALTIME_URL, + api_key="YOUR_API_KEY", + model_name="qwen3-asr-flash-realtime", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), ] seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化") diff --git a/api/tests/test_asr.py b/api/tests/test_asr.py index 209116c..1cd3c01 100644 --- a/api/tests/test_asr.py +++ b/api/tests/test_asr.py @@ -1,8 +1,21 @@ """Tests for ASR Model API endpoints""" +import io +import wave + import pytest from unittest.mock import patch, MagicMock +def _make_wav_bytes(sample_rate: int = 16000) -> bytes: + with io.BytesIO() as buffer: + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(sample_rate) + wav_file.writeframes(b"\x00\x00" * sample_rate) + return buffer.getvalue() + + class TestASRModelAPI: """Test cases for ASR Model endpoints""" @@ -75,6 +88,24 @@ class TestASRModelAPI: assert data["language"] == "en" assert data["enable_punctuation"] == False + def test_update_asr_model_vendor(self, client, sample_asr_model_data): + """Test updating ASR vendor metadata.""" + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + response = client.put( + f"/api/asr/{model_id}", + json={ + "vendor": "DashScope", + "model_name": "qwen3-asr-flash-realtime", + "base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["vendor"] == "DashScope" + assert data["model_name"] == "qwen3-asr-flash-realtime" + def test_delete_asr_model(self, client, sample_asr_model_data): """Test deleting an ASR model""" # Create first @@ -234,6 +265,28 @@ class TestASRModelAPI: response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 + def test_test_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch): + """Test DashScope ASR connectivity probe.""" + from app.routers import asr as asr_router + + sample_asr_model_data["vendor"] = "DashScope" + sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime" + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + def fake_probe(**kwargs): + assert kwargs["api_key"] == sample_asr_model_data["api_key"] + assert kwargs["model"] == "qwen3-asr-flash-realtime" + + monkeypatch.setattr(asr_router, "_probe_dashscope_asr_connection", fake_probe) + + response = client.post(f"/api/asr/{model_id}/test") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["message"] == "DashScope realtime ASR connected" + @patch('httpx.Client') def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with failed connection""" @@ -274,7 +327,7 @@ class TestASRModelAPI: def test_different_asr_vendors(self, client): """Test creating ASR models with different vendors""" - vendors = ["SiliconFlow", "OpenAI", "Azure"] + vendors = ["SiliconFlow", "OpenAI", "Azure", "DashScope"] for vendor in vendors: data = { "id": f"asr-vendor-{vendor.lower()}", @@ -345,3 +398,33 @@ class TestASRModelAPI: ) assert response.status_code == 400 assert "Only audio files are supported" in response.text + + def test_preview_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch): + """Test ASR preview endpoint with DashScope realtime helper.""" + from app.routers import asr as asr_router + + sample_asr_model_data["vendor"] = "DashScope" + sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime" + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + def fake_preview(**kwargs): + assert kwargs["base_url"] == sample_asr_model_data["base_url"] + assert kwargs["model"] == sample_asr_model_data["model_name"] + return { + "transcript": "你好,这是实时识别", + "language": "zh", + "confidence": None, + } + + monkeypatch.setattr(asr_router, "_transcribe_dashscope_preview", fake_preview) + + response = client.post( + f"/api/asr/{model_id}/preview", + files={"file": ("sample.wav", _make_wav_bytes(), "audio/wav")}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["transcript"] == "你好,这是实时识别" diff --git a/web/pages/ASRLibrary.tsx b/web/pages/ASRLibrary.tsx index 03f54ae..52a08f8 100644 --- a/web/pages/ASRLibrary.tsx +++ b/web/pages/ASRLibrary.tsx @@ -82,6 +82,16 @@ const convertRecordedBlobToWav = async (blob: Blob): Promise => { } }; +const OPENAI_COMPATIBLE_DEFAULT_MODEL = 'FunAudioLLM/SenseVoiceSmall'; +const OPENAI_COMPATIBLE_DEFAULT_BASE_URL = 'https://api.siliconflow.cn/v1'; +const DASHSCOPE_DEFAULT_MODEL = 'qwen3-asr-flash-realtime'; +const DASHSCOPE_DEFAULT_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime'; + +type ASRVendor = 'OpenAI Compatible' | 'DashScope'; + +const normalizeVendor = (value?: string): ASRVendor => + String(value || '').trim().toLowerCase() === 'dashscope' ? 'DashScope' : 'OpenAI Compatible'; + export const ASRLibraryPage: React.FC = () => { const [models, setModels] = useState([]); const [searchTerm, setSearchTerm] = useState(''); @@ -271,10 +281,10 @@ const ASRModelModal: React.FC<{ initialModel?: ASRModel; }> = ({ isOpen, onClose, onSubmit, initialModel }) => { const [name, setName] = useState(''); - const [vendor, setVendor] = useState('OpenAI Compatible'); + const [vendor, setVendor] = useState('OpenAI Compatible'); const [language, setLanguage] = useState('zh'); - const [modelName, setModelName] = useState('FunAudioLLM/SenseVoiceSmall'); - const [baseUrl, setBaseUrl] = useState('https://api.siliconflow.cn/v1'); + const [modelName, setModelName] = useState(OPENAI_COMPATIBLE_DEFAULT_MODEL); + const [baseUrl, setBaseUrl] = useState(OPENAI_COMPATIBLE_DEFAULT_BASE_URL); const [apiKey, setApiKey] = useState(''); const [hotwords, setHotwords] = useState(''); const [enablePunctuation, setEnablePunctuation] = useState(true); @@ -282,14 +292,40 @@ const ASRModelModal: React.FC<{ const [enabled, setEnabled] = useState(true); const [saving, setSaving] = useState(false); + const getDefaultModel = (nextVendor: ASRVendor): string => + nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL; + + const getDefaultBaseUrl = (nextVendor: ASRVendor): string => + nextVendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL; + + const handleVendorChange = (nextVendor: ASRVendor) => { + const previousVendor = vendor; + setVendor(nextVendor); + + const previousDefaultModel = getDefaultModel(previousVendor); + const nextDefaultModel = getDefaultModel(nextVendor); + const trimmedModelName = modelName.trim(); + if (!trimmedModelName || trimmedModelName === previousDefaultModel) { + setModelName(nextDefaultModel); + } + + const previousDefaultBaseUrl = getDefaultBaseUrl(previousVendor); + const nextDefaultBaseUrl = getDefaultBaseUrl(nextVendor); + const trimmedBaseUrl = baseUrl.trim(); + if (!trimmedBaseUrl || trimmedBaseUrl === previousDefaultBaseUrl) { + setBaseUrl(nextDefaultBaseUrl); + } + }; + useEffect(() => { if (!isOpen) return; if (initialModel) { + const nextVendor = normalizeVendor(initialModel.vendor); setName(initialModel.name || ''); - setVendor(initialModel.vendor || 'OpenAI Compatible'); + setVendor(nextVendor); setLanguage(initialModel.language || 'zh'); - setModelName(initialModel.modelName || 'FunAudioLLM/SenseVoiceSmall'); - setBaseUrl(initialModel.baseUrl || 'https://api.siliconflow.cn/v1'); + setModelName(initialModel.modelName || getDefaultModel(nextVendor)); + setBaseUrl(initialModel.baseUrl || getDefaultBaseUrl(nextVendor)); setApiKey(initialModel.apiKey || ''); setHotwords(toHotwordsValue(initialModel.hotwords)); setEnablePunctuation(initialModel.enablePunctuation ?? true); @@ -301,8 +337,8 @@ const ASRModelModal: React.FC<{ setName(''); setVendor('OpenAI Compatible'); setLanguage('zh'); - setModelName('FunAudioLLM/SenseVoiceSmall'); - setBaseUrl('https://api.siliconflow.cn/v1'); + setModelName(OPENAI_COMPATIBLE_DEFAULT_MODEL); + setBaseUrl(OPENAI_COMPATIBLE_DEFAULT_BASE_URL); setApiKey(''); setHotwords(''); setEnablePunctuation(true); @@ -368,9 +404,10 @@ const ASRModelModal: React.FC<{
@@ -388,13 +425,22 @@ const ASRModelModal: React.FC<{
- setModelName(e.target.value)} placeholder="FunAudioLLM/SenseVoiceSmall" /> + setModelName(e.target.value)} + placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_MODEL : OPENAI_COMPATIBLE_DEFAULT_MODEL} + />
- setBaseUrl(e.target.value)} placeholder="https://api.siliconflow.cn/v1" className="font-mono text-xs" /> + setBaseUrl(e.target.value)} + placeholder={vendor === 'DashScope' ? DASHSCOPE_DEFAULT_BASE_URL : OPENAI_COMPATIBLE_DEFAULT_BASE_URL} + className="font-mono text-xs" + />
@@ -405,6 +451,11 @@ const ASRModelModal: React.FC<{
setHotwords(e.target.value)} placeholder="品牌名, 人名, 专有词" /> + {vendor === 'DashScope' && ( +

+ DashScope 走实时 WebSocket ASR。预览建议使用浏览器录音或上传 WAV 文件。 +

+ )}
diff --git a/web/services/backendApi.ts b/web/services/backendApi.ts index f4b6caa..50aec96 100644 --- a/web/services/backendApi.ts +++ b/web/services/backendApi.ts @@ -3,6 +3,8 @@ import { apiRequest, getApiBaseUrl } from './apiClient'; type AnyRecord = Record; const DEFAULT_LIST_LIMIT = 1000; +const OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL = 'https://api.siliconflow.cn/v1'; +const DASHSCOPE_DEFAULT_ASR_BASE_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime'; const TOOL_ID_ALIASES: Record = { voice_message_prompt: 'voice_msg_prompt', }; @@ -129,7 +131,16 @@ const mapVoice = (raw: AnyRecord): Voice => ({ const mapASRModel = (raw: AnyRecord): ASRModel => ({ id: String(readField(raw, ['id'], '')), name: readField(raw, ['name'], ''), - vendor: readField(raw, ['vendor'], 'OpenAI Compatible'), + vendor: (() => { + const vendor = String(readField(raw, ['vendor'], '')).trim().toLowerCase(); + if (vendor === 'dashscope') { + return 'DashScope'; + } + if (vendor === 'siliconflow' || vendor === 'openai compatible' || vendor === 'openai-compatible' || vendor === '硅基流动') { + return 'OpenAI Compatible'; + } + return String(readField(raw, ['vendor'], 'OpenAI Compatible')) || 'OpenAI Compatible'; + })(), language: readField(raw, ['language'], 'zh'), baseUrl: readField(raw, ['baseUrl', 'base_url'], ''), apiKey: readField(raw, ['apiKey', 'api_key'], ''), @@ -457,11 +468,16 @@ export const fetchASRModels = async (): Promise => { }; export const createASRModel = async (data: Partial): Promise => { + const vendor = data.vendor || 'OpenAI Compatible'; + const normalizedVendor = String(vendor).trim().toLowerCase(); + const defaultBaseUrl = normalizedVendor === 'dashscope' + ? DASHSCOPE_DEFAULT_ASR_BASE_URL + : OPENAI_COMPATIBLE_DEFAULT_ASR_BASE_URL; const payload = { name: data.name || 'New ASR Model', - vendor: data.vendor || 'OpenAI Compatible', + vendor, language: data.language || 'zh', - base_url: data.baseUrl || 'https://api.siliconflow.cn/v1', + base_url: data.baseUrl || defaultBaseUrl, api_key: data.apiKey || '', model_name: data.modelName || undefined, hotwords: data.hotwords || [],