Use openai compatible as vendor

This commit is contained in:
Xin Wang
2026-02-12 18:44:55 +08:00
parent 260ff621bf
commit ff3a03b1ad
23 changed files with 822 additions and 905 deletions

View File

@@ -16,16 +16,22 @@ from ..schemas import (
router = APIRouter(prefix="/asr", tags=["ASR Models"])
SILICONFLOW_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
def _is_siliconflow_vendor(vendor: str) -> bool:
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
def _is_openai_compatible_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"硅基流动", # backward compatibility
}
def _default_asr_model(vendor: str) -> str:
if _is_siliconflow_vendor(vendor):
return SILICONFLOW_DEFAULT_ASR_MODEL
if _is_openai_compatible_vendor(vendor):
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
return "whisper-1"
@@ -129,7 +135,7 @@ def test_asr_model(
# 连接性测试优先,避免依赖真实音频输入
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer":
response = client.get(f"{model.base_url}/asr", headers=headers)
elif model.vendor.lower() == "openai":
response = client.get(f"{model.base_url}/audio/models", headers=headers)
@@ -258,7 +264,7 @@ async def preview_asr_model(
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
if not effective_api_key and _is_siliconflow_vendor(model.vendor):
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not effective_api_key:
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")

View File

@@ -13,8 +13,13 @@ from ..schemas import (
router = APIRouter(prefix="/assistants", tags=["Assistants"])
def _is_siliconflow_vendor(vendor: Optional[str]) -> bool:
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
def _is_openai_compatible_vendor(vendor: Optional[str]) -> bool:
return (vendor or "").strip().lower() in {
"siliconflow",
"硅基流动",
"openai compatible",
"openai-compatible",
}
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
@@ -47,11 +52,11 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
if assistant.asr_model_id:
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
if asr:
asr_provider = "siliconflow" if _is_siliconflow_vendor(asr.vendor) else "buffered"
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
metadata["services"]["asr"] = {
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "siliconflow" else None,
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
}
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
@@ -61,12 +66,12 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
elif assistant.voice:
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
if voice:
tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge"
tts_provider = "openai_compatible" if _is_openai_compatible_vendor(voice.vendor) else "edge"
metadata["services"]["tts"] = {
"enabled": True,
"provider": tts_provider,
"model": voice.model,
"apiKey": voice.api_key if tts_provider == "siliconflow" else None,
"apiKey": voice.api_key if tts_provider == "openai_compatible" else None,
"voice": voice.voice_key or voice.id,
"speed": assistant.speed or voice.speed,
}

View File

@@ -467,7 +467,13 @@ def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
normalized_vendor = (model.vendor or "").strip().lower()
if normalized_vendor in [
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"paraformer",
]:
response = client.get(
f"{model.base_url}/asr",
headers=headers

View File

@@ -13,20 +13,26 @@ from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewRe
router = APIRouter(prefix="/voices", tags=["Voices"])
SILICONFLOW_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
def _is_siliconflow_vendor(vendor: str) -> bool:
return vendor.strip().lower() in {"siliconflow", "硅基流动"}
def _is_openai_compatible_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"硅基流动", # backward compatibility
}
def _default_base_url(vendor: str) -> Optional[str]:
if _is_siliconflow_vendor(vendor):
if _is_openai_compatible_vendor(vendor):
return "https://api.siliconflow.cn/v1"
return None
def _build_siliconflow_voice_key(voice: Voice, model: str) -> str:
def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str:
if voice.voice_key:
return voice.voice_key
if ":" in voice.id:
@@ -65,8 +71,8 @@ def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
model = data.model
voice_key = data.voice_key
if _is_siliconflow_vendor(vendor):
model = model or SILICONFLOW_DEFAULT_MODEL
if _is_openai_compatible_vendor(vendor):
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
if not voice_key:
raw_id = (data.id or data.name).strip()
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
@@ -115,11 +121,11 @@ def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
update_data["vendor"] = update_data["vendor"].strip()
vendor_for_defaults = update_data.get("vendor", voice.vendor)
if _is_siliconflow_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or SILICONFLOW_DEFAULT_MODEL
if _is_openai_compatible_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
voice_key = update_data.get("voice_key") or voice.voice_key
update_data["model"] = model
update_data["voice_key"] = voice_key or _build_siliconflow_voice_key(voice, model)
update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model)
for field, value in update_data.items():
setattr(voice, field, value)
@@ -152,7 +158,7 @@ def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
if not api_key and _is_siliconflow_vendor(voice.vendor):
if not api_key and _is_openai_compatible_vendor(voice.vendor):
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
@@ -161,11 +167,11 @@ def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
model = voice.model or SILICONFLOW_DEFAULT_MODEL
model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
payload = {
"model": model,
"input": text,
"voice": voice.voice_key or _build_siliconflow_voice_key(voice, model),
"voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model),
"response_format": "mp3",
"speed": data.speed if data.speed is not None else voice.speed,
}